1
0
forked from baron/baron-sso

chore: consolidate local integration changes

This commit is contained in:
2026-06-09 21:03:05 +09:00
parent aa2848c3b6
commit 1341f07ef9
158 changed files with 10995 additions and 1490 deletions

View File

@@ -56,6 +56,11 @@ func main() {
slog.Error("clear-orphan-user-tenant-memberships failed", "error", err)
os.Exit(1)
}
case "worksmobile-sync":
if err := runWorksmobileSync(os.Args[2:]); err != nil {
slog.Error("worksmobile-sync failed", "error", err)
os.Exit(1)
}
default:
printUsage()
os.Exit(2)
@@ -227,4 +232,5 @@ func printUsage() {
fmt.Fprintln(os.Stderr, "usage:")
fmt.Fprintln(os.Stderr, " adminctl create-super-admin [--email EMAIL] [--password PASSWORD] [--name NAME] [--update-password]")
fmt.Fprintln(os.Stderr, " adminctl clear-orphan-user-tenant-memberships [--dry-run]")
fmt.Fprintln(os.Stderr, " adminctl worksmobile-sync [--orgunits] [--users-csv PATH] [--credential-batch-id ID] [--process] [--serialize-orgunits] [--serialize-users-batch ID] [--batch-size N] [--delay DURATION]")
}

View File

@@ -1,6 +1,11 @@
package main
import "testing"
import (
"baron-sso-backend/internal/service"
"context"
"strings"
"testing"
)
func TestResolveCreateSuperAdminConfigUsesEnvDefaults(t *testing.T) {
t.Setenv("ADMIN_EMAIL", "admin@example.com")
@@ -71,3 +76,66 @@ func TestResolveClearOrphanUserTenantMembershipsConfig(t *testing.T) {
t.Fatal("dry-run flag was not set")
}
}
func TestAuditWorksmobileDuplicatePhoneCountryCodesReportsAndFixes(t *testing.T) {
client := &fakeWorksmobilePhoneAuditClient{
users: []service.WorksmobileRemoteUser{
{
ID: "works-user-1",
ExternalID: "baron-user-1",
Email: "one@example.com",
DisplayName: "One",
CellPhone: "+82 +821091917771",
DomainID: 1001,
DomainName: "samaneng.com",
},
{
ID: "works-user-2",
Email: "two@example.com",
CellPhone: "+821012345678",
DomainID: 1001,
},
},
}
output := &strings.Builder{}
count, err := auditWorksmobileDuplicatePhoneCountryCodes(context.Background(), output, true, client)
if err != nil {
t.Fatalf("auditWorksmobileDuplicatePhoneCountryCodes returned error: %v", err)
}
if count != 1 {
t.Fatalf("count=%d, want 1", count)
}
if !strings.Contains(output.String(), "one@example.com") || !strings.Contains(output.String(), "+821091917771") {
t.Fatalf("audit output did not include normalized duplicate phone row: %s", output.String())
}
if len(client.patches) != 1 {
t.Fatalf("patch count=%d, want 1", len(client.patches))
}
if client.patches[0].identifier != "works-user-1" {
t.Fatalf("patch identifier=%q, want works-user-1", client.patches[0].identifier)
}
if client.patches[0].payload.CellPhone != "+821091917771" {
t.Fatalf("patch cellPhone=%q, want +821091917771", client.patches[0].payload.CellPhone)
}
}
type fakeWorksmobilePhoneAuditClient struct {
users []service.WorksmobileRemoteUser
patches []fakeWorksmobilePhonePatch
}
type fakeWorksmobilePhonePatch struct {
identifier string
payload service.WorksmobileUserPatchPayload
}
func (f *fakeWorksmobilePhoneAuditClient) ListUsers(ctx context.Context) ([]service.WorksmobileRemoteUser, error) {
return f.users, nil
}
func (f *fakeWorksmobilePhoneAuditClient) PatchUser(ctx context.Context, identifier string, payload service.WorksmobileUserPatchPayload) error {
f.patches = append(f.patches, fakeWorksmobilePhonePatch{identifier: identifier, payload: payload})
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,38 @@
package main
import (
"baron-sso-backend/internal/service"
"testing"
)
func TestClassifyWorksmobileAlignFromWorksAllowsDomainOnlyEmailMismatch(t *testing.T) {
item := service.WorksmobileComparisonItem{
BaronEmail: "user@typo.example.com",
WorksmobileEmail: "user@example.com",
}
status, ok := classifyWorksmobileAlignFromWorks(item)
if !ok {
t.Fatalf("expected domain-only email mismatch to be alignable, status=%s", status)
}
if status != "updated" {
t.Fatalf("expected updated status, got %s", status)
}
}
func TestClassifyWorksmobileAlignFromWorksSkipsLocalPartChange(t *testing.T) {
item := service.WorksmobileComparisonItem{
BaronEmail: "old@example.com",
WorksmobileEmail: "new@example.com",
}
status, ok := classifyWorksmobileAlignFromWorks(item)
if ok {
t.Fatalf("expected local-part change to be skipped")
}
if status != "skipped_email_local_part_changed" {
t.Fatalf("expected skipped_email_local_part_changed status, got %s", status)
}
}

View File

@@ -4,11 +4,21 @@ import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"context"
"flag"
"fmt"
"log"
)
func main() {
dryRun := flag.Bool("dry-run", true, "변경 대상만 출력하고 Kratos identity를 수정하지 않습니다")
maintenanceWindow := flag.Bool("maintenance-window", false, "승인된 정비 시간에만 실제 변경을 허용합니다")
markMirrorStale := flag.Bool("mark-mirror-stale", false, "실행 전 Redis identity mirror를 stale로 표시했음을 확인합니다")
flag.Parse()
if !*dryRun && (!*maintenanceWindow || !*markMirrorStale) {
log.Fatal("refusing to update Kratos identities: pass --dry-run=false --maintenance-window --mark-mirror-stale after marking identity mirror stale")
}
kratosAdmin := service.NewKratosAdminService()
ctx := context.Background()
@@ -37,6 +47,11 @@ func main() {
}
if changed {
if *dryRun {
count++
fmt.Printf("Would update %s\n", id.ID)
continue
}
_, err := kratosAdmin.UpdateIdentity(ctx, id.ID, traits, id.State)
if err != nil {
log.Printf("Failed to update %s: %v", id.ID, err)
@@ -46,5 +61,10 @@ func main() {
}
}
}
fmt.Printf("Total updated: %d\n", count)
if *dryRun {
fmt.Printf("Total candidates: %d\n", count)
} else {
fmt.Printf("Total updated: %d\n", count)
fmt.Println("Identity mirror was marked stale before maintenance; run full mirror refresh and drift report before trusting cached user lists.")
}
}

View File

@@ -336,7 +336,12 @@ func main() {
)
configureWorksmobileClientFromEnv(worksmobileClient)
worksmobileService := service.NewWorksmobileSyncService(tenantService, userRepo, worksmobileOutboxRepo, worksmobileClient)
worksmobileRelayWorker := service.NewWorksmobileRelayWorker(worksmobileOutboxRepo, worksmobileClient)
worksmobileRelayClient := *worksmobileClient
worksmobileRelayClient.RateLimiter = service.NewWorksmobileAPIRateLimiter(240, time.Minute)
worksmobileRelayWorker := service.NewWorksmobileRelayWorker(worksmobileOutboxRepo, &worksmobileRelayClient)
if lock := service.NewWorksmobileRedisRelayLeaderLock(redisService); lock != nil {
worksmobileRelayWorker.SetLeaderLock(lock)
}
go worksmobileRelayWorker.Start(context.Background())
slog.Info("✅ Worksmobile Relay Worker started")
rpUsageEmitter := service.NewRPUsageEventEmitter(rpUsageOutboxRepo)
@@ -370,12 +375,13 @@ func main() {
authHandler.RPUserMetadataRepo = rpUserMetadataRepo
authHandler.RPUsageSink = rpUsageEmitter
adminHandler := handler.NewAdminHandler(ketoService, ketoOutboxRepo)
adminHandler.DB = db
adminHandler.RPUsageQueries = rpUsageQueryRepo
adminHandler.TenantRepo = tenantRepo
adminHandler.Hydra = hydraService
adminHandler.AuditRepo = auditRepo
adminHandler.UserProjectionRepo = userProjectionRepo
adminHandler.UserProjectionSyncer = userProjectionSyncer
adminHandler.IdentityCache = redisService
adminHandler.IntegrityChecker = repository.NewDataIntegrityChecker(db)
devHandler := handler.NewDevHandler(redisService, secretRepo, consentRepo, relyingPartyService, ketoService, ketoOutboxRepo, tenantService, developerService, authHandler)
devHandler.HeadlessJWKS = headlessJWKSCache
@@ -383,6 +389,7 @@ func main() {
devHandler.RPUserMetadataRepo = rpUserMetadataRepo
devHandler.RPUsageQueries = rpUsageQueryRepo
tenantHandler := handler.NewTenantHandler(db, tenantService, userRepo, userProjectionRepo, ketoService, ketoOutboxRepo, kratosAdminService, sharedLinkService, hydraService, consentRepo)
tenantHandler.OrgChartCache = redisService
userGroupHandler := handler.NewUserGroupHandler(userGroupService)
relyingPartyHandler := handler.NewRelyingPartyHandler(relyingPartyService, kratosAdminService)
userHandler := handler.NewUserHandler(kratosAdminService, oryAdminProvider, tenantService, ketoService, ketoOutboxRepo, userRepo, userGroupRepo, auditRepo)
@@ -718,12 +725,15 @@ func main() {
admin.Get("/integrity/orphan-user-login-ids", requireSuperAdmin, adminHandler.ListOrphanUserLoginIDs)
admin.Delete("/integrity/orphan-user-login-ids", requireSuperAdmin, adminHandler.DeleteOrphanUserLoginIDs)
admin.Get("/projections/users", requireSuperAdmin, adminHandler.GetUserProjectionStatus)
admin.Post("/projections/users/reconcile", requireSuperAdmin, adminHandler.ReconcileUserProjection)
admin.Post("/projections/users/reset", requireSuperAdmin, adminHandler.ResetUserProjection)
admin.Get("/ory/ssot", requireSuperAdmin, adminHandler.GetOrySSOTSystemStatus)
admin.Post("/ory/ssot/identity-cache/flush", requireSuperAdmin, adminHandler.FlushIdentityCache)
admin.Get("/rp-usage/daily", requireAdmin, adminHandler.GetRPUsageDaily)
admin.Get("/global-custom-claims", requireSuperAdmin, adminHandler.GetGlobalCustomClaimDefinitions)
admin.Put("/global-custom-claims", requireSuperAdmin, adminHandler.UpdateGlobalCustomClaimDefinitions)
// Tenant Management (Mixed roles, handler filters results)
admin.Get("/tenants", requireAnyUser, tenantHandler.ListTenants)
admin.Get("/orgchart/snapshot", requireAnyUser, tenantHandler.GetOrgChartSnapshot)
admin.Get("/tenants/export", requireSuperAdmin, tenantHandler.ExportTenantsCSV)
admin.Post("/tenants/import", requireSuperAdmin, tenantHandler.ImportTenantsCSV)
admin.Post("/tenants", requireSuperAdmin, tenantHandler.CreateTenant)

View File

@@ -63,6 +63,7 @@ func migrateSchemas(db *gorm.DB) error {
&domain.SharedLink{},
&domain.DeveloperRequest{},
&domain.RPUserMetadata{},
&domain.SystemSetting{},
// &domain.RelyingParty{}, // Removed: SSOT is Hydra + Keto
)
}

View File

@@ -0,0 +1,19 @@
package domain
import "time"
type IdentityCacheStatus struct {
Status string `json:"status"`
RedisReady bool `json:"redisReady"`
ObservedCount int64 `json:"observedCount"`
KeyCount int64 `json:"keyCount"`
LastRefreshedAt *time.Time `json:"lastRefreshedAt,omitempty"`
LastError string `json:"lastError,omitempty"`
UpdatedAt *time.Time `json:"updatedAt,omitempty"`
}
type IdentityCacheFlushResult struct {
Status string `json:"status"`
FlushedKeys int64 `json:"flushedKeys"`
UpdatedAt time.Time `json:"updatedAt"`
}

View File

@@ -0,0 +1,11 @@
package domain
import "time"
// SystemSetting stores small global configuration documents.
type SystemSetting struct {
Key string `gorm:"primaryKey;size:128" json:"key"`
Value JSONMap `gorm:"type:jsonb" json:"value"`
CreatedAt time.Time
UpdatedAt time.Time
}

View File

@@ -174,13 +174,7 @@ func ValidateLoginID(loginID string, emails []string, phone string) error {
}
if phone != "" {
normalizedPhone := strings.ReplaceAll(phone, "-", "")
normalizedPhone = strings.ReplaceAll(normalizedPhone, " ", "")
if strings.HasPrefix(normalizedPhone, "010") {
normalizedPhone = "+82" + normalizedPhone[1:]
} else if strings.HasPrefix(normalizedPhone, "82") {
normalizedPhone = "+" + normalizedPhone
}
normalizedPhone := NormalizePhoneNumber(phone)
if loginID == phone || loginID == normalizedPhone {
return fmt.Errorf("ID cannot be the same as the phone number")
@@ -211,3 +205,43 @@ func ValidateLoginID(loginID string, emails []string, phone string) error {
return nil
}
func NormalizePhoneNumber(phone string) string {
trimmed := strings.TrimSpace(phone)
if trimmed == "" {
return ""
}
hasLeadingPlus := false
digits := strings.Builder{}
for _, r := range trimmed {
switch {
case r >= '0' && r <= '9':
digits.WriteRune(r)
case r == '+' && digits.Len() == 0 && !hasLeadingPlus:
hasLeadingPlus = true
}
}
number := digits.String()
if number == "" {
return ""
}
if strings.HasPrefix(number, "010") {
return "+82" + number[1:]
}
if strings.HasPrefix(number, "82") {
rest := number[2:]
for strings.HasPrefix(rest, "82") {
rest = rest[2:]
}
if strings.HasPrefix(rest, "0") {
rest = rest[1:]
}
return "+82" + rest
}
if hasLeadingPlus {
return "+" + number
}
return number
}

View File

@@ -39,3 +39,26 @@ func TestValidateLoginID(t *testing.T) {
})
}
}
func TestNormalizePhoneNumberDeduplicatesKoreanCountryCode(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{"Local mobile", "010-9191-7771", "+821091917771"},
{"Korean country code", "+82 10-9191-7771", "+821091917771"},
{"Duplicate plus Korean country code", "+82 +821091917771", "+821091917771"},
{"Duplicate compact Korean country code", "+82821091917771", "+821091917771"},
{"Duplicate spaced Korean country code", "+82 8210 9191 7771", "+821091917771"},
{"Non Korean international phone preserved", "+1 914 481 2222", "+19144812222"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := NormalizePhoneNumber(tt.input); got != tt.want {
t.Fatalf("NormalizePhoneNumber(%q)=%q, want %q", tt.input, got, tt.want)
}
})
}
}

View File

@@ -11,22 +11,44 @@ import (
"time"
"github.com/gofiber/fiber/v2"
"gorm.io/gorm"
)
type adminHydraClientLister interface {
ListClients(ctx context.Context, limit, offset int) ([]domain.HydraClient, error)
}
type identityCacheAdmin interface {
GetIdentityCacheStatus(ctx context.Context) (domain.IdentityCacheStatus, error)
FlushIdentityCache(ctx context.Context) (domain.IdentityCacheFlushResult, error)
}
type AdminHandler struct {
Keto service.KetoService
KetoOutbox repository.KetoOutboxRepository
RPUsageQueries domain.RPUsageQueryRepository
TenantRepo repository.TenantRepository
Hydra adminHydraClientLister
AuditRepo domain.AuditRepository
UserProjectionRepo repository.UserProjectionRepository
UserProjectionSyncer service.UserProjectionReconciler
IntegrityChecker repository.DataIntegrityChecker
DB *gorm.DB
Keto service.KetoService
KetoOutbox repository.KetoOutboxRepository
RPUsageQueries domain.RPUsageQueryRepository
TenantRepo repository.TenantRepository
Hydra adminHydraClientLister
AuditRepo domain.AuditRepository
UserProjectionRepo repository.UserProjectionRepository
IdentityCache identityCacheAdmin
IntegrityChecker repository.DataIntegrityChecker
}
const globalCustomClaimsSettingKey = "global_custom_claim_definitions"
type globalCustomClaimDefinition struct {
Key string `json:"key"`
Label string `json:"label"`
ValueType string `json:"valueType"`
ReadPermission string `json:"readPermission"`
WritePermission string `json:"writePermission"`
Description string `json:"description,omitempty"`
}
type globalCustomClaimDefinitionsResponse struct {
Items []globalCustomClaimDefinition `json:"items"`
}
func NewAdminHandler(keto service.KetoService, ketoOutbox repository.KetoOutboxRepository) *AdminHandler {
@@ -110,6 +132,154 @@ func (h *AdminHandler) CheckAuth(c *fiber.Ctx) error {
return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "ok"})
}
func (h *AdminHandler) GetGlobalCustomClaimDefinitions(c *fiber.Ctx) error {
if h == nil || h.DB == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{
"error": "settings store unavailable",
})
}
var setting domain.SystemSetting
if err := h.DB.WithContext(c.Context()).First(&setting, "key = ?", globalCustomClaimsSettingKey).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return c.JSON(globalCustomClaimDefinitionsResponse{Items: []globalCustomClaimDefinition{}})
}
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(globalCustomClaimDefinitionsResponse{
Items: normalizeGlobalCustomClaimDefinitions(setting.Value["items"]),
})
}
func (h *AdminHandler) UpdateGlobalCustomClaimDefinitions(c *fiber.Ctx) error {
if h == nil || h.DB == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{
"error": "settings store unavailable",
})
}
var req globalCustomClaimDefinitionsResponse
if err := c.BodyParser(&req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid request body"})
}
items, err := validateGlobalCustomClaimDefinitions(req.Items)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
}
setting := domain.SystemSetting{
Key: globalCustomClaimsSettingKey,
Value: domain.JSONMap{"items": globalCustomClaimDefinitionsToJSON(items)},
}
if err := h.DB.WithContext(c.Context()).Save(&setting).Error; err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(globalCustomClaimDefinitionsResponse{Items: items})
}
func normalizeGlobalCustomClaimDefinitions(value any) []globalCustomClaimDefinition {
rawItems, ok := value.([]any)
if !ok {
return []globalCustomClaimDefinition{}
}
items := make([]globalCustomClaimDefinition, 0, len(rawItems))
for _, item := range rawItems {
raw, ok := item.(map[string]any)
if !ok {
continue
}
def := globalCustomClaimDefinition{
Key: strings.TrimSpace(stringValue(raw["key"])),
Label: strings.TrimSpace(stringValue(raw["label"])),
ValueType: normalizeGlobalCustomClaimType(stringValue(raw["valueType"])),
ReadPermission: adminNormalizeCustomClaimPermission(stringValue(raw["readPermission"])),
WritePermission: adminNormalizeCustomClaimPermission(stringValue(raw["writePermission"])),
Description: strings.TrimSpace(stringValue(raw["description"])),
}
if def.Key != "" {
items = append(items, def)
}
}
return items
}
func validateGlobalCustomClaimDefinitions(items []globalCustomClaimDefinition) ([]globalCustomClaimDefinition, error) {
seen := map[string]struct{}{}
normalized := make([]globalCustomClaimDefinition, 0, len(items))
for _, item := range items {
key := strings.TrimSpace(item.Key)
if key == "" {
continue
}
if !isValidCustomClaimKey(key) {
return nil, fiber.NewError(fiber.StatusBadRequest, "claim key must use letters, numbers, underscore, dot, or hyphen")
}
if _, exists := seen[key]; exists {
return nil, fiber.NewError(fiber.StatusBadRequest, "duplicate claim key: "+key)
}
seen[key] = struct{}{}
normalized = append(normalized, globalCustomClaimDefinition{
Key: key,
Label: strings.TrimSpace(item.Label),
ValueType: normalizeGlobalCustomClaimType(item.ValueType),
ReadPermission: adminNormalizeCustomClaimPermission(item.ReadPermission),
WritePermission: adminNormalizeCustomClaimPermission(item.WritePermission),
Description: strings.TrimSpace(item.Description),
})
}
return normalized, nil
}
func globalCustomClaimDefinitionsToJSON(items []globalCustomClaimDefinition) []any {
values := make([]any, 0, len(items))
for _, item := range items {
values = append(values, map[string]any{
"key": item.Key,
"label": item.Label,
"valueType": item.ValueType,
"readPermission": item.ReadPermission,
"writePermission": item.WritePermission,
"description": item.Description,
})
}
return values
}
func normalizeGlobalCustomClaimType(value string) string {
switch strings.ToLower(strings.TrimSpace(value)) {
case "number", "boolean", "array", "object", "date", "datetime":
return strings.ToLower(strings.TrimSpace(value))
default:
return "text"
}
}
func adminNormalizeCustomClaimPermission(value string) string {
if strings.TrimSpace(value) == "user_and_admin" {
return "user_and_admin"
}
return "admin_only"
}
func isValidCustomClaimKey(value string) bool {
for _, r := range value {
if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '_' || r == '-' || r == '.' {
continue
}
return false
}
return true
}
func stringValue(value any) string {
if text, ok := value.(string); ok {
return text
}
return ""
}
func requireSuperAdminProfile(c *fiber.Ctx) bool {
profile, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
if profile == nil || domain.NormalizeRole(profile.Role) != domain.RoleSuperAdmin {
@@ -133,26 +303,48 @@ func (h *AdminHandler) GetUserProjectionStatus(c *fiber.Ctx) error {
return c.JSON(status)
}
func (h *AdminHandler) ReconcileUserProjection(c *fiber.Ctx) error {
func (h *AdminHandler) GetOrySSOTSystemStatus(c *fiber.Ctx) error {
if !requireSuperAdminProfile(c) {
return nil
}
if h == nil || h.UserProjectionSyncer == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "user projection sync service unavailable"})
if h == nil || h.UserProjectionRepo == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "user projection service unavailable"})
}
count, err := h.UserProjectionSyncer.Reconcile(c.Context())
projectionStatus, err := h.UserProjectionRepo.GetStatus(c.Context())
if err != nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": err.Error()})
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
cacheStatus := domain.IdentityCacheStatus{
Status: "unavailable",
RedisReady: false,
LastError: "identity cache service unavailable",
}
if h.IdentityCache != nil {
cacheStatus, err = h.IdentityCache.GetIdentityCacheStatus(c.Context())
if err != nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": err.Error()})
}
}
return c.JSON(fiber.Map{
"status": "success",
"syncedUsers": count,
"updatedAt": time.Now().UTC().Format(time.RFC3339),
"userProjection": projectionStatus,
"identityCache": cacheStatus,
})
}
func (h *AdminHandler) ResetUserProjection(c *fiber.Ctx) error {
return h.ReconcileUserProjection(c)
func (h *AdminHandler) FlushIdentityCache(c *fiber.Ctx) error {
if !requireSuperAdminProfile(c) {
return nil
}
if h == nil || h.IdentityCache == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "identity cache service unavailable"})
}
result, err := h.IdentityCache.FlushIdentityCache(c.Context())
if err != nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(result)
}
func (h *AdminHandler) GetDataIntegrity(c *fiber.Ctx) error {

View File

@@ -5,7 +5,6 @@ import (
"baron-sso-backend/internal/service"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
@@ -78,6 +77,10 @@ func (f *fakeAdminUserProjectionRepo) CountTenantMembers(ctx context.Context, te
return nil, nil
}
func (f *fakeAdminUserProjectionRepo) CountTenantMembersRecursive(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
return nil, nil
}
func (f *fakeAdminUserProjectionRepo) ReplaceAllFromKratos(ctx context.Context, users []domain.User) error {
return nil
}
@@ -90,15 +93,22 @@ func (f *fakeAdminUserProjectionRepo) GetStatus(ctx context.Context) (domain.Use
return f.status, nil
}
type fakeAdminUserProjectionSyncer struct {
count int
err error
calls int
type fakeIdentityCacheAdmin struct {
status domain.IdentityCacheStatus
flush domain.IdentityCacheFlushResult
err error
statusHit int
flushCalls int
}
func (f *fakeAdminUserProjectionSyncer) Reconcile(ctx context.Context) (int, error) {
f.calls++
return f.count, f.err
func (f *fakeIdentityCacheAdmin) GetIdentityCacheStatus(ctx context.Context) (domain.IdentityCacheStatus, error) {
f.statusHit++
return f.status, f.err
}
func (f *fakeIdentityCacheAdmin) FlushIdentityCache(ctx context.Context) (domain.IdentityCacheFlushResult, error) {
f.flushCalls++
return f.flush, f.err
}
func TestAdminHandler_GetRPUsageDaily(t *testing.T) {
@@ -199,42 +209,81 @@ func TestAdminHandler_UserProjectionStatusReturnsProjectionStateForSuperAdmin(t
require.Equal(t, int64(152), body.ProjectedUsers)
}
func TestAdminHandler_ReconcileUserProjectionRequiresSuperAdminAndRunsSyncer(t *testing.T) {
syncer := &fakeAdminUserProjectionSyncer{count: 4}
h := &AdminHandler{UserProjectionSyncer: syncer}
func TestAdminHandler_GetOrySSOTSystemStatusReturnsProjectionAndIdentityCache(t *testing.T) {
syncedAt := time.Date(2026, 5, 11, 3, 0, 0, 0, time.UTC)
cache := &fakeIdentityCacheAdmin{
status: domain.IdentityCacheStatus{
Status: "ready",
RedisReady: true,
ObservedCount: 151,
KeyCount: 153,
LastRefreshedAt: &syncedAt,
UpdatedAt: &syncedAt,
},
}
h := &AdminHandler{
UserProjectionRepo: &fakeAdminUserProjectionRepo{
status: domain.UserProjectionStatus{
Name: domain.UserProjectionNameKratos,
Status: domain.UserProjectionStatusReady,
Ready: true,
LastSyncedAt: &syncedAt,
ProjectedUsers: 152,
},
},
IdentityCache: cache,
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Post("/api/v1/admin/projections/users/reconcile", h.ReconcileUserProjection)
app.Get("/api/v1/admin/ory/ssot", h.GetOrySSOTSystemStatus)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/projections/users/reconcile", nil)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/ory/ssot", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, 1, syncer.calls)
var body map[string]any
var body struct {
UserProjection domain.UserProjectionStatus `json:"userProjection"`
IdentityCache domain.IdentityCacheStatus `json:"identityCache"`
}
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, "success", body["status"])
require.Equal(t, float64(4), body["syncedUsers"])
require.Equal(t, int64(152), body.UserProjection.ProjectedUsers)
require.True(t, body.IdentityCache.RedisReady)
require.Equal(t, int64(151), body.IdentityCache.ObservedCount)
require.Equal(t, int64(153), body.IdentityCache.KeyCount)
require.Equal(t, 1, cache.statusHit)
}
func TestAdminHandler_ReconcileUserProjectionReturnsServiceUnavailableOnSyncFailure(t *testing.T) {
syncer := &fakeAdminUserProjectionSyncer{err: errors.New("kratos down")}
h := &AdminHandler{UserProjectionSyncer: syncer}
func TestAdminHandler_FlushIdentityCacheRequiresSuperAdminAndFlushesCacheOnly(t *testing.T) {
cache := &fakeIdentityCacheAdmin{
flush: domain.IdentityCacheFlushResult{
Status: "success",
FlushedKeys: 7,
UpdatedAt: time.Date(2026, 5, 11, 3, 2, 0, 0, time.UTC),
},
}
h := &AdminHandler{
IdentityCache: cache,
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Post("/api/v1/admin/projections/users/reconcile", h.ReconcileUserProjection)
app.Post("/api/v1/admin/ory/ssot/identity-cache/flush", h.FlushIdentityCache)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/projections/users/reconcile", nil)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/ory/ssot/identity-cache/flush", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
require.Equal(t, http.StatusOK, resp.StatusCode)
var body domain.IdentityCacheFlushResult
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, int64(7), body.FlushedKeys)
require.Equal(t, 1, cache.flushCalls)
}
func TestAdminHandler_GetRPUsageDailyChecksTenantPermission(t *testing.T) {

View File

@@ -776,13 +776,7 @@ func (h *AuthHandler) Signup(c *fiber.Ctx) error {
}
// Normalize Phone (E.164 형태로 보관)
normalizedPhone := strings.ReplaceAll(req.Phone, "-", "")
normalizedPhone = strings.ReplaceAll(normalizedPhone, " ", "")
if strings.HasPrefix(normalizedPhone, "010") {
normalizedPhone = "+82" + normalizedPhone[1:]
} else if strings.HasPrefix(normalizedPhone, "82") {
normalizedPhone = "+" + normalizedPhone
}
normalizedPhone := domain.NormalizePhoneNumber(req.Phone)
slog.Info("[Signup] Phone normalization", "raw", req.Phone, "normalized", normalizedPhone)
@@ -1092,15 +1086,7 @@ func (h *AuthHandler) GetTenantInfo(c *fiber.Ctx) error {
// normalizePhoneForLoginID는 전화번호를 IDP 조회에 적합한 형태(E.164)로 정규화합니다.
func normalizePhoneForLoginID(phone string) string {
normalized := strings.ReplaceAll(phone, "-", "")
normalized = strings.ReplaceAll(normalized, " ", "")
if strings.HasPrefix(normalized, "010") {
return "+82" + normalized[1:]
}
if strings.HasPrefix(normalized, "82") {
return "+" + normalized
}
return normalized
return domain.NormalizePhoneNumber(phone)
}
func buildOidcClaimsFromTraits(traits map[string]any, scopes []string, tenantID string) map[string]any {
@@ -1226,7 +1212,7 @@ func buildOidcClaimsFromTraits(traits map[string]any, scopes []string, tenantID
// Heuristic: if a trait value is a map, it's treated as namespaced metadata for a tenant
for k, v := range traits {
if k == "metadata" {
if k == "metadata" || k == "global_custom_claims" || k == "global_custom_claim_types" || k == "global_custom_claim_permissions" {
continue
}
if m, ok := v.(map[string]any); ok {
@@ -1242,7 +1228,7 @@ func buildOidcClaimsFromTraits(traits map[string]any, scopes []string, tenantID
claims["tenants"] = allTenants
}
return claims
return applyGlobalCustomClaims(claims, traits)
}
func withOidcSessionMetadata(claims map[string]any, sessionID string) map[string]any {
@@ -1263,6 +1249,39 @@ func composeOIDCSessionClaims(client domain.HydraClient, traits map[string]any,
return withOidcSessionMetadata(claims, sessionID)
}
func applyGlobalCustomClaims(baseClaims map[string]any, traits map[string]any) map[string]any {
if baseClaims == nil {
baseClaims = map[string]any{}
}
if traits == nil {
return baseClaims
}
rawClaims, ok := traits["global_custom_claims"]
if !ok || rawClaims == nil {
return baseClaims
}
customClaims, ok := rawClaims.(map[string]any)
if !ok {
return baseClaims
}
for key, value := range customClaims {
key = strings.TrimSpace(key)
if key == "" || value == nil {
continue
}
if key == "rp_claims" || key == "rp_profiles" {
continue
}
if _, exists := baseClaims[key]; exists {
continue
}
baseClaims[key] = value
}
return baseClaims
}
func (h *AuthHandler) withHanmacFamilyTenantClaims(ctx context.Context, claims map[string]any, traits map[string]any, scopes []string) map[string]any {
if claims == nil {
claims = map[string]any{}
@@ -4666,7 +4685,7 @@ func extractFirstString(data map[string]any, keys ...string) string {
}
func sanitizePhoneForSms(phone string) string {
sanitized := strings.TrimSpace(phone)
sanitized := domain.NormalizePhoneNumber(phone)
if strings.HasPrefix(sanitized, "+82") {
sanitized = "0" + sanitized[3:]
}
@@ -4685,11 +4704,7 @@ func (h *AuthHandler) formatPhoneForDisplay(phone string) string {
}
func (h *AuthHandler) formatPhoneForStorage(phone string) string {
phone = strings.ReplaceAll(phone, "-", "")
if strings.HasPrefix(phone, "010") && len(phone) == 11 {
return "+8210" + phone[3:]
}
return phone
return domain.NormalizePhoneNumber(phone)
}
// GetMe - Returns current user's profile with enriched data from local DB
@@ -5920,6 +5935,12 @@ func (h *AuthHandler) RevokeLinkedRp(c *fiber.Ctx) error {
slog.Error("failed to revoke hydra consent sessions", "error", err)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to revoke link")
}
if h.ConsentRepo != nil {
if err := h.ConsentRepo.Delete(c.Context(), subject, clientID); err != nil {
slog.Error("failed to delete local consent after hydra revoke", "error", err, "subject", subject, "client_id", clientID)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to revoke local consent")
}
}
if h.AuditRepo != nil {
detailsMap := map[string]any{
@@ -7611,35 +7632,6 @@ func (h *AuthHandler) getKratosSessionIDWithCookie(cookie string) (string, error
return result.ID, nil
}
func (h *AuthHandler) updateKratosIdentity(identityID string, traits map[string]any) error {
kratosAdminURL := strings.TrimRight(os.Getenv("KRATOS_ADMIN_URL"), "/")
if kratosAdminURL == "" {
kratosAdminURL = "http://kratos:4434"
}
payload := map[string]any{
"schema_id": "default",
"traits": traits,
}
body, _ := json.Marshal(payload)
req, err := http.NewRequestWithContext(context.Background(), http.MethodPut, fmt.Sprintf("%s/admin/identities/%s", kratosAdminURL, identityID), bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return fmt.Errorf("kratos admin update failed status=%d body=%s", resp.StatusCode, string(respBody))
}
return nil
}
func (h *AuthHandler) getHydraProfile(ctx context.Context, token string) (*domain.UserProfileResponse, error) {
intro, err := h.Hydra.IntrospectToken(ctx, token)
if err != nil {
@@ -7952,10 +7944,17 @@ func (h *AuthHandler) UpdateMe(c *fiber.Ctx) error {
}
}
if err := h.updateKratosIdentity(identityID, traits); err != nil {
if h.KratosAdmin == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "identity provider not available")
}
updatedIdentity, err := h.KratosAdmin.UpdateIdentity(c.Context(), identityID, traits, "")
if err != nil {
slog.Error("Failed to update profile in Kratos", "error", err)
return errorJSON(c, fiber.StatusInternalServerError, "프로필 업데이트에 실패했습니다.")
}
if updatedIdentity != nil && updatedIdentity.Traits != nil {
traits = updatedIdentity.Traits
}
// [New] Local DB Sync - Sync synchronously to ensure immediate consistency
if h.UserRepo != nil {

View File

@@ -28,6 +28,8 @@ func TestRevokeLinkedRp_Success(t *testing.T) {
}
// 2. Hydra Revoke
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
assert.Equal(t, "user-123", r.URL.Query().Get("subject"))
assert.Equal(t, "app-1", r.URL.Query().Get("client"))
return httpResponse(r, http.StatusNoContent, ""), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
@@ -40,12 +42,22 @@ func TestRevokeLinkedRp_Success(t *testing.T) {
auditRepo := &mockAuditRepo{}
rpUsageSink := &mockRPUsageEventSink{}
consentRepo := &mockConsentRepo{
consents: []domain.ClientConsent{
{
ClientID: "app-1",
Subject: "user-123",
GrantedScopes: []string{"openid", "profile"},
},
},
}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
AuditRepo: auditRepo,
ConsentRepo: consentRepo,
RPUsageSink: rpUsageSink,
}
app := fiber.New()
@@ -67,6 +79,9 @@ func TestRevokeLinkedRp_Success(t *testing.T) {
assert.Equal(t, domain.RPUsageEventTypeAuthorizationRevoked, rpUsageSink.events[0].EventType)
assert.Equal(t, "user-123", rpUsageSink.events[0].Subject)
assert.Equal(t, "app-1", rpUsageSink.events[0].ClientID)
remaining, err := consentRepo.Find(req.Context(), "app-1", "user-123")
assert.NoError(t, err)
assert.Nil(t, remaining)
}
func TestRevokeLinkedRp_SendsBackchannelLogoutTokenWhenConfigured(t *testing.T) {

View File

@@ -696,6 +696,31 @@ func TestGetConsentRequest_Skip_DynamicClaims(t *testing.T) {
assert.Equal(t, "Officer", capturedClaims["position"])
}
func TestBuildOidcClaimsFromTraits_IncludesGlobalCustomClaims(t *testing.T) {
claims := buildOidcClaimsFromTraits(map[string]any{
"email": "user@test.com",
"name": "Test User",
"global_custom_claims": map[string]any{
"contract_date": "2026-06-09",
"approved_at": "2026-06-09T09:30:00+09:00",
"email": "override@test.com",
"rp_claims": "reserved",
},
"global_custom_claim_permissions": map[string]any{
"contract_date": map[string]any{
"readPermission": "user_and_admin",
"writePermission": "admin_only",
},
},
}, []string{"openid", "profile", "email"}, "")
assert.Equal(t, "2026-06-09", claims["contract_date"])
assert.Equal(t, "2026-06-09T09:30:00+09:00", claims["approved_at"])
assert.Equal(t, "user@test.com", claims["email"])
assert.NotEqual(t, "reserved", claims["rp_claims"])
assert.NotContains(t, claims, "global_custom_claim_permissions")
}
func TestAcceptConsentRequest_AppliesConfiguredIDTokenClaims(t *testing.T) {
var capturedClaims map[string]any

View File

@@ -2,6 +2,7 @@ package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"context"
"encoding/json"
@@ -31,6 +32,28 @@ func (r *recordingUpdateMeUserRepo) UpdateUserLoginIDs(ctx context.Context, user
return nil
}
type recordingUpdateMeKratosAdmin struct {
MockKratosAdminService
updatedIdentityID string
updatedTraits map[string]any
updatedState string
storedTraits map[string]any
}
func (r *recordingUpdateMeKratosAdmin) UpdateIdentity(ctx context.Context, identityID string, traits map[string]any, state string) (*service.KratosIdentity, error) {
r.updatedIdentityID = identityID
r.updatedTraits = maps.Clone(traits)
r.updatedState = state
if r.storedTraits != nil {
maps.Copy(r.storedTraits, traits)
}
return &service.KratosIdentity{
ID: identityID,
Traits: traits,
State: state,
}, nil
}
func TestUpdateMe_InvalidatesProfileCacheForTokenSession(t *testing.T) {
token := "token-abc"
identityID := "user-1"
@@ -79,8 +102,10 @@ func TestUpdateMe_InvalidatesProfileCacheForTokenSession(t *testing.T) {
t.Setenv("KRATOS_ADMIN_URL", "http://kratos.test")
redis := &mockRedisRepo{data: make(map[string]string)}
kratosAdmin := &recordingUpdateMeKratosAdmin{storedTraits: traits}
h := &AuthHandler{
RedisService: redis,
KratosAdmin: kratosAdmin,
}
app := fiber.New()
app.Get("/api/v1/user/me", h.GetMe)
@@ -113,6 +138,8 @@ func TestUpdateMe_InvalidatesProfileCacheForTokenSession(t *testing.T) {
require.NoError(t, err)
require.Equal(t, http.StatusOK, updateResp.StatusCode)
require.Equal(t, "New Dept", traits["department"])
require.Equal(t, identityID, kratosAdmin.updatedIdentityID)
require.Equal(t, "New Dept", kratosAdmin.updatedTraits["department"])
// 3) 새로고침 재조회 시 New Dept가 보여야 함(캐시 무효화 회귀 방지)
getReq2 := httptest.NewRequest(http.MethodGet, "/api/v1/user/me", nil)
@@ -177,9 +204,11 @@ func TestUpdateMe_SyncsLocalReadModelFields(t *testing.T) {
"verify_update_phone:" + identityID + ":+821087654321": "verified",
}}
userRepo := &recordingUpdateMeUserRepo{}
kratosAdmin := &recordingUpdateMeKratosAdmin{storedTraits: traits}
h := &AuthHandler{
RedisService: redis,
UserRepo: userRepo,
KratosAdmin: kratosAdmin,
}
app := fiber.New()
app.Put("/api/v1/user/me", h.UpdateMe)
@@ -199,6 +228,9 @@ func TestUpdateMe_SyncsLocalReadModelFields(t *testing.T) {
updateResp, err := app.Test(updateReq, -1)
require.NoError(t, err)
require.Equal(t, http.StatusOK, updateResp.StatusCode)
require.Equal(t, identityID, kratosAdmin.updatedIdentityID)
require.Equal(t, "New Name", kratosAdmin.updatedTraits["name"])
require.Equal(t, "+821087654321", kratosAdmin.updatedTraits["phone_number"])
require.NotNil(t, userRepo.updated)
require.Equal(t, identityID, userRepo.updated.ID)

View File

@@ -196,7 +196,17 @@ func (m *mockConsentRepo) Find(ctx context.Context, clientID, subject string) (*
return nil, nil
}
func (m *mockConsentRepo) Delete(ctx context.Context, subject, clientID string) error { return nil }
func (m *mockConsentRepo) Delete(ctx context.Context, subject, clientID string) error {
filtered := m.consents[:0]
for _, consent := range m.consents {
if consent.Subject == subject && (clientID == "" || consent.ClientID == clientID) {
continue
}
filtered = append(filtered, consent)
}
m.consents = filtered
return nil
}
func (m *mockConsentRepo) DeleteByClient(ctx context.Context, clientID string) error {
filtered := m.consents[:0]

View File

@@ -176,17 +176,18 @@ type clientRelationUpsertRequest struct {
}
type consentSummary struct {
Subject string `json:"subject"`
UserName string `json:"userName,omitempty"`
ClientID string `json:"clientId"`
ClientName string `json:"clientName,omitempty"`
GrantedScopes []string `json:"grantedScopes"`
AuthenticatedAt string `json:"authenticatedAt,omitempty"`
CreatedAt time.Time `json:"createdAt"`
DeletedAt *time.Time `json:"deletedAt,omitempty"`
Status string `json:"status"`
TenantID string `json:"tenantId,omitempty"`
TenantName string `json:"tenantName,omitempty"`
Subject string `json:"subject"`
UserName string `json:"userName,omitempty"`
ClientID string `json:"clientId"`
ClientName string `json:"clientName,omitempty"`
GrantedScopes []string `json:"grantedScopes"`
AuthenticatedAt string `json:"authenticatedAt,omitempty"`
CreatedAt time.Time `json:"createdAt"`
DeletedAt *time.Time `json:"deletedAt,omitempty"`
Status string `json:"status"`
TenantID string `json:"tenantId,omitempty"`
TenantName string `json:"tenantName,omitempty"`
RPMetadata domain.JSONMap `json:"rpMetadata,omitempty"`
}
type consentListResponse struct {
@@ -217,10 +218,12 @@ type clientUpsertRequest struct {
}
type normalizedIDTokenClaim struct {
Namespace string `json:"namespace"`
Key string `json:"key"`
Value string `json:"value"`
ValueType string `json:"valueType"`
Namespace string `json:"namespace"`
Key string `json:"key"`
Value string `json:"value"`
ValueType string `json:"valueType"`
ReadPermission string `json:"readPermission"`
WritePermission string `json:"writePermission"`
}
var protectedSystemClientIDs = map[string]struct{}{
@@ -1535,19 +1538,202 @@ func (h *DevHandler) UpsertRPUserMetadata(c *fiber.Ctx) error {
if req.Metadata == nil {
req.Metadata = map[string]any{}
}
normalizedMetadata, err := normalizeRPUserMetadataForClient(req.Metadata, summary.Metadata)
if err != nil {
return errorJSON(c, fiber.StatusBadRequest, err.Error())
}
row := &domain.RPUserMetadata{
ClientID: clientID,
UserID: userID,
Metadata: domain.JSONMap(req.Metadata),
Metadata: normalizedMetadata,
}
if err := h.RPUserMetadataRepo.Upsert(c.Context(), row); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
if err := h.syncRPUserMetadataToKratos(c.Context(), userID, clientID, normalizedMetadata); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.JSON(row)
}
func (h *DevHandler) syncRPUserMetadataToKratos(ctx context.Context, userID string, clientID string, metadata domain.JSONMap) error {
if h == nil || h.KratosAdmin == nil {
return nil
}
identity, err := h.KratosAdmin.GetIdentity(ctx, userID)
if err != nil {
return fmt.Errorf("failed to load kratos identity for rp user metadata: %w", err)
}
if identity == nil {
return errors.New("kratos identity not found for rp user metadata")
}
traits := identity.Traits
if traits == nil {
traits = map[string]any{}
}
rawRPClaims, _ := traits["rp_custom_claims"].(map[string]any)
if rawRPClaims == nil {
rawRPClaims = map[string]any{}
}
rawRPClaims[clientID] = metadata
traits["rp_custom_claims"] = rawRPClaims
_, err = h.KratosAdmin.UpdateIdentity(ctx, identity.ID, traits, identity.State)
if err != nil {
return fmt.Errorf("failed to update kratos rp user metadata: %w", err)
}
return nil
}
type rpUserMetadataClaimSchema struct {
Key string
ValueType string
ReadPermission string
WritePermission string
}
func normalizeCustomClaimPermission(value any) string {
permission := strings.TrimSpace(readInterfaceString(value, ""))
switch permission {
case "user_and_admin":
return "user_and_admin"
default:
return "admin_only"
}
}
func normalizeCustomClaimPermissions(value any, fallbackRead string, fallbackWrite string) map[string]any {
var record map[string]any
switch typed := value.(type) {
case map[string]any:
record = typed
case domain.JSONMap:
record = map[string]any(typed)
}
return map[string]any{
"readPermission": normalizeCustomClaimPermission(readMapValueOrFallback(record, "readPermission", fallbackRead)),
"writePermission": normalizeCustomClaimPermission(readMapValueOrFallback(record, "writePermission", fallbackWrite)),
}
}
func readMapValueOrFallback(values map[string]any, key string, fallback string) any {
if values == nil {
return fallback
}
if value, ok := values[key]; ok {
return value
}
return fallback
}
func normalizeRPUserMetadataForClient(metadata map[string]any, clientMetadata map[string]any) (domain.JSONMap, error) {
schemas, err := rpUserMetadataClaimSchemas(clientMetadata)
if err != nil {
return nil, err
}
normalized := domain.JSONMap{}
for rawKey, rawValue := range metadata {
key := strings.TrimSpace(rawKey)
if key == "" || isEmptyRPUserMetadataValue(rawValue) {
continue
}
if strings.HasSuffix(key, "_permissions") {
claimKey := strings.TrimSuffix(key, "_permissions")
schema, ok := schemas[claimKey]
if !ok {
return nil, fmt.Errorf("rp user metadata claim is not configured: %s", claimKey)
}
normalized[key] = normalizeCustomClaimPermissions(rawValue, schema.ReadPermission, schema.WritePermission)
continue
}
schema, ok := schemas[key]
if !ok {
return nil, fmt.Errorf("rp user metadata claim is not configured: %s", key)
}
textValue, err := stringifyRPUserMetadataValue(rawValue)
if err != nil {
return nil, fmt.Errorf("rp user metadata %s is invalid: %w", key, err)
}
parsed, err := parseConfiguredClaimValue(textValue, schema.ValueType)
if err != nil {
return nil, fmt.Errorf("rp user metadata %s is invalid: %w", key, err)
}
normalized[key] = parsed
permissionKey := key + "_permissions"
if _, exists := normalized[permissionKey]; !exists {
normalized[permissionKey] = map[string]any{
"readPermission": schema.ReadPermission,
"writePermission": schema.WritePermission,
}
}
}
return normalized, nil
}
func rpUserMetadataClaimSchemas(clientMetadata map[string]any) (map[string]rpUserMetadataClaimSchema, error) {
rawClaims, ok := clientMetadata[domain.MetadataIDTokenClaims]
if !ok || rawClaims == nil {
return map[string]rpUserMetadataClaimSchema{}, nil
}
claims, err := normalizeIDTokenClaimsForDevConsole(rawClaims)
if err != nil {
return nil, err
}
schemas := make(map[string]rpUserMetadataClaimSchema, len(claims))
for _, claim := range claims {
if claim.Namespace != "rp_claims" {
continue
}
schemas[claim.Key] = rpUserMetadataClaimSchema{
Key: claim.Key,
ValueType: claim.ValueType,
ReadPermission: claim.ReadPermission,
WritePermission: claim.WritePermission,
}
}
return schemas, nil
}
func isEmptyRPUserMetadataValue(value any) bool {
if value == nil {
return true
}
if text, ok := value.(string); ok {
return strings.TrimSpace(text) == ""
}
return false
}
func stringifyRPUserMetadataValue(value any) (string, error) {
switch typed := value.(type) {
case string:
return strings.TrimSpace(typed), nil
case bool:
return strconv.FormatBool(typed), nil
case float64:
return strconv.FormatFloat(typed, 'f', -1, 64), nil
case float32:
return strconv.FormatFloat(float64(typed), 'f', -1, 32), nil
case int:
return strconv.Itoa(typed), nil
case int64:
return strconv.FormatInt(typed, 10), nil
case int32:
return strconv.FormatInt(int64(typed), 10), nil
case json.Number:
return typed.String(), nil
default:
data, err := json.Marshal(value)
if err != nil {
return "", err
}
return string(data), nil
}
}
func (h *DevHandler) syncHeadlessJWKSCache(ctx context.Context, client domain.HydraClient, reason string) {
if h.HeadlessJWKS == nil {
h.HeadlessJWKS = service.NewHeadlessJWKSCacheService(h.Redis, nil)
@@ -2262,6 +2448,13 @@ func (h *DevHandler) ListConsents(c *fiber.Ctx) error {
}
}
var rpMetadata domain.JSONMap
if h.RPUserMetadataRepo != nil {
if row, err := h.RPUserMetadataRepo.Get(c.Context(), consent.ClientID, consent.Subject); err == nil && row != nil && len(row.Metadata) > 0 {
rpMetadata = row.Metadata
}
}
items = append(items, consentSummary{
Subject: consent.Subject,
UserName: userName,
@@ -2273,6 +2466,7 @@ func (h *DevHandler) ListConsents(c *fiber.Ctx) error {
Status: status,
TenantID: consent.TenantID,
TenantName: consent.TenantName,
RPMetadata: rpMetadata,
})
}
@@ -3107,7 +3301,7 @@ func normalizeIDTokenClaimsMetadata(metadata map[string]any) (map[string]any, er
return metadata, nil
}
normalized, err := normalizeIDTokenClaims(rawClaims)
normalized, err := normalizeIDTokenClaimsForDevConsole(rawClaims)
if err != nil {
return nil, err
}
@@ -3116,6 +3310,14 @@ func normalizeIDTokenClaimsMetadata(metadata map[string]any) (map[string]any, er
}
func normalizeIDTokenClaims(rawClaims any) ([]normalizedIDTokenClaim, error) {
return normalizeIDTokenClaimsWithOptions(rawClaims, true)
}
func normalizeIDTokenClaimsForDevConsole(rawClaims any) ([]normalizedIDTokenClaim, error) {
return normalizeIDTokenClaimsWithOptions(rawClaims, false)
}
func normalizeIDTokenClaimsWithOptions(rawClaims any, allowTopLevel bool) ([]normalizedIDTokenClaim, error) {
rawList, ok := rawClaims.([]any)
if !ok {
if typedList, ok := rawClaims.([]map[string]any); ok {
@@ -3154,6 +3356,9 @@ func normalizeIDTokenClaims(rawClaims any) ([]normalizedIDTokenClaim, error) {
if namespace != "top_level" && namespace != "rp_claims" {
return nil, fmt.Errorf("metadata.id_token_claims namespace must be top_level or rp_claims: %s", namespace)
}
if !allowTopLevel && namespace == "top_level" {
return nil, errors.New("metadata.id_token_claims top_level namespace is managed from admin user custom claims")
}
key := strings.TrimSpace(readInterfaceString(record["key"], ""))
if key == "" {
@@ -3168,7 +3373,7 @@ func normalizeIDTokenClaims(rawClaims any) ([]normalizedIDTokenClaim, error) {
valueType = "text"
}
switch valueType {
case "text", "number", "boolean", "array", "object":
case "text", "number", "boolean", "array", "object", "date", "datetime":
default:
return nil, fmt.Errorf("metadata.id_token_claims valueType is invalid: %s", valueType)
}
@@ -3185,10 +3390,12 @@ func normalizeIDTokenClaims(rawClaims any) ([]normalizedIDTokenClaim, error) {
seen[signature] = struct{}{}
normalized = append(normalized, normalizedIDTokenClaim{
Namespace: namespace,
Key: key,
Value: value,
ValueType: valueType,
Namespace: namespace,
Key: key,
Value: value,
ValueType: valueType,
ReadPermission: normalizeCustomClaimPermission(record["readPermission"]),
WritePermission: normalizeCustomClaimPermission(record["writePermission"]),
})
}
@@ -3258,6 +3465,25 @@ func parseConfiguredClaimValue(rawValue string, valueType string) (any, error) {
return nil, errors.New("object value must be valid JSON object")
}
return parsed, nil
case "date":
if trimmed == "" {
return nil, errors.New("date value is required")
}
if _, err := time.Parse("2006-01-02", trimmed); err != nil {
return nil, errors.New("date value must use YYYY-MM-DD")
}
return trimmed, nil
case "datetime":
if trimmed == "" {
return nil, errors.New("datetime value is required")
}
if _, err := time.Parse(time.RFC3339, trimmed); err == nil {
return trimmed, nil
}
if _, err := time.Parse("2006-01-02T15:04", trimmed); err == nil {
return trimmed, nil
}
return nil, errors.New("datetime value must use RFC3339 or YYYY-MM-DDTHH:mm")
default:
return nil, fmt.Errorf("unsupported claim value type: %s", valueType)
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
type devMockRPUserMetadataRepo struct {
@@ -40,6 +41,14 @@ func TestDevHandler_RPUserMetadataRoundTrip(t *testing.T) {
"client_name": "Client One",
"metadata": map[string]any{
"tenant_id": "tenant-1",
"id_token_claims": []map[string]any{
{
"namespace": "rp_claims",
"key": "approvalLevel",
"valueType": "text",
"value": "A",
},
},
},
}), nil
}
@@ -50,7 +59,9 @@ func TestDevHandler_RPUserMetadataRoundTrip(t *testing.T) {
repo.On("Upsert", mock.Anything, mock.MatchedBy(func(row *domain.RPUserMetadata) bool {
return row.ClientID == "client-1" &&
row.UserID == "user-1" &&
row.Metadata["approvalLevel"] == "A"
row.Metadata["approvalLevel"] == "A" &&
row.Metadata["approvalLevel_permissions"].(map[string]any)["readPermission"] == "admin_only" &&
row.Metadata["approvalLevel_permissions"].(map[string]any)["writePermission"] == "user_and_admin"
})).Return(nil).Once()
repo.On("Get", mock.Anything, "client-1", "user-1").Return(&domain.RPUserMetadata{
ClientID: "client-1",
@@ -74,7 +85,12 @@ func TestDevHandler_RPUserMetadataRoundTrip(t *testing.T) {
app.Get("/api/v1/dev/clients/:id/users/:userId/metadata", h.GetRPUserMetadata)
body, _ := json.Marshal(map[string]any{
"metadata": map[string]any{"approvalLevel": "A"},
"metadata": map[string]any{
"approvalLevel": "A",
"approvalLevel_permissions": map[string]any{
"writePermission": "user_and_admin",
},
},
})
putReq := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
putReq.Header.Set("Content-Type", "application/json")
@@ -92,3 +108,171 @@ func TestDevHandler_RPUserMetadataRoundTrip(t *testing.T) {
assert.Equal(t, "A", got["metadata"].(map[string]any)["approvalLevel"])
repo.AssertExpectations(t)
}
func TestDevHandler_RPUserMetadataMirrorsToKratosTraits(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/clients/client-1" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-1",
"client_name": "Client One",
"metadata": map[string]any{
"tenant_id": "tenant-1",
"id_token_claims": []map[string]any{
{
"namespace": "rp_claims",
"key": "approvalLevel",
"valueType": "text",
"value": "A",
"readPermission": "user_and_admin",
"writePermission": "admin_only",
},
},
},
}), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
repo := new(devMockRPUserMetadataRepo)
repo.On("Upsert", mock.Anything, mock.AnythingOfType("*domain.RPUserMetadata")).Return(nil).Once()
kratos := new(MockKratosAdmin)
kratos.On("GetIdentity", mock.Anything, "user-1").Return(&service.KratosIdentity{
ID: "user-1",
State: "active",
Traits: map[string]any{
"email": "user@example.com",
"name": "User One",
},
}, nil).Once()
var capturedTraits map[string]any
kratos.On("UpdateIdentity", mock.Anything, "user-1", mock.Anything, "active").Run(func(args mock.Arguments) {
capturedTraits = args.Get(2).(map[string]any)
}).Return(&service.KratosIdentity{ID: "user-1", State: "active", Traits: map[string]any{}}, nil).Once()
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
},
KratosAdmin: kratos,
RPUserMetadataRepo: repo,
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
body, _ := json.Marshal(map[string]any{
"metadata": map[string]any{"approvalLevel": "B"},
})
req := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
require.Equal(t, http.StatusOK, resp.StatusCode)
rpClaims := capturedTraits["rp_custom_claims"].(map[string]any)
clientClaims := rpClaims["client-1"].(domain.JSONMap)
require.Equal(t, "B", clientClaims["approvalLevel"])
require.Equal(t, map[string]any{
"readPermission": "user_and_admin",
"writePermission": "admin_only",
}, clientClaims["approvalLevel_permissions"])
repo.AssertExpectations(t)
kratos.AssertExpectations(t)
}
func TestDevHandler_RPUserMetadataRejectsUndefinedClaimKey(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/clients/client-1" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-1",
"client_name": "Client One",
"metadata": map[string]any{
"id_token_claims": []map[string]any{
{
"namespace": "rp_claims",
"key": "contract_date",
"valueType": "date",
"value": "2026-06-09",
},
},
},
}), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
repo := new(devMockRPUserMetadataRepo)
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
},
RPUserMetadataRepo: repo,
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
body, _ := json.Marshal(map[string]any{
"metadata": map[string]any{"unknown_claim": "A"},
})
req := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
repo.AssertNotCalled(t, "Upsert", mock.Anything, mock.Anything)
}
func TestDevHandler_RPUserMetadataRejectsInvalidTypedClaimValue(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/clients/client-1" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-1",
"client_name": "Client One",
"metadata": map[string]any{
"id_token_claims": []map[string]any{
{
"namespace": "rp_claims",
"key": "contract_date",
"valueType": "date",
"value": "2026-06-09",
},
},
},
}), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
repo := new(devMockRPUserMetadataRepo)
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
},
RPUserMetadataRepo: repo,
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
body, _ := json.Marshal(map[string]any{
"metadata": map[string]any{"contract_date": "2026/06/09"},
})
req := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
repo.AssertNotCalled(t, "Upsert", mock.Anything, mock.Anything)
}

View File

@@ -726,7 +726,7 @@ func TestUpdateClient_AuditDetailsIncludeGeneralSettingChanges(t *testing.T) {
"tenant_id": "tenant-1",
"tenant_access_restricted": true,
"allowed_tenants": []any{"tenant-1", "tenant-2"},
"id_token_claims": []any{map[string]any{"namespace": "top_level", "key": "locale", "valueType": "text", "value": "ko-KR"}},
"id_token_claims": []any{map[string]any{"namespace": "rp_claims", "key": "locale", "valueType": "text", "value": "ko-KR"}},
"headless_login_enabled": true,
"headless_jwks_uri": "https://rp.example.com/jwks.json",
"headless_token_endpoint_auth_method": "private_key_jwt",
@@ -766,7 +766,7 @@ func TestUpdateClient_AuditDetailsIncludeGeneralSettingChanges(t *testing.T) {
"allowed_tenants": []string{"tenant-1", "tenant-2"},
"id_token_claims": []map[string]any{
{
"namespace": "top_level",
"namespace": "rp_claims",
"key": "locale",
"valueType": "text",
"value": "ko-KR",
@@ -2306,7 +2306,7 @@ func TestCreateClient_NormalizesIDTokenClaimsMetadata(t *testing.T) {
"id_token_claims": []map[string]any{
{
"id": "claim-1",
"namespace": "top_level",
"namespace": "rp_claims",
"key": "locale",
"value": " ko-KR ",
"valueType": "text",
@@ -2331,7 +2331,7 @@ func TestCreateClient_NormalizesIDTokenClaimsMetadata(t *testing.T) {
if assert.True(t, ok) && assert.Len(t, claims, 2) {
first, ok := claims[0].(map[string]any)
if assert.True(t, ok) {
assert.Equal(t, "top_level", first["namespace"])
assert.Equal(t, "rp_claims", first["namespace"])
assert.Equal(t, "locale", first["key"])
assert.Equal(t, "ko-KR", first["value"])
assert.Equal(t, "text", first["valueType"])
@@ -2393,7 +2393,7 @@ func TestCreateClient_RejectsInvalidIDTokenClaimsMetadata(t *testing.T) {
defer resp.Body.Close()
bodyBytes, _ := io.ReadAll(resp.Body)
assert.Contains(t, string(bodyBytes), "top-level key rp_claims is reserved")
assert.Contains(t, string(bodyBytes), "top_level namespace is managed from admin user custom claims")
assert.False(t, hydraCalled)
}
@@ -3134,6 +3134,147 @@ func TestListConsents_UserAllowedByRPAdminsRelation(t *testing.T) {
mockKeto.AssertExpectations(t)
}
func TestListConsents_IncludesRPUserMetadata(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.Method == http.MethodGet && r.URL.Path == "/clients/client-1" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-1",
"client_name": "App One",
"metadata": map[string]any{
"tenant_id": "tenant-1",
"status": "active",
},
}), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
repo := new(devMockRPUserMetadataRepo)
repo.On("Get", mock.Anything, "client-1", "subject-1").Return(&domain.RPUserMetadata{
ClientID: "client-1",
UserID: "subject-1",
Metadata: domain.JSONMap{
"approvalLevel": "A",
"reviewedAt": "2026-06-09T09:30:00+09:00",
},
}, nil).Once()
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
},
ConsentRepo: &mockConsentRepo{
consents: []domain.ClientConsent{
{
ClientID: "client-1",
Subject: "subject-1",
GrantedScopes: []string{"openid", "profile"},
CreatedAt: time.Now().UTC(),
},
},
},
RPUserMetadataRepo: repo,
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Get("/api/v1/dev/consents", h.ListConsents)
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/consents?client_id=client-1", nil)
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var result consentListResponse
_ = json.NewDecoder(resp.Body).Decode(&result)
if assert.Len(t, result.Items, 1) {
assert.Equal(t, domain.JSONMap{
"approvalLevel": "A",
"reviewedAt": "2026-06-09T09:30:00+09:00",
}, result.Items[0].RPMetadata)
}
repo.AssertExpectations(t)
}
func TestNormalizeIDTokenClaimsMetadata_AllowsDateAndDatetime(t *testing.T) {
metadata, err := normalizeIDTokenClaimsMetadata(map[string]any{
domain.MetadataIDTokenClaims: []any{
map[string]any{
"namespace": "rp_claims",
"key": "contract_date",
"value": "2026-06-09",
"valueType": "date",
},
map[string]any{
"namespace": "rp_claims",
"key": "approved_at",
"value": "2026-06-09T09:30:00+09:00",
"valueType": "datetime",
},
},
})
assert.NoError(t, err)
claims := metadata[domain.MetadataIDTokenClaims].([]normalizedIDTokenClaim)
assert.Equal(t, "date", claims[0].ValueType)
assert.Equal(t, "datetime", claims[1].ValueType)
}
func TestUpdateClient_RejectsTopLevelIDTokenClaimsFromDevConsole(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.Method == http.MethodGet && r.URL.Path == "/clients/client-1" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-1",
"client_name": "App One",
"redirect_uris": []string{"http://localhost/cb"},
"grant_types": []string{"authorization_code"},
"response_types": []string{"code"},
"scope": "openid profile",
"token_endpoint_auth_method": "none",
"metadata": map[string]any{"status": "active"},
}), nil
}
if r.Method == http.MethodPut && r.URL.Path == "/clients/client-1" {
t.Fatalf("hydra update should not be called for top-level id token claims")
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
},
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Put("/api/v1/dev/clients/:id", h.UpdateClient)
body, _ := json.Marshal(map[string]any{
"metadata": map[string]any{
domain.MetadataIDTokenClaims: []any{
map[string]any{
"namespace": "top_level",
"key": "employee_id",
"value": "EMP001",
"valueType": "text",
},
},
},
})
req := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
}
func TestListClientRelations_RPAdminAllowedByViewRelationshipsPermission(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.Method == http.MethodGet && r.URL.Path == "/clients/client-1" {

View File

@@ -10,12 +10,15 @@ import (
"bytes"
"context"
"encoding/csv"
"encoding/json"
"errors"
"fmt"
"io"
"maps"
"os"
"reflect"
"sort"
"strconv"
"strings"
"time"
@@ -28,6 +31,7 @@ type TenantHandler struct {
Service service.TenantService
UserRepo repository.UserRepository
UserProjectionRepo repository.UserProjectionRepository
OrgChartCache orgChartCacheStore
Keto service.KetoService
KetoOutbox repository.KetoOutboxRepository
KratosAdmin service.KratosAdminService
@@ -37,6 +41,11 @@ type TenantHandler struct {
ConsentRepo repository.ClientConsentRepository
}
type orgChartCacheStore interface {
Get(key string) (string, error)
Set(key string, value string, expiration time.Duration) error
}
func seedTenantDeleteError(c *fiber.Ctx) error {
return errorJSON(c, fiber.StatusConflict, "seed tenants cannot be deleted")
}
@@ -74,18 +83,19 @@ func (h *TenantHandler) SetWorksmobileSyncer(syncer service.WorksmobileSyncer) {
}
type tenantSummary struct {
ID string `json:"id"`
Type string `json:"type"`
ParentID *string `json:"parentId"`
Name string `json:"name"`
Slug string `json:"slug"`
Description string `json:"description"`
Status string `json:"status"`
Domains []string `json:"domains,omitempty"`
Config domain.JSONMap `json:"config,omitempty"`
MemberCount int64 `json:"memberCount"`
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
ID string `json:"id"`
Type string `json:"type"`
ParentID *string `json:"parentId"`
Name string `json:"name"`
Slug string `json:"slug"`
Description string `json:"description"`
Status string `json:"status"`
Domains []string `json:"domains,omitempty"`
Config domain.JSONMap `json:"config,omitempty"`
MemberCount int64 `json:"memberCount"`
TotalMemberCount int64 `json:"totalMemberCount"`
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
}
type tenantListResponse struct {
@@ -97,6 +107,18 @@ type tenantListResponse struct {
NextCursor string `json:"nextCursor,omitempty"`
}
type orgChartSnapshotCacheInfo struct {
Source string `json:"source"`
Hit bool `json:"hit"`
TTLSeconds int `json:"ttlSeconds,omitempty"`
}
type orgChartSnapshotResponse struct {
Tenants []tenantSummary `json:"tenants"`
Users []userSummary `json:"users"`
Cache orgChartSnapshotCacheInfo `json:"cache"`
}
func pageTenantsByCursor(tenants []domain.Tenant, limit int, cursorRaw string) ([]domain.Tenant, string, error) {
ordered := append([]domain.Tenant(nil), tenants...)
pagination.SortByKeyDesc(ordered, func(tenant domain.Tenant) (time.Time, string) {
@@ -360,7 +382,7 @@ func (h *TenantHandler) ListTenants(c *fiber.Ctx) error {
}
}
memberCounts, err := h.countTenantMembersFromProjection(c.Context(), tenants)
memberCounts, totalMemberCounts, err := h.countTenantMembersFromProjection(c.Context(), tenants)
if err != nil {
return errorJSON(c, fiber.StatusServiceUnavailable, err.Error())
}
@@ -369,6 +391,7 @@ func (h *TenantHandler) ListTenants(c *fiber.Ctx) error {
for _, t := range tenants {
summary := mapTenantSummary(t)
summary.MemberCount = memberCounts[t.ID]
summary.TotalMemberCount = totalMemberCounts[t.ID]
items = append(items, summary)
}
@@ -1656,13 +1679,14 @@ func (h *TenantHandler) GetTenant(c *fiber.Ctx) error {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
memberCounts, err := h.countTenantMembersFromProjection(c.Context(), []domain.Tenant{tenant})
memberCounts, totalMemberCounts, err := h.countTenantMembersFromProjection(c.Context(), []domain.Tenant{tenant})
if err != nil {
return errorJSON(c, fiber.StatusServiceUnavailable, err.Error())
}
summary := mapTenantSummary(tenant)
summary.MemberCount = memberCounts[tenant.ID]
summary.TotalMemberCount = totalMemberCounts[tenant.ID]
return c.JSON(summary)
}
@@ -1748,6 +1772,7 @@ func (h *TenantHandler) CreateTenant(c *fiber.Ctx) error {
summary := mapTenantSummary(*tenant)
summary.MemberCount = 0
summary.TotalMemberCount = 0
if req.Config != nil {
config, err := normalizeTenantConfig(req.Config)
@@ -2658,25 +2683,33 @@ func buildOrgContextTree(rootID string, tenants []domain.Tenant, tenantByID map[
return build(rootID)
}
func (h *TenantHandler) countTenantMembersFromProjection(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
func (h *TenantHandler) countTenantMembersFromProjection(ctx context.Context, tenants []domain.Tenant) (map[string]int64, map[string]int64, error) {
counts := make(map[string]int64, len(tenants))
for _, tenant := range tenants {
counts[tenant.ID] = 0
}
if len(tenants) == 0 {
return counts, nil
return counts, counts, nil
}
if h.UserProjectionRepo == nil {
return nil, errors.New("user projection is not configured")
return nil, nil, errors.New("user projection is not configured")
}
ready, err := h.UserProjectionRepo.IsReady(ctx)
if err != nil {
return nil, fmt.Errorf("user projection status unavailable: %w", err)
return nil, nil, fmt.Errorf("user projection status unavailable: %w", err)
}
if !ready {
return nil, errors.New("user projection is not ready")
return nil, nil, errors.New("user projection is not ready")
}
return h.UserProjectionRepo.CountTenantMembers(ctx, tenants)
directCounts, err := h.UserProjectionRepo.CountTenantMembers(ctx, tenants)
if err != nil {
return nil, nil, err
}
totalCounts, err := h.UserProjectionRepo.CountTenantMembersRecursive(ctx, tenants)
if err != nil {
return nil, nil, err
}
return directCounts, totalCounts, nil
}
func normalizeTenantStatus(value string) string {
@@ -2736,6 +2769,230 @@ func (h *TenantHandler) DeleteShareLink(c *fiber.Ctx) error {
return c.JSON(fiber.Map{"message": "Share link deleted successfully"})
}
func (h *TenantHandler) GetOrgChartSnapshot(c *fiber.Ctx) error {
profile, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
cacheMode := strings.ToLower(strings.TrimSpace(c.Query("cache")))
cacheKey := orgChartSnapshotCacheKey(profile, c.Get("X-Tenant-ID"))
ttl := orgChartSnapshotCacheTTL()
if cacheMode == "redis" && h.OrgChartCache != nil {
if raw, err := h.OrgChartCache.Get(cacheKey); err == nil && strings.TrimSpace(raw) != "" {
var cached orgChartSnapshotResponse
if err := json.Unmarshal([]byte(raw), &cached); err == nil {
cached.Cache = orgChartSnapshotCacheInfo{
Source: "redis",
Hit: true,
TTLSeconds: int(ttl.Seconds()),
}
c.Set("X-Orgfront-Cache", "HIT")
return c.JSON(cached)
}
}
}
snapshot, err := h.buildOrgChartSnapshot(c.Context(), profile)
if err != nil {
return errorJSON(c, fiber.StatusServiceUnavailable, err.Error())
}
snapshot.Cache = orgChartSnapshotCacheInfo{
Source: "database",
Hit: false,
TTLSeconds: int(ttl.Seconds()),
}
if cacheMode == "redis" && h.OrgChartCache != nil {
if raw, err := json.Marshal(snapshot); err == nil {
_ = h.OrgChartCache.Set(cacheKey, string(raw), ttl)
}
c.Set("X-Orgfront-Cache", "MISS")
} else {
c.Set("X-Orgfront-Cache", "BYPASS")
}
return c.JSON(snapshot)
}
func (h *TenantHandler) buildOrgChartSnapshot(ctx context.Context, profile *domain.UserProfileResponse) (orgChartSnapshotResponse, error) {
tenants, err := h.listOrgChartTenantsForProfile(ctx, profile)
if err != nil {
return orgChartSnapshotResponse{}, err
}
memberCounts, totalMemberCounts, err := h.countTenantMembersFromProjection(ctx, tenants)
if err != nil {
return orgChartSnapshotResponse{}, err
}
tenantSummaries := make([]tenantSummary, 0, len(tenants))
for _, tenant := range tenants {
summary := mapTenantSummary(tenant)
summary.MemberCount = memberCounts[tenant.ID]
summary.TotalMemberCount = totalMemberCounts[tenant.ID]
tenantSummaries = append(tenantSummaries, summary)
}
users, err := h.listOrgChartUsers(ctx, profile, tenants)
if err != nil {
return orgChartSnapshotResponse{}, err
}
return orgChartSnapshotResponse{
Tenants: tenantSummaries,
Users: users,
}, nil
}
func (h *TenantHandler) listOrgChartTenantsForProfile(ctx context.Context, profile *domain.UserProfileResponse) ([]domain.Tenant, error) {
if h.Service == nil {
return nil, errors.New("tenant service is not configured")
}
role := ""
if profile != nil {
role = domain.NormalizeRole(profile.Role)
}
if role == domain.RoleSuperAdmin {
tenants, _, err := h.Service.ListTenants(ctx, 10000, 0, "", "")
return tenants, err
}
allTenants, _, err := h.Service.ListTenants(ctx, 10000, 0, "", "")
if err != nil {
return nil, err
}
if profile == nil {
return []domain.Tenant{}, nil
}
baseTenantIDs := make([]string, 0, len(profile.ManageableTenants)+len(profile.JoinedTenants)+1)
for _, tenant := range profile.ManageableTenants {
baseTenantIDs = append(baseTenantIDs, tenant.ID)
}
for _, tenant := range profile.JoinedTenants {
baseTenantIDs = append(baseTenantIDs, tenant.ID)
}
if profile.TenantID != nil {
baseTenantIDs = append(baseTenantIDs, *profile.TenantID)
}
parentMap := make(map[string]string)
for _, tenant := range allTenants {
if tenant.ParentID != nil {
parentMap[tenant.ID] = *tenant.ParentID
}
}
findRoot := func(id string) string {
curr := id
for {
parentID, exists := parentMap[curr]
if !exists || parentID == "" {
return curr
}
curr = parentID
}
}
roots := make(map[string]bool)
for _, id := range baseTenantIDs {
if strings.TrimSpace(id) != "" {
roots[findRoot(id)] = true
}
}
tenants := make([]domain.Tenant, 0, len(allTenants))
for _, tenant := range allTenants {
if roots[findRoot(tenant.ID)] {
tenants = append(tenants, tenant)
}
}
return h.filterPrivateTenantsForProfile(ctx, tenants, profile)
}
func (h *TenantHandler) listOrgChartUsers(ctx context.Context, profile *domain.UserProfileResponse, tenants []domain.Tenant) ([]userSummary, error) {
if h.UserRepo == nil {
return nil, errors.New("user repository is not configured")
}
role := ""
if profile != nil {
role = domain.NormalizeRole(profile.Role)
}
tenantIDs := []string{}
if role != domain.RoleSuperAdmin {
tenantIDs = make([]string, 0, len(tenants))
for _, tenant := range tenants {
tenantIDs = append(tenantIDs, tenant.ID)
}
}
users, _, _, err := h.UserRepo.List(ctx, 0, 10000, "", tenantIDs, "")
if err != nil {
return nil, err
}
summaries := make([]userSummary, 0, len(users))
for _, user := range users {
summary := userSummary{
ID: user.ID,
Email: user.Email,
LoginID: user.Email,
Name: user.Name,
Phone: user.Phone,
Role: domain.NormalizeRole(user.Role),
Status: normalizeStatus(user.Status),
TenantSlug: userTenantSlug(user),
CompanyCode: userTenantSlug(user),
Metadata: user.Metadata,
Tenant: user.Tenant,
Department: user.Department,
Grade: user.Grade,
Position: user.Position,
JobTitle: user.JobTitle,
CreatedAt: formatTime(user.CreatedAt),
UpdatedAt: formatTime(user.UpdatedAt),
}
if h.Service != nil {
if joined, err := h.Service.ListJoinedTenants(ctx, user.ID); err == nil {
summary.JoinedTenants = joined
}
}
summaries = append(summaries, summary)
}
return summaries, nil
}
func orgChartSnapshotCacheKey(profile *domain.UserProfileResponse, tenantHeader string) string {
role := "anonymous"
userID := "anonymous"
tenantID := strings.TrimSpace(tenantHeader)
if profile != nil {
role = domain.NormalizeRole(profile.Role)
userID = strings.TrimSpace(profile.ID)
if tenantID == "" && profile.TenantID != nil {
tenantID = strings.TrimSpace(*profile.TenantID)
}
}
if userID == "" {
userID = "anonymous"
}
if tenantID == "" {
tenantID = "none"
}
return fmt.Sprintf("orgchart:snapshot:v1:%s:%s:%s", role, userID, tenantID)
}
func orgChartSnapshotCacheTTL() time.Duration {
const defaultTTL = 5 * time.Minute
raw := strings.TrimSpace(os.Getenv("ORGFRONT_ORGCHART_CACHE_TTL_SECONDS"))
if raw == "" {
return defaultTTL
}
seconds, err := strconv.Atoi(raw)
if err != nil || seconds <= 0 {
return defaultTTL
}
return time.Duration(seconds) * time.Second
}
func (h *TenantHandler) GetPublicOrgChart(c *fiber.Ctx) error {
token := c.Query("token")
if token == "" {

View File

@@ -15,6 +15,7 @@ import (
"testing"
"time"
"github.com/go-redis/redis/v8"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
@@ -190,6 +191,25 @@ type MockUserProjectionRepoForHandler struct {
mock.Mock
}
type mockOrgChartCache struct {
mock.Mock
values map[string]string
}
func (m *mockOrgChartCache) Get(key string) (string, error) {
args := m.Called(key)
return args.String(0), args.Error(1)
}
func (m *mockOrgChartCache) Set(key string, value string, expiration time.Duration) error {
if m.values == nil {
m.values = make(map[string]string)
}
m.values[key] = value
args := m.Called(key, value, expiration)
return args.Error(0)
}
func (m *MockUserProjectionRepoForHandler) IsReady(ctx context.Context) (bool, error) {
args := m.Called(ctx)
return args.Bool(0), args.Error(1)
@@ -208,6 +228,14 @@ func (m *MockUserProjectionRepoForHandler) CountTenantMembers(ctx context.Contex
return args.Get(0).(map[string]int64), args.Error(1)
}
func (m *MockUserProjectionRepoForHandler) CountTenantMembersRecursive(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
args := m.Called(ctx, tenants)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(map[string]int64), args.Error(1)
}
func (m *MockUserProjectionRepoForHandler) ReplaceAllFromKratos(ctx context.Context, users []domain.User) error {
args := m.Called(ctx, users)
return args.Error(0)
@@ -278,6 +306,8 @@ func TestTenantHandler_ListTenantsUsesReadyUserProjectionCountsWithoutKratos(t *
mockProjection.On("IsReady", mock.Anything).Return(true, nil).Once()
mockProjection.On("CountTenantMembers", mock.Anything, tenants).
Return(map[string]int64{"00000000-0000-0000-0000-000000000001": 2}, nil).Once()
mockProjection.On("CountTenantMembersRecursive", mock.Anything, tenants).
Return(map[string]int64{"00000000-0000-0000-0000-000000000001": 7}, nil).Once()
req := httptest.NewRequest("GET", "/tenants?limit=10&offset=0", nil)
resp, _ := app.Test(req)
@@ -289,6 +319,135 @@ func TestTenantHandler_ListTenantsUsesReadyUserProjectionCountsWithoutKratos(t *
require.Len(t, res.Items, 1)
assert.Equal(t, int64(2), res.Items[0].MemberCount)
assert.Equal(t, int64(7), res.Items[0].TotalMemberCount)
mockProjection.AssertExpectations(t)
}
func TestTenantHandler_GetOrgChartSnapshotReturnsRedisCacheHit(t *testing.T) {
app := fiber.New()
cache := &mockOrgChartCache{}
cached := `{"tenants":[{"id":"family","type":"COMPANY_GROUP","name":"한맥가족","slug":"hanmac-family","description":"","status":"active","memberCount":0,"totalMemberCount":2,"createdAt":"2026-06-09T00:00:00Z","updatedAt":"2026-06-09T00:00:00Z"}],"users":[],"cache":{"source":"redis","hit":true}}`
cache.On("Get", mock.MatchedBy(func(key string) bool {
return strings.HasPrefix(key, "orgchart:snapshot:")
})).Return(cached, nil).Once()
h := &TenantHandler{OrgChartCache: cache}
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Get("/admin/orgchart/snapshot", h.GetOrgChartSnapshot)
req := httptest.NewRequest(http.MethodGet, "/admin/orgchart/snapshot?cache=redis", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, "HIT", resp.Header.Get("X-Orgfront-Cache"))
var body map[string]any
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, "redis", body["cache"].(map[string]any)["source"])
cache.AssertExpectations(t)
}
func TestTenantHandler_GetOrgChartSnapshotCachesMissResult(t *testing.T) {
app := fiber.New()
mockSvc := new(MockTenantService)
mockProjection := new(MockUserProjectionRepoForHandler)
mockUsers := new(MockUserRepoForHandler)
cache := &mockOrgChartCache{}
now := time.Date(2026, 6, 9, 0, 0, 0, 0, time.UTC)
familyID := "family"
samanID := "saman"
tenants := []domain.Tenant{
{ID: familyID, Type: domain.TenantTypeCompanyGroup, Name: "한맥가족", Slug: "hanmac-family", Status: domain.TenantStatusActive, CreatedAt: now, UpdatedAt: now},
{ID: samanID, Type: domain.TenantTypeCompany, Name: "삼안", Slug: "saman", ParentID: &familyID, Status: domain.TenantStatusActive, CreatedAt: now, UpdatedAt: now},
}
users := []domain.User{
{ID: "user-1", Email: "user@example.com", Name: "User One", Role: domain.RoleUser, Status: "active", TenantID: &samanID, Tenant: &tenants[1], CreatedAt: now, UpdatedAt: now},
}
cache.On("Get", mock.Anything).Return("", redis.Nil).Once()
cache.On("Set", mock.MatchedBy(func(key string) bool {
return strings.HasPrefix(key, "orgchart:snapshot:")
}), mock.Anything, mock.AnythingOfType("time.Duration")).Return(nil).Once()
mockSvc.On("ListTenants", mock.Anything, 10000, 0, "", "").Return(tenants, int64(2), nil).Once()
mockSvc.On("ListJoinedTenants", mock.Anything, "user-1").Return([]domain.Tenant{tenants[1]}, nil).Once()
mockProjection.On("IsReady", mock.Anything).Return(true, nil).Once()
mockProjection.On("CountTenantMembers", mock.Anything, tenants).Return(map[string]int64{familyID: 0, samanID: 1}, nil).Once()
mockProjection.On("CountTenantMembersRecursive", mock.Anything, tenants).Return(map[string]int64{familyID: 1, samanID: 1}, nil).Once()
mockUsers.On("List", mock.Anything, 0, 10000, "", []string{}, "").Return(users, int64(1), "", nil).Once()
h := &TenantHandler{Service: mockSvc, UserRepo: mockUsers, UserProjectionRepo: mockProjection, OrgChartCache: cache}
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Get("/admin/orgchart/snapshot", h.GetOrgChartSnapshot)
req := httptest.NewRequest(http.MethodGet, "/admin/orgchart/snapshot?cache=redis", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, "MISS", resp.Header.Get("X-Orgfront-Cache"))
var body struct {
Tenants []tenantSummary `json:"tenants"`
Users []userSummary `json:"users"`
}
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Len(t, body.Tenants, 2)
require.Len(t, body.Users, 1)
require.Equal(t, int64(1), body.Tenants[0].TotalMemberCount)
cache.AssertExpectations(t)
mockSvc.AssertExpectations(t)
mockProjection.AssertExpectations(t)
mockUsers.AssertExpectations(t)
}
func TestTenantHandler_ListTenantsReturnsTotalMemberCountForDescendants(t *testing.T) {
app := fiber.New()
mockSvc := new(MockTenantService)
mockProjection := new(MockUserProjectionRepoForHandler)
h := &TenantHandler{
Service: mockSvc,
UserProjectionRepo: mockProjection,
}
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
Role: "super_admin",
})
return c.Next()
})
app.Get("/tenants", h.ListTenants)
parentID := "00000000-0000-0000-0000-000000000001"
childID := "00000000-0000-0000-0000-000000000002"
tenants := []domain.Tenant{
{ID: parentID, Name: "Parent", Slug: "parent"},
{ID: childID, Name: "Child", Slug: "child", ParentID: &parentID},
}
mockSvc.On("ListTenants", mock.Anything, 10, 0, "", "").Return(tenants, int64(2), nil).Once()
mockProjection.On("IsReady", mock.Anything).Return(true, nil).Once()
mockProjection.On("CountTenantMembers", mock.Anything, tenants).
Return(map[string]int64{parentID: 1, childID: 2}, nil).Once()
mockProjection.On("CountTenantMembersRecursive", mock.Anything, tenants).
Return(map[string]int64{parentID: 3, childID: 2}, nil).Once()
req := httptest.NewRequest("GET", "/tenants?limit=10&offset=0", nil)
resp, _ := app.Test(req)
require.Equal(t, http.StatusOK, resp.StatusCode)
var res tenantListResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&res))
require.Len(t, res.Items, 2)
assert.Equal(t, int64(1), res.Items[0].MemberCount)
assert.Equal(t, int64(3), res.Items[0].TotalMemberCount)
assert.Equal(t, int64(2), res.Items[1].MemberCount)
assert.Equal(t, int64(2), res.Items[1].TotalMemberCount)
mockProjection.AssertExpectations(t)
}
@@ -321,6 +480,7 @@ func TestTenantHandler_ListTenantsRejectsStatsWhenUserProjectionIsNotReady(t *te
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
mockProjection.AssertNotCalled(t, "CountTenantMembers", mock.Anything, mock.Anything)
mockProjection.AssertNotCalled(t, "CountTenantMembersRecursive", mock.Anything, mock.Anything)
}
func TestTenantHandler_ListTenants(t *testing.T) {
@@ -350,6 +510,8 @@ func TestTenantHandler_ListTenants(t *testing.T) {
mockProjection.On("IsReady", mock.Anything).Return(true, nil).Once()
mockProjection.On("CountTenantMembers", mock.Anything, tenants).
Return(map[string]int64{"t1": 5, "t2": 10}, nil).Once()
mockProjection.On("CountTenantMembersRecursive", mock.Anything, tenants).
Return(map[string]int64{"t1": 5, "t2": 10}, nil).Once()
req := httptest.NewRequest("GET", "/tenants?limit=10&offset=0", nil)
resp, _ := app.Test(req)
@@ -399,6 +561,7 @@ func TestTenantHandler_ListTenantsReturnsNextCursorWhenMoreRowsExist(t *testing.
mockSvc.On("ListTenants", mock.Anything, 2, 0, "", "").Return(tenants, int64(3), nil).Once()
mockProjection.On("IsReady", mock.Anything).Return(true, nil).Once()
mockProjection.On("CountTenantMembers", mock.Anything, tenants).Return(map[string]int64{}, nil).Once()
mockProjection.On("CountTenantMembersRecursive", mock.Anything, tenants).Return(map[string]int64{}, nil).Once()
req := httptest.NewRequest("GET", "/tenants?limit=2&offset=0", nil)
resp, _ := app.Test(req)
@@ -468,6 +631,9 @@ func TestTenantHandler_ListTenantsHidesPrivateSubtreeForUnauthorizedUser(t *test
mockProjection.On("CountTenantMembers", mock.Anything, mock.MatchedBy(func(got []domain.Tenant) bool {
return tenantSlugsMatch(got, "hanmac-family", "hanmac", "public-team")
})).Return(map[string]int64{}, nil).Once()
mockProjection.On("CountTenantMembersRecursive", mock.Anything, mock.MatchedBy(func(got []domain.Tenant) bool {
return tenantSlugsMatch(got, "hanmac-family", "hanmac", "public-team")
})).Return(map[string]int64{}, nil).Once()
req := httptest.NewRequest(http.MethodGet, "/tenants?limit=100&offset=0", nil)
resp, err := app.Test(req)
@@ -517,6 +683,9 @@ func TestTenantHandler_ListTenantsShowsPrivateSubtreeForManageableTenant(t *test
mockProjection.On("CountTenantMembers", mock.Anything, mock.MatchedBy(func(got []domain.Tenant) bool {
return tenantSlugsMatch(got, "hanmac-family", "hanmac", "private-team", "private-child")
})).Return(map[string]int64{}, nil).Once()
mockProjection.On("CountTenantMembersRecursive", mock.Anything, mock.MatchedBy(func(got []domain.Tenant) bool {
return tenantSlugsMatch(got, "hanmac-family", "hanmac", "private-team", "private-child")
})).Return(map[string]int64{}, nil).Once()
req := httptest.NewRequest(http.MethodGet, "/tenants?limit=100&offset=0", nil)
resp, err := app.Test(req)
@@ -936,6 +1105,8 @@ func TestTenantHandler_ListTenantsUsesProjectionCountsWhenAvailable(t *testing.T
mockProjection.On("IsReady", mock.Anything).Return(true, nil).Once()
mockProjection.On("CountTenantMembers", mock.Anything, tenants).
Return(map[string]int64{"00000000-0000-0000-0000-000000000001": 2}, nil).Once()
mockProjection.On("CountTenantMembersRecursive", mock.Anything, tenants).
Return(map[string]int64{"00000000-0000-0000-0000-000000000001": 2}, nil).Once()
mockUserRepo.On("CountByCompanyCodes", mock.Anything, []string{"saman"}).
Return(map[string]int64{"saman": 152}, nil).Maybe()

View File

@@ -23,6 +23,7 @@ import (
"time"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
)
// OryProviderAPI defines the subset of Ory Provider used by UserHandler
@@ -99,6 +100,76 @@ func sanitizeUserMetadata(metadata map[string]any) map[string]any {
return sanitized
}
func userAppointmentSliceFromRaw(raw any) []any {
switch values := raw.(type) {
case []any:
return append([]any(nil), values...)
case []map[string]any:
appointments := make([]any, 0, len(values))
for _, value := range values {
appointments = append(appointments, value)
}
return appointments
default:
return nil
}
}
func userAppointmentTenantKey(raw any) string {
appointment, ok := raw.(map[string]any)
if !ok {
return ""
}
if value := normalizeMetadataString(appointment["tenantId"]); value != "" {
return "id:" + strings.ToLower(value)
}
if value := normalizeMetadataString(appointment["tenantSlug"]); value != "" {
return "slug:" + strings.ToLower(value)
}
if value := normalizeMetadataString(appointment["slug"]); value != "" {
return "slug:" + strings.ToLower(value)
}
return ""
}
func mergeUserAddTenantAppointment(traits map[string]any, metadata map[string]any, tenant *domain.Tenant) map[string]any {
if tenant == nil {
return metadata
}
if metadata == nil {
metadata = map[string]any{}
}
appointments := userAppointmentSliceFromRaw(traits["additionalAppointments"])
if len(appointments) == 0 {
if legacyMetadata, ok := traits["metadata"].(map[string]any); ok {
appointments = userAppointmentSliceFromRaw(legacyMetadata["additionalAppointments"])
}
}
if incoming := userAppointmentSliceFromRaw(metadata["additionalAppointments"]); len(incoming) > 0 {
appointments = incoming
}
seen := make(map[string]bool, len(appointments)+1)
for _, appointment := range appointments {
if key := userAppointmentTenantKey(appointment); key != "" {
seen[key] = true
}
}
tenantIDKey := "id:" + strings.ToLower(strings.TrimSpace(tenant.ID))
tenantSlugKey := "slug:" + strings.ToLower(strings.TrimSpace(tenant.Slug))
if !seen[tenantIDKey] && !seen[tenantSlugKey] {
appointments = append(appointments, map[string]any{
"tenantId": tenant.ID,
"tenantSlug": tenant.Slug,
"tenantName": tenant.Name,
"isPrimary": false,
})
}
metadata["additionalAppointments"] = appointments
return metadata
}
func sanitizeUserRepresentativeTenants(ctx context.Context, tenantService service.TenantService, metadata map[string]any, appointments []map[string]any) (bool, error) {
if tenantService == nil || metadata == nil {
return false, nil
@@ -534,6 +605,66 @@ func (h *UserHandler) ListUsers(c *fiber.Ctx) error {
}
}
if h.UserRepo != nil {
var tenantIDs []string
if tenantSlug != "" {
if targetTenantID == "" {
return c.JSON(userListResponse{
Items: []userSummary{},
Limit: limit,
Offset: offset,
Total: 0,
Cursor: cursorRaw,
})
}
if requesterRole != domain.RoleSuperAdmin && !manageableSlugs[targetTenantID] && !manageableSlugs[strings.ToLower(tenantSlug)] {
return c.JSON(userListResponse{
Items: []userSummary{},
Limit: limit,
Offset: offset,
Total: 0,
Cursor: cursorRaw,
})
}
tenantIDs = append(tenantIDs, targetTenantID)
} else if requesterRole != domain.RoleSuperAdmin {
for key := range manageableSlugs {
if _, err := uuid.Parse(key); err == nil {
tenantIDs = append(tenantIDs, key)
}
}
if len(tenantIDs) == 0 {
return c.JSON(userListResponse{
Items: []userSummary{},
Limit: limit,
Offset: offset,
Total: 0,
Cursor: cursorRaw,
})
}
}
users, total, nextCursor, err := h.UserRepo.List(c.Context(), offset, limit, search, tenantIDs, cursorRaw)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, "failed to list users")
}
items := make([]userSummary, 0, len(users))
for _, user := range users {
items = append(items, h.mapLocalUserSummary(c.Context(), user))
}
if cursorRaw != "" {
offset = 0
}
return c.JSON(userListResponse{
Items: items,
Limit: limit,
Offset: offset,
Total: total,
Cursor: cursorRaw,
NextCursor: nextCursor,
})
}
if h.KratosAdmin == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "identity provider not available")
}
@@ -1615,11 +1746,18 @@ func (h *UserHandler) ExportUsersCSV(c *fiber.Ctx) error {
// 1. Fetch Users using Repo for efficiency
var exportTenantIDs []string
if tenantSlug != "" && h.TenantService != nil {
t, err := h.TenantService.GetTenantBySlug(c.Context(), tenantSlug)
if err == nil && t != nil {
exportTenantIDs = []string{t.ID}
if tenantSlug != "" {
if h.TenantService == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "tenant service unavailable for scoped export")
}
t, err := h.TenantService.GetTenantBySlug(c.Context(), tenantSlug)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, "failed to resolve tenant for export")
}
if t == nil || strings.TrimSpace(t.ID) == "" {
return errorJSON(c, fiber.StatusNotFound, "tenant not found for export")
}
exportTenantIDs = []string{t.ID}
}
users, _, _, err := h.UserRepo.List(c.Context(), 0, 10000, search, exportTenantIDs, "")
if err != nil {
@@ -2087,7 +2225,7 @@ func (h *UserHandler) UpdateUser(c *fiber.Ctx) error {
// All non-superadmins can only move users within tenants they can manage.
if requester != nil && domain.NormalizeRole(requester.Role) != domain.RoleSuperAdmin {
if !req.IsAddTenant && !req.IsRemoveTenant && req.CompanyCode != nil {
if !req.IsRemoveTenant && req.CompanyCode != nil {
targetSlug := strings.TrimSpace(*req.CompanyCode)
targetAllowed := profileCanAccessTenant(requester, "", targetSlug)
if !targetAllowed && h.TenantService != nil && targetSlug != "" {
@@ -2096,7 +2234,7 @@ func (h *UserHandler) UpdateUser(c *fiber.Ctx) error {
}
}
if !targetAllowed {
return errorJSON(c, fiber.StatusForbidden, "forbidden: non-superadmins cannot change user's tenant to an unmanageable one")
return errorJSON(c, fiber.StatusForbidden, "forbidden: non-superadmins cannot assign user's tenant to an unmanageable one")
}
}
}
@@ -2221,6 +2359,21 @@ func (h *UserHandler) UpdateUser(c *fiber.Ctx) error {
traits["tenant_id"] = ""
}
}
} else if h.TenantService != nil && code != "" {
tenant, err := h.TenantService.GetTenantBySlug(c.Context(), code)
if err != nil || tenant == nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid tenant assignment")
}
req.Metadata = mergeUserAddTenantAppointment(traits, req.Metadata, tenant)
if h.KetoOutboxRepo != nil {
_ = h.KetoOutboxRepo.Create(c.Context(), &domain.KetoOutbox{
Namespace: "Tenant",
Object: tenant.ID,
Relation: "members",
Subject: "User:" + userID,
Action: domain.KetoOutboxActionCreate,
})
}
}
}
delete(traits, "companyCode")
@@ -2775,6 +2928,45 @@ func (h *UserHandler) mapIdentitySummary(ctx context.Context, identity service.K
return summary
}
func (h *UserHandler) mapLocalUserSummary(ctx context.Context, user domain.User) userSummary {
tenantSlug := userTenantSlug(user)
customLoginIDs := make([]string, 0, len(user.UserLoginIDs))
for _, loginID := range user.UserLoginIDs {
if strings.TrimSpace(loginID.LoginID) != "" {
customLoginIDs = append(customLoginIDs, strings.TrimSpace(loginID.LoginID))
}
}
summary := userSummary{
ID: user.ID,
Email: user.Email,
LoginID: user.Email,
CustomLoginIDs: customLoginIDs,
Name: user.Name,
Phone: user.Phone,
Role: domain.NormalizeRole(user.Role),
Status: normalizeStatus(user.Status),
TenantSlug: tenantSlug,
CompanyCode: tenantSlug,
Department: user.Department,
Grade: user.Grade,
Position: user.Position,
JobTitle: user.JobTitle,
Metadata: user.Metadata,
Tenant: user.Tenant,
CreatedAt: formatTime(user.CreatedAt),
UpdatedAt: formatTime(user.UpdatedAt),
}
if h.TenantService != nil {
if joined, err := h.TenantService.ListJoinedTenants(ctx, user.ID); err == nil {
summary.JoinedTenants = joined
}
}
return summary
}
func (h *UserHandler) normalizePhoneNumber(phone string) string {
return normalizePhoneNumber(phone)
}
@@ -3302,18 +3494,7 @@ func normalizeKratosState(status *string) string {
}
func normalizePhoneNumber(phone string) string {
normalized := strings.ReplaceAll(phone, "-", "")
normalized = strings.ReplaceAll(normalized, " ", "")
if normalized == "" {
return ""
}
if strings.HasPrefix(normalized, "010") {
return "+82" + normalized[1:]
}
if strings.HasPrefix(normalized, "82") {
return "+" + normalized
}
return normalized
return domain.NormalizePhoneNumber(phone)
}
func (h *UserHandler) validateMetadata(metadata map[string]any, schema []any, checkRequired bool) error {

View File

@@ -320,7 +320,8 @@ func (m *MockTenantServiceForUser) RegisterTenant(ctx context.Context, name, slu
func TestUserHandler_ExportUsersCSV_UsesTenantSlugAliasAndOmitsRole(t *testing.T) {
app := fiber.New()
mockRepo := new(MockUserRepoForHandler)
h := &UserHandler{UserRepo: mockRepo}
mockTenant := new(MockTenantServiceForUser)
h := &UserHandler{UserRepo: mockRepo, TenantService: mockTenant}
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
@@ -332,7 +333,11 @@ func TestUserHandler_ExportUsersCSV_UsesTenantSlugAliasAndOmitsRole(t *testing.T
createdAt := time.Date(2026, 4, 29, 12, 0, 0, 0, time.UTC)
tenantID := "tenant-uuid"
mockRepo.On("List", mock.Anything, 0, 10000, "", []string(nil), "").
mockTenant.On("GetTenantBySlug", mock.Anything, "test-tenant").Return(&domain.Tenant{
ID: tenantID,
Slug: "test-tenant",
}, nil).Once()
mockRepo.On("List", mock.Anything, 0, 10000, "", []string{tenantID}, "").
Return([]domain.User{
{
ID: "u-1",
@@ -362,9 +367,34 @@ func TestUserHandler_ExportUsersCSV_UsesTenantSlugAliasAndOmitsRole(t *testing.T
assert.Contains(t, body, "u-1,user@test.com,Test User,010-1111-2222,active,tenant-uuid,test-tenant,책임,팀장")
assert.NotContains(t, body, "Role")
assert.NotContains(t, body, "Department")
mockTenant.AssertExpectations(t)
mockRepo.AssertExpectations(t)
}
func TestUserHandler_ExportUsersCSV_UnknownTenantSlugDoesNotFallbackToAllUsers(t *testing.T) {
app := fiber.New()
mockRepo := new(MockUserRepoForHandler)
mockTenant := new(MockTenantServiceForUser)
h := &UserHandler{UserRepo: mockRepo, TenantService: mockTenant}
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
Role: domain.RoleSuperAdmin,
})
return c.Next()
})
app.Get("/users/export", h.ExportUsersCSV)
mockTenant.On("GetTenantBySlug", mock.Anything, "missing-tenant").Return(nil, nil).Once()
req := httptest.NewRequest("GET", "/users/export?tenantSlug=missing-tenant&includeIds=true", nil)
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
mockRepo.AssertNotCalled(t, "List", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything)
mockTenant.AssertExpectations(t)
}
func TestUserHandler_ExportUsersCSV_OmitsIDsAndUsesTenantSlug(t *testing.T) {
app := fiber.New()
mockRepo := new(MockUserRepoForHandler)
@@ -951,10 +981,11 @@ func TestUserHandler_BulkCreateUsers_UsesEmailDomainTenantAsPrimaryWhenExplicitT
mockOry.AssertExpectations(t)
}
func TestUserHandler_ListUsersReturnsServiceUnavailableWhenKratosFails(t *testing.T) {
func TestUserHandler_ListUsersUsesLocalProjectionWhenKratosFails(t *testing.T) {
app := fiber.New()
mockKratos := new(MockKratosAdmin)
mockRepo := new(MockUserRepoForHandler)
createdAt := time.Date(2026, 6, 8, 6, 30, 0, 0, time.UTC)
h := &UserHandler{
KratosAdmin: mockKratos,
@@ -970,14 +1001,86 @@ func TestUserHandler_ListUsersReturnsServiceUnavailableWhenKratosFails(t *testin
app.Get("/users", h.ListUsers)
mockKratos.On("ListIdentities", mock.Anything).Return([]service.KratosIdentity{}, errors.New("kratos down")).Maybe()
mockRepo.On("List", mock.Anything, 0, 10, "", []string(nil), "").Return([]domain.User{
{
ID: "local-user-1",
Email: "local1@example.com",
Name: "Local One",
Role: domain.RoleUser,
Status: domain.UserStatusActive,
CreatedAt: createdAt,
UpdatedAt: createdAt,
},
}, int64(1), "", nil)
req := httptest.NewRequest("GET", "/users?limit=10&offset=0", nil)
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
mockRepo.AssertNotCalled(t, "List", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything)
mockKratos.AssertExpectations(t)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var res userListResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&res))
require.Equal(t, int64(1), res.Total)
require.Len(t, res.Items, 1)
require.Equal(t, "local1@example.com", res.Items[0].Email)
mockRepo.AssertExpectations(t)
}
func TestUserHandler_ListUsersUsesLocalProjectionTotalBeyondKratosPageLimit(t *testing.T) {
app := fiber.New()
mockKratos := new(MockKratosAdmin)
mockRepo := new(MockUserRepoForHandler)
createdAt := time.Date(2026, 6, 8, 6, 40, 0, 0, time.UTC)
h := &UserHandler{
KratosAdmin: mockKratos,
UserRepo: mockRepo,
}
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
Role: domain.RoleSuperAdmin,
})
return c.Next()
})
app.Get("/users", h.ListUsers)
kratosIdentities := make([]service.KratosIdentity, 250)
for i := range kratosIdentities {
kratosIdentities[i] = service.KratosIdentity{
ID: "kratos-user",
State: "active",
CreatedAt: createdAt.Add(-time.Duration(i) * time.Second),
Traits: map[string]any{"email": "kratos@example.com", "name": "Kratos"},
}
}
mockKratos.On("ListIdentities", mock.Anything).Return(kratosIdentities, nil).Maybe()
mockRepo.On("List", mock.Anything, 0, 50, "", []string(nil), "").Return([]domain.User{
{
ID: "local-user-1",
Email: "local1@example.com",
Name: "Local One",
Role: domain.RoleUser,
Status: domain.UserStatusActive,
CreatedAt: createdAt,
UpdatedAt: createdAt,
},
}, int64(2114), "next-local-cursor", nil)
req := httptest.NewRequest("GET", "/users?limit=50&offset=0", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
var res userListResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&res))
require.Equal(t, int64(2114), res.Total)
require.Len(t, res.Items, 1)
require.Equal(t, "local1@example.com", res.Items[0].Email)
require.Equal(t, "next-local-cursor", res.NextCursor)
mockRepo.AssertExpectations(t)
}
func TestUserHandler_ListUsersReturnsNextCursorWhenMoreRowsExist(t *testing.T) {
@@ -2363,6 +2466,157 @@ func TestUserHandler_UpdateUserAcceptsTenantSlugAndRejectsCompanyCode(t *testing
mockKratos.AssertExpectations(t)
}
func TestUserHandler_UpdateUserAddTenantKeepsPrimaryAndAddsAppointment(t *testing.T) {
app := fiber.New()
mockKratos := new(MockKratosAdmin)
mockTenant := new(MockTenantServiceForUser)
h := &UserHandler{
KratosAdmin: mockKratos,
TenantService: mockTenant,
}
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
ID: "admin-id",
Role: domain.RoleSuperAdmin,
})
return c.Next()
})
app.Put("/users/:id", h.UpdateUser)
mockKratos.On("GetIdentity", mock.Anything, "user-id").Return(&service.KratosIdentity{
ID: "user-id",
State: "active",
Traits: map[string]any{
"email": "user@test.com",
"name": "Test User",
"tenant_id": "primary-tenant-id",
"role": domain.RoleUser,
"additionalAppointments": []any{
map[string]any{
"tenantId": "primary-tenant-id",
"tenantSlug": "primary-tenant",
"tenantName": "대표 조직",
"isPrimary": true,
},
},
},
}, nil)
mockTenant.On("GetTenantBySlug", mock.Anything, "private-team").Return(&domain.Tenant{
ID: "private-team-id",
Name: "비공개 팀",
Slug: "private-team",
Config: domain.JSONMap{
"visibility": "private",
},
}, nil)
mockTenant.On("GetTenant", mock.Anything, "private-team-id").Return(&domain.Tenant{
ID: "private-team-id",
Name: "비공개 팀",
Slug: "private-team",
Config: domain.JSONMap{
"visibility": "private",
},
}, nil).Maybe()
mockTenant.On("GetTenant", mock.Anything, "primary-tenant-id").Return(&domain.Tenant{
ID: "primary-tenant-id",
Name: "대표 조직",
Slug: "primary-tenant",
}, nil).Maybe()
var capturedTraits map[string]any
mockKratos.On("UpdateIdentity", mock.Anything, "user-id", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
capturedTraits = args.Get(2).(map[string]any)
}).Return(&service.KratosIdentity{
ID: "user-id",
State: "active",
Traits: map[string]any{
"email": "user@test.com",
"name": "Test User",
"tenant_id": "primary-tenant-id",
"role": domain.RoleUser,
"additionalAppointments": []any{
map[string]any{
"tenantId": "primary-tenant-id",
"tenantSlug": "primary-tenant",
"tenantName": "대표 조직",
"isPrimary": true,
},
map[string]any{
"tenantId": "private-team-id",
"tenantSlug": "private-team",
"tenantName": "비공개 팀",
"isPrimary": false,
},
},
},
}, nil)
body := `{"tenantSlug":"private-team","isAddTenant":true}`
req := httptest.NewRequest(http.MethodPut, "/users/user-id", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, "primary-tenant-id", capturedTraits["tenant_id"])
appointments, ok := capturedTraits["additionalAppointments"].([]any)
require.True(t, ok)
require.Len(t, appointments, 2)
added := appointments[1].(map[string]any)
require.Equal(t, "private-team-id", added["tenantId"])
require.Equal(t, "private-team", added["tenantSlug"])
require.Equal(t, "비공개 팀", added["tenantName"])
require.Equal(t, false, added["isPrimary"])
}
func TestUserHandler_UpdateUserAddTenantRejectsUnmanageableTenantForTenantAdmin(t *testing.T) {
app := fiber.New()
mockKratos := new(MockKratosAdmin)
mockTenant := new(MockTenantServiceForUser)
allowedTenantID := "allowed-tenant-id"
h := &UserHandler{
KratosAdmin: mockKratos,
TenantService: mockTenant,
}
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
ID: "tenant-admin-id",
Role: "tenant_admin",
ManageableTenants: []domain.Tenant{
{ID: allowedTenantID, Slug: "allowed-team"},
},
})
return c.Next()
})
app.Put("/users/:id", h.UpdateUser)
mockKratos.On("GetIdentity", mock.Anything, "user-id").Return(&service.KratosIdentity{
ID: "user-id",
State: "active",
Traits: map[string]any{
"email": "user@test.com",
"name": "Test User",
"tenant_id": allowedTenantID,
"role": domain.RoleUser,
},
}, nil)
mockTenant.On("GetTenantBySlug", mock.Anything, "outside-team").Return(&domain.Tenant{
ID: "outside-tenant-id",
Name: "관리 외부 팀",
Slug: "outside-team",
}, nil).Once()
body := `{"tenantSlug":"outside-team","isAddTenant":true}`
req := httptest.NewRequest(http.MethodPut, "/users/user-id", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusForbidden, resp.StatusCode)
mockKratos.AssertNotCalled(t, "UpdateIdentity", mock.Anything, mock.Anything, mock.Anything, mock.Anything)
mockTenant.AssertExpectations(t)
}
func TestUserHandler_BulkUpdateUsersAcceptsTenantSlugAndRejectsCompanyCode(t *testing.T) {
app := fiber.New()
mockKratos := new(MockKratosAdmin)

View File

@@ -72,11 +72,17 @@ func (h *WorksmobileHandler) DeleteOrgUnit(c *fiber.Ctx) error {
func (h *WorksmobileHandler) SyncUser(c *fiber.Ctx) error {
userID := strings.TrimSpace(c.Params("userId"))
credentialBatchID, err := parseWorksmobileCredentialBatchID(c)
credentialRequest, err := parseWorksmobileCredentialRequest(c)
if err != nil {
return errorJSON(c, fiber.StatusBadRequest, err.Error())
}
job, err := h.Service.EnqueueUserSync(c.Context(), strings.TrimSpace(c.Params("tenantId")), userID, credentialBatchID)
job, err := h.Service.EnqueueUserSync(
c.Context(),
strings.TrimSpace(c.Params("tenantId")),
userID,
credentialRequest.CredentialBatchID,
credentialRequest.InitialPassword,
)
if err != nil {
return worksmobileGuardError(c, err, "sync_user", "user_id", userID)
}
@@ -158,21 +164,30 @@ func (h *WorksmobileHandler) DeleteCredentialBatchPasswords(c *fiber.Ctx) error
type worksmobileCredentialBatchRequest struct {
CredentialBatchID string `json:"credentialBatchId"`
InitialPassword string `json:"initialPassword"`
}
func parseWorksmobileCredentialBatchID(c *fiber.Ctx) (string, error) {
req, err := parseWorksmobileCredentialRequest(c)
return req.CredentialBatchID, err
}
func parseWorksmobileCredentialRequest(c *fiber.Ctx) (worksmobileCredentialBatchRequest, error) {
batchID := strings.TrimSpace(c.Query("credentialBatchId"))
req := worksmobileCredentialBatchRequest{CredentialBatchID: batchID}
if len(bytes.TrimSpace(c.Body())) == 0 {
return batchID, nil
return req, nil
}
var req worksmobileCredentialBatchRequest
if err := c.BodyParser(&req); err != nil {
return "", err
return worksmobileCredentialBatchRequest{}, err
}
req.InitialPassword = strings.TrimSpace(req.InitialPassword)
if bodyBatchID := strings.TrimSpace(req.CredentialBatchID); bodyBatchID != "" {
return bodyBatchID, nil
req.CredentialBatchID = bodyBatchID
return req, nil
}
return batchID, nil
req.CredentialBatchID = batchID
return req, nil
}
func worksmobileOverviewAllowed(overview service.WorksmobileTenantOverview) bool {

View File

@@ -97,13 +97,14 @@ func TestWorksmobileHandlerPassesSyncUserCredentialBatchID(t *testing.T) {
app := fiber.New()
app.Post("/tenants/:tenantId/worksmobile/users/:userId/sync", h.SyncUser)
req := httptest.NewRequest("POST", "/tenants/hanmac-id/worksmobile/users/user-1/sync", strings.NewReader(`{"credentialBatchId":"batch-1"}`))
req := httptest.NewRequest("POST", "/tenants/hanmac-id/worksmobile/users/user-1/sync", strings.NewReader(`{"credentialBatchId":"batch-1","initialPassword":"InputPass1!"}`))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusAccepted, resp.StatusCode)
require.Equal(t, "batch-1", fakeService.syncUserCredentialBatchID)
require.Equal(t, "InputPass1!", fakeService.syncUserInitialPassword)
}
func TestWorksmobileHandlerPassesPasswordResetCredentialBatchID(t *testing.T) {
@@ -199,6 +200,7 @@ type fakeWorksmobileAdminService struct {
credentials []service.WorksmobileInitialPasswordCredential
syncUserErr error
syncUserCredentialBatchID string
syncUserInitialPassword string
resetPasswordCredentialBatchID string
downloadCredentialBatchID string
deletedCredentialBatchID string
@@ -227,8 +229,9 @@ func (f *fakeWorksmobileAdminService) EnqueueOrgUnitDelete(ctx context.Context,
return &domain.WorksmobileOutbox{ID: "job-orgunit-delete", ResourceID: orgUnitID, Action: domain.WorksmobileActionDelete}, nil
}
func (f *fakeWorksmobileAdminService) EnqueueUserSync(ctx context.Context, tenantID, userID, credentialBatchID string) (*domain.WorksmobileOutbox, error) {
func (f *fakeWorksmobileAdminService) EnqueueUserSync(ctx context.Context, tenantID, userID, credentialBatchID, initialPassword string) (*domain.WorksmobileOutbox, error) {
f.syncUserCredentialBatchID = credentialBatchID
f.syncUserInitialPassword = initialPassword
if f.syncUserErr != nil {
return nil, f.syncUserErr
}

View File

@@ -16,6 +16,7 @@ type UserProjectionRepository interface {
IsReady(ctx context.Context) (bool, error)
GetStatus(ctx context.Context) (domain.UserProjectionStatus, error)
CountTenantMembers(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error)
CountTenantMembersRecursive(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error)
ReplaceAllFromKratos(ctx context.Context, users []domain.User) error
MarkFailed(ctx context.Context, syncErr error) error
}
@@ -108,10 +109,63 @@ func (r *userProjectionRepository) CountTenantMembers(ctx context.Context, tenan
return counts, nil
}
func (r *userProjectionRepository) CountTenantMembersRecursive(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
counts := make(map[string]int64, len(tenants))
for _, tenant := range tenants {
counts[tenant.ID] = 0
}
if len(tenants) == 0 {
return counts, nil
}
valuePlaceholders := make([]string, 0, len(tenants))
args := make([]any, 0, len(tenants))
for _, tenant := range tenants {
valuePlaceholders = append(valuePlaceholders, "(?)")
args = append(args, strings.TrimSpace(tenant.ID))
}
query := fmt.Sprintf(`
WITH RECURSIVE requested(tenant_id) AS (
VALUES %s
),
descendants(root_tenant_id, tenant_id) AS (
SELECT requested.tenant_id, requested.tenant_id
FROM requested
UNION ALL
SELECT descendants.root_tenant_id, child.id::text
FROM descendants
JOIN tenants child
ON child.parent_id::text = descendants.tenant_id
AND child.deleted_at IS NULL
)
SELECT requested.tenant_id, COUNT(DISTINCT users.id) AS count
FROM requested
LEFT JOIN descendants
ON descendants.root_tenant_id = requested.tenant_id
LEFT JOIN users
ON users.deleted_at IS NULL
AND users.tenant_id::text = descendants.tenant_id
GROUP BY requested.tenant_id
`, strings.Join(valuePlaceholders, ","))
type result struct {
TenantID string
Count int64
}
var rows []result
if err := r.db.WithContext(ctx).Raw(query, args...).Scan(&rows).Error; err != nil {
return nil, err
}
for _, row := range rows {
counts[row.TenantID] = row.Count
}
return counts, nil
}
func (r *userProjectionRepository) ReplaceAllFromKratos(ctx context.Context, users []domain.User) error {
now := time.Now()
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
ids := make([]string, 0, len(users))
for i := range users {
users[i].DeletedAt = gorm.DeletedAt{}
if users[i].CreatedAt.IsZero() {
@@ -120,7 +174,6 @@ func (r *userProjectionRepository) ReplaceAllFromKratos(ctx context.Context, use
if users[i].UpdatedAt.IsZero() {
users[i].UpdatedAt = now
}
ids = append(ids, users[i].ID)
}
if len(users) > 0 {
@@ -138,11 +191,6 @@ func (r *userProjectionRepository) ReplaceAllFromKratos(ctx context.Context, use
}).Create(&users).Error; err != nil {
return err
}
if err := tx.Where("id NOT IN ?", ids).Delete(&domain.User{}).Error; err != nil {
return err
}
} else if err := tx.Where("1 = 1").Delete(&domain.User{}).Error; err != nil {
return err
}
return upsertUserProjectionState(tx, domain.UserProjectionStatusReady, &now, "")

View File

@@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestUserProjectionRepository_ReplaceAllFromKratosMarksReadyAndRemovesStaleUsers(t *testing.T) {
func TestUserProjectionRepository_ReplaceAllFromKratosMarksReadyWithoutDeletingUsersMissingFromPartialList(t *testing.T) {
ctx := context.Background()
repo := NewUserProjectionRepository(testDB)
@@ -28,13 +28,14 @@ func TestUserProjectionRepository_ReplaceAllFromKratosMarksReadyAndRemovesStaleU
Type: domain.TenantTypeCompany,
Status: domain.TenantStatusActive,
}).Error)
stale := &domain.User{
existing := &domain.User{
ID: "00000000-0000-0000-0000-000000000099",
Email: "stale@example.com",
Name: "Stale",
Email: "existing@example.com",
Name: "Existing",
CompanyCode: tenantSlug,
TenantID: &tenantID,
}
require.NoError(t, NewUserRepository(testDB).Create(ctx, stale))
require.NoError(t, NewUserRepository(testDB).Create(ctx, existing))
users := []domain.User{
{
@@ -66,11 +67,91 @@ func TestUserProjectionRepository_ReplaceAllFromKratosMarksReadyAndRemovesStaleU
{ID: tenantID, Slug: tenantSlug},
})
require.NoError(t, err)
assert.Equal(t, int64(2), counts[tenantID])
assert.Equal(t, int64(3), counts[tenantID])
var activeCount int64
require.NoError(t, testDB.Model(&domain.User{}).Count(&activeCount).Error)
assert.Equal(t, int64(2), activeCount)
assert.Equal(t, int64(3), activeCount)
var existingCount int64
require.NoError(t, testDB.Model(&domain.User{}).Where("id = ?", existing.ID).Count(&existingCount).Error)
assert.Equal(t, int64(1), existingCount)
var existingRow domain.User
require.NoError(t, testDB.Unscoped().First(&existingRow, "id = ?", existing.ID).Error)
assert.False(t, existingRow.DeletedAt.Valid)
}
func TestUserProjectionRepository_CountTenantMembersRecursiveIncludesDescendantsAndExcludesSoftDeletedUsers(t *testing.T) {
ctx := context.Background()
repo := NewUserProjectionRepository(testDB)
parentID := "20000000-0000-0000-0000-000000000001"
childID := "20000000-0000-0000-0000-000000000002"
grandchildID := "20000000-0000-0000-0000-000000000003"
siblingID := "20000000-0000-0000-0000-000000000004"
tenantIDs := []string{parentID, childID, grandchildID, siblingID}
require.NoError(t, testDB.Exec("DELETE FROM user_login_ids").Error)
require.NoError(t, testDB.Exec("DELETE FROM users").Error)
require.NoError(t, testDB.Unscoped().Where("id IN ?", tenantIDs).Delete(&domain.Tenant{}).Error)
require.NoError(t, testDB.Create(&domain.Tenant{
ID: parentID,
Name: "Recursive Parent",
Slug: "recursive-parent",
Type: domain.TenantTypeCompany,
Status: domain.TenantStatusActive,
}).Error)
require.NoError(t, testDB.Create(&domain.Tenant{
ID: childID,
Name: "Recursive Child",
Slug: "recursive-child",
Type: domain.TenantTypeOrganization,
Status: domain.TenantStatusActive,
ParentID: &parentID,
}).Error)
require.NoError(t, testDB.Create(&domain.Tenant{
ID: grandchildID,
Name: "Recursive Grandchild",
Slug: "recursive-grandchild",
Type: domain.TenantTypeUserGroup,
Status: domain.TenantStatusActive,
ParentID: &childID,
}).Error)
require.NoError(t, testDB.Create(&domain.Tenant{
ID: siblingID,
Name: "Recursive Sibling",
Slug: "recursive-sibling",
Type: domain.TenantTypeCompany,
Status: domain.TenantStatusActive,
}).Error)
users := []domain.User{
{ID: "21000000-0000-0000-0000-000000000001", Email: "parent@example.com", Name: "Parent", TenantID: &parentID},
{ID: "21000000-0000-0000-0000-000000000002", Email: "child@example.com", Name: "Child", TenantID: &childID},
{ID: "21000000-0000-0000-0000-000000000003", Email: "grandchild@example.com", Name: "Grandchild", TenantID: &grandchildID},
{ID: "21000000-0000-0000-0000-000000000004", Email: "deleted-grandchild@example.com", Name: "Deleted Grandchild", TenantID: &grandchildID},
{ID: "21000000-0000-0000-0000-000000000005", Email: "sibling@example.com", Name: "Sibling", TenantID: &siblingID},
}
for i := range users {
require.NoError(t, testDB.Create(&users[i]).Error)
}
require.NoError(t, testDB.Delete(&domain.User{}, "id = ?", users[3].ID).Error)
directCounts, err := repo.CountTenantMembers(ctx, []domain.Tenant{{ID: parentID}, {ID: childID}, {ID: grandchildID}, {ID: siblingID}})
require.NoError(t, err)
assert.Equal(t, int64(1), directCounts[parentID])
assert.Equal(t, int64(1), directCounts[childID])
assert.Equal(t, int64(1), directCounts[grandchildID])
assert.Equal(t, int64(1), directCounts[siblingID])
recursiveCounts, err := repo.CountTenantMembersRecursive(ctx, []domain.Tenant{{ID: parentID}, {ID: childID}, {ID: grandchildID}, {ID: siblingID}})
require.NoError(t, err)
assert.Equal(t, int64(3), recursiveCounts[parentID])
assert.Equal(t, int64(2), recursiveCounts[childID])
assert.Equal(t, int64(1), recursiveCounts[grandchildID])
assert.Equal(t, int64(1), recursiveCounts[siblingID])
}
func TestUserProjectionRepository_MarkFailedMakesProjectionNotReady(t *testing.T) {

View File

@@ -4,6 +4,7 @@ import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/pagination"
"context"
"encoding/json"
"fmt"
"strings"
@@ -46,6 +47,31 @@ func (r *userRepository) DB() *gorm.DB {
return r.db
}
func (r *userRepository) withTenantMembershipFilter(db *gorm.DB, tenantIDs []string) *gorm.DB {
if len(tenantIDs) == 0 {
return db
}
clauses := []string{"tenant_id IN ?"}
args := []any{tenantIDs}
for _, tenantID := range tenantIDs {
tenantID = strings.TrimSpace(tenantID)
if tenantID == "" {
continue
}
payload, err := json.Marshal(map[string]any{
"additionalAppointments": []map[string]string{
{"tenantId": tenantID},
},
})
if err != nil {
continue
}
clauses = append(clauses, "metadata @> ?::jsonb")
args = append(args, string(payload))
}
return db.Where("("+strings.Join(clauses, " OR ")+")", args...)
}
func (r *userRepository) Create(ctx context.Context, user *domain.User) error {
return r.db.WithContext(ctx).Create(user).Error
}
@@ -124,7 +150,7 @@ func (r *userRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.
func (r *userRepository) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
var users []domain.User
if err := r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Find(&users).Error; err != nil {
if err := r.withTenantMembershipFilter(r.db.WithContext(ctx), []string{tenantID}).Find(&users).Error; err != nil {
return nil, err
}
return users, nil
@@ -132,40 +158,23 @@ func (r *userRepository) ListByTenant(ctx context.Context, tenantID string) ([]d
func (r *userRepository) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("tenant_id = ?", tenantID).Count(&count).Error
err := r.withTenantMembershipFilter(r.db.WithContext(ctx).Model(&domain.User{}), []string{tenantID}).Count(&count).Error
return count, err
}
func (r *userRepository) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
type result struct {
TenantID *string
Count int64
}
var results []result
counts := make(map[string]int64)
if len(tenantIDs) == 0 {
return counts, nil
}
if err := r.db.WithContext(ctx).Model(&domain.User{}).
Select("tenant_id, count(*) as count").
Where("tenant_id IN ?", tenantIDs).
Group("tenant_id").
Find(&results).Error; err != nil {
return nil, err
}
for _, res := range results {
if res.TenantID != nil && *res.TenantID != "" {
counts[*res.TenantID] = res.Count
}
}
// Ensure all requested tenant IDs are in the map, even if count is 0
for _, id := range tenantIDs {
if _, ok := counts[id]; !ok {
counts[id] = 0
for _, tenantID := range tenantIDs {
var count int64
if err := r.withTenantMembershipFilter(r.db.WithContext(ctx).Model(&domain.User{}), []string{tenantID}).Count(&count).Error; err != nil {
return nil, err
}
counts[tenantID] = count
}
return counts, nil
}
@@ -222,7 +231,7 @@ func (r *userRepository) List(ctx context.Context, offset, limit int, search str
db := r.db.WithContext(ctx).Model(&domain.User{})
if len(tenantIDs) > 0 {
db = db.Where("tenant_id IN ?", tenantIDs)
db = r.withTenantMembershipFilter(db, tenantIDs)
}
if search != "" {
@@ -311,7 +320,7 @@ func (r *userRepository) FindTenantIDByLoginID(ctx context.Context, loginID stri
func (r *userRepository) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) {
var users []domain.User
err := r.db.WithContext(ctx).Where("tenant_id IN ?", tenantIDs).Find(&users).Error
err := r.withTenantMembershipFilter(r.db.WithContext(ctx), tenantIDs).Find(&users).Error
return users, err
}

View File

@@ -204,6 +204,60 @@ func TestUserRepository(t *testing.T) {
})
}
func TestUserRepository_ListIncludesAdditionalTenantAppointments(t *testing.T) {
repo := NewUserRepository(testDB)
ctx := context.Background()
require.NoError(t, testDB.Exec("DELETE FROM user_login_ids").Error)
require.NoError(t, testDB.Exec("DELETE FROM users").Error)
primaryTenant := createUserRepositoryTestTenant(t, "repo-primary-tenant")
additionalTenant := createUserRepositoryTestTenant(t, "repo-additional-tenant")
primaryTenantID := primaryTenant.ID
additionalTenantID := additionalTenant.ID
users := []domain.User{
{
ID: uuid.NewString(),
Email: "primary-member@example.com",
Name: "Primary Member",
Role: domain.RoleUser,
TenantID: &additionalTenantID,
},
{
ID: uuid.NewString(),
Email: "additional-member@example.com",
Name: "Additional Member",
Role: domain.RoleUser,
TenantID: &primaryTenantID,
Metadata: domain.JSONMap{
"additionalAppointments": []any{
map[string]any{
"tenantId": additionalTenant.ID,
"tenantSlug": additionalTenant.Slug,
"tenantName": additionalTenant.Name,
"isPrimary": false,
},
},
},
},
}
for i := range users {
require.NoError(t, repo.Create(ctx, &users[i]))
}
listed, total, _, err := repo.List(ctx, 0, 20, "", []string{additionalTenant.ID}, "")
require.NoError(t, err)
require.Equal(t, int64(2), total)
require.Len(t, listed, 2)
emails := []string{listed[0].Email, listed[1].Email}
assert.Contains(t, emails, "primary-member@example.com")
assert.Contains(t, emails, "additional-member@example.com")
counts, err := repo.CountByTenantIDs(ctx, []string{additionalTenant.ID})
require.NoError(t, err)
assert.Equal(t, int64(2), counts[additionalTenant.ID])
}
func createUserRepositoryTestTenant(t *testing.T, slug string) domain.Tenant {
t.Helper()
require.NoError(t, testDB.Unscoped().Where("slug = ?", slug).Delete(&domain.Tenant{}).Error)

View File

@@ -9,6 +9,7 @@ import (
"io"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
@@ -68,27 +69,76 @@ func NewKratosAdminService() KratosAdminService {
func (s *kratosAdminService) ListIdentities(ctx context.Context) ([]KratosIdentity, error) {
endpoint := strings.TrimRight(s.AdminURL, "/") + "/admin/identities"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return nil, fmt.Errorf("kratos admin list identities failed status=%d body=%s", resp.StatusCode, string(body))
}
var identities []KratosIdentity
if err := json.NewDecoder(resp.Body).Decode(&identities); err != nil {
return nil, err
pageToken := ""
seenTokens := make(map[string]bool)
for {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
query := req.URL.Query()
query.Set("page_size", "250")
if pageToken != "" {
query.Set("page_token", pageToken)
}
req.URL.RawQuery = query.Encode()
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
_ = resp.Body.Close()
return nil, fmt.Errorf("kratos admin list identities failed status=%d body=%s", resp.StatusCode, string(body))
}
var page []KratosIdentity
if err := json.NewDecoder(resp.Body).Decode(&page); err != nil {
_ = resp.Body.Close()
return nil, err
}
_ = resp.Body.Close()
identities = append(identities, page...)
nextToken := kratosNextPageToken(resp.Header.Values("Link"))
if nextToken == "" {
return identities, nil
}
if seenTokens[nextToken] {
return nil, fmt.Errorf("kratos admin list identities pagination loop detected page_token=%s", nextToken)
}
seenTokens[nextToken] = true
pageToken = nextToken
}
return identities, nil
}
func kratosNextPageToken(linkHeaders []string) string {
for _, header := range linkHeaders {
for _, part := range strings.Split(header, ",") {
part = strings.TrimSpace(part)
if !strings.Contains(part, `rel="next"`) && !strings.Contains(part, `rel=next`) {
continue
}
start := strings.Index(part, "<")
end := strings.Index(part, ">")
if start < 0 || end <= start+1 {
continue
}
rawURL := part[start+1 : end]
parsed, err := url.Parse(rawURL)
if err != nil {
continue
}
if token := strings.TrimSpace(parsed.Query().Get("page_token")); token != "" {
return token
}
}
}
return ""
}
func (s *kratosAdminService) FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error) {

View File

@@ -0,0 +1,63 @@
package service
import (
"bytes"
"context"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestKratosAdminService_ListIdentitiesFollowsNextPagination(t *testing.T) {
var requestedTokens []string
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
require.Equal(t, "/admin/identities", r.URL.Path)
token := r.URL.Query().Get("page_token")
requestedTokens = append(requestedTokens, token)
header := make(http.Header)
header.Set("Content-Type", "application/json")
status := http.StatusOK
body := "[]"
switch token {
case "":
header.Set(
"Link",
`</admin/identities?page_size=2&page_token=identity-2>; rel="next"`,
)
body = `[{"id":"identity-1","traits":{"email":"one@example.com"}},{"id":"identity-2","traits":{"email":"two@example.com"}}]`
case "identity-2":
body = `[{"id":"identity-3","traits":{"email":"three@example.com"}}]`
default:
t.Fatalf("unexpected page_token %q", token)
}
return &http.Response{
StatusCode: status,
Header: header,
Body: io.NopCloser(bytes.NewBufferString(body)),
Request: r,
}, nil
})}
service := &kratosAdminService{
AdminURL: "http://kratos.example",
HTTPClient: client,
}
identities, err := service.ListIdentities(context.Background())
require.NoError(t, err)
require.Equal(t, []string{"", "identity-2"}, requestedTokens)
require.Len(t, identities, 3)
require.Equal(t, "identity-1", identities[0].ID)
require.Equal(t, "identity-2", identities[1].ID)
require.Equal(t, "identity-3", identities[2].ID)
}

View File

@@ -1,7 +1,9 @@
package service
import (
"baron-sso-backend/internal/domain"
"context"
"encoding/json"
"os"
"time"
@@ -14,6 +16,14 @@ type RedisService struct {
Client *redis.Client
}
type identityMirrorStateStore struct {
Status string `json:"status"`
LastRefreshedAt *time.Time `json:"lastRefreshedAt,omitempty"`
LastError string `json:"lastError,omitempty"`
ObservedCount int64 `json:"observedCount,omitempty"`
UpdatedAt *time.Time `json:"updatedAt,omitempty"`
}
// NewRedisService creates and returns a new RedisService
func NewRedisService() (*RedisService, error) {
redisAddr := os.Getenv("REDIS_ADDR")
@@ -90,3 +100,139 @@ func (s *RedisService) Get(key string) (string, error) {
func (s *RedisService) Delete(key string) error {
return s.Client.Del(ctx, key).Err()
}
func (s *RedisService) GetIdentityCacheStatus(ctx context.Context) (domain.IdentityCacheStatus, error) {
if s == nil || s.Client == nil {
return domain.IdentityCacheStatus{
Status: "unavailable",
RedisReady: false,
LastError: "redis service unavailable",
}, nil
}
if err := s.Client.Ping(ctx).Err(); err != nil {
return domain.IdentityCacheStatus{
Status: "failed",
RedisReady: false,
LastError: err.Error(),
}, nil
}
keyCount, err := s.countIdentityCacheKeys(ctx)
if err != nil {
return domain.IdentityCacheStatus{
Status: "failed",
RedisReady: true,
LastError: err.Error(),
}, nil
}
raw, err := s.Client.Get(ctx, "identity:mirror:state").Result()
if err == redis.Nil {
return domain.IdentityCacheStatus{
Status: "empty",
RedisReady: true,
KeyCount: keyCount,
}, nil
}
if err != nil {
return domain.IdentityCacheStatus{
Status: "failed",
RedisReady: true,
KeyCount: keyCount,
LastError: err.Error(),
}, nil
}
var stored identityMirrorStateStore
if err := json.Unmarshal([]byte(raw), &stored); err != nil {
return domain.IdentityCacheStatus{
Status: "failed",
RedisReady: true,
KeyCount: keyCount,
LastError: err.Error(),
}, nil
}
status := stored.Status
if status == "" {
status = "unknown"
}
return domain.IdentityCacheStatus{
Status: status,
RedisReady: true,
ObservedCount: stored.ObservedCount,
KeyCount: keyCount,
LastRefreshedAt: stored.LastRefreshedAt,
LastError: stored.LastError,
UpdatedAt: stored.UpdatedAt,
}, nil
}
func (s *RedisService) FlushIdentityCache(ctx context.Context) (domain.IdentityCacheFlushResult, error) {
if s == nil || s.Client == nil {
return domain.IdentityCacheFlushResult{}, os.ErrInvalid
}
keys, err := s.identityCacheKeys(ctx)
if err != nil {
return domain.IdentityCacheFlushResult{}, err
}
var deleted int64
for len(keys) > 0 {
chunkSize := len(keys)
if chunkSize > 500 {
chunkSize = 500
}
chunk := keys[:chunkSize]
count, err := s.Client.Del(ctx, chunk...).Result()
if err != nil {
return domain.IdentityCacheFlushResult{}, err
}
deleted += count
keys = keys[chunkSize:]
}
return domain.IdentityCacheFlushResult{
Status: "success",
FlushedKeys: deleted,
UpdatedAt: time.Now().UTC(),
}, nil
}
func (s *RedisService) countIdentityCacheKeys(ctx context.Context) (int64, error) {
keys, err := s.identityCacheKeys(ctx)
if err != nil {
return 0, err
}
return int64(len(keys)), nil
}
func (s *RedisService) identityCacheKeys(ctx context.Context) ([]string, error) {
seen := make(map[string]bool)
patterns := []string{
"identity:mirror:*",
"identity:index:*",
}
for _, pattern := range patterns {
var cursor uint64
for {
keys, next, err := s.Client.Scan(ctx, cursor, pattern, 250).Result()
if err != nil {
return nil, err
}
for _, key := range keys {
seen[key] = true
}
cursor = next
if cursor == 0 {
break
}
}
}
keys := make([]string, 0, len(seen))
for key := range seen {
keys = append(keys, key)
}
return keys, nil
}

View File

@@ -0,0 +1,150 @@
package service
import (
"context"
"encoding/json"
"os"
"testing"
"time"
"github.com/go-redis/redis/v8"
"github.com/stretchr/testify/require"
)
type redisCommandStub struct {
scans map[string][]string
stateValue string
deleted []string
}
func (h *redisCommandStub) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
return ctx, nil
}
func (h *redisCommandStub) AfterProcess(ctx context.Context, cmd redis.Cmder) error {
switch cmd.Name() {
case "ping":
if status, ok := cmd.(*redis.StatusCmd); ok {
status.SetVal("PONG")
}
case "scan":
if scan, ok := cmd.(*redis.ScanCmd); ok {
scan.SetVal(h.scans[scanPattern(cmd.Args())], 0)
}
case "get":
if str, ok := cmd.(*redis.StringCmd); ok {
if h.stateValue == "" {
str.SetErr(redis.Nil)
return nil
}
str.SetVal(h.stateValue)
}
case "del":
args := cmd.Args()
keys := make([]string, 0, len(args)-1)
for _, arg := range args[1:] {
keys = append(keys, arg.(string))
}
h.deleted = append(h.deleted, keys...)
if count, ok := cmd.(*redis.IntCmd); ok {
count.SetVal(int64(len(keys)))
}
}
cmd.SetErr(nil)
return nil
}
func (h *redisCommandStub) BeforeProcessPipeline(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
return ctx, nil
}
func (h *redisCommandStub) AfterProcessPipeline(ctx context.Context, cmds []redis.Cmder) error {
return nil
}
func scanPattern(args []interface{}) string {
for index := 0; index < len(args)-1; index++ {
value, ok := args[index].(string)
if ok && value == "match" {
if pattern, ok := args[index+1].(string); ok {
return pattern
}
}
}
return ""
}
func newStubbedRedisService(stub *redisCommandStub) *RedisService {
client := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:1",
MaxRetries: -1,
})
client.AddHook(stub)
return &RedisService{Client: client}
}
func TestRedisServiceGetIdentityCacheStatusReadsStateAndCountsCacheKeys(t *testing.T) {
now := time.Date(2026, 6, 9, 3, 20, 0, 0, time.UTC)
state, err := json.Marshal(identityMirrorStateStore{
Status: "ready",
LastRefreshedAt: &now,
ObservedCount: 42,
UpdatedAt: &now,
})
require.NoError(t, err)
stub := &redisCommandStub{
stateValue: string(state),
scans: map[string][]string{
"identity:mirror:*": {"identity:mirror:state", "identity:mirror:user:1"},
"identity:index:*": {"identity:index:email:a", "identity:mirror:user:1"},
},
}
service := newStubbedRedisService(stub)
status, err := service.GetIdentityCacheStatus(context.Background())
require.NoError(t, err)
require.Equal(t, "ready", status.Status)
require.True(t, status.RedisReady)
require.Equal(t, int64(42), status.ObservedCount)
require.Equal(t, int64(3), status.KeyCount)
require.Equal(t, &now, status.LastRefreshedAt)
require.Equal(t, &now, status.UpdatedAt)
}
func TestRedisServiceFlushIdentityCacheDeletesOnlyIdentityMirrorAndIndexKeys(t *testing.T) {
stub := &redisCommandStub{
scans: map[string][]string{
"identity:mirror:*": {"identity:mirror:state", "identity:mirror:user:1"},
"identity:index:*": {"identity:index:email:a", "identity:mirror:user:1"},
},
}
service := newStubbedRedisService(stub)
result, err := service.FlushIdentityCache(context.Background())
require.NoError(t, err)
require.Equal(t, "success", result.Status)
require.Equal(t, int64(3), result.FlushedKeys)
require.ElementsMatch(t, []string{
"identity:mirror:state",
"identity:mirror:user:1",
"identity:index:email:a",
}, stub.deleted)
}
func TestRedisServiceGetIdentityCacheStatusReturnsUnavailableWithoutClient(t *testing.T) {
status, err := (*RedisService)(nil).GetIdentityCacheStatus(context.Background())
require.NoError(t, err)
require.Equal(t, "unavailable", status.Status)
require.False(t, status.RedisReady)
require.NotEmpty(t, status.LastError)
}
func TestRedisServiceFlushIdentityCacheFailsWithoutClient(t *testing.T) {
_, err := (*RedisService)(nil).FlushIdentityCache(context.Background())
require.ErrorIs(t, err, os.ErrInvalid)
}

View File

@@ -222,44 +222,19 @@ func (s *userGroupService) AddMember(ctx context.Context, groupID, userID string
tenant, _ = s.tenantRepo.FindByID(ctx, group.TenantID)
}
var updatedIdentity *KratosIdentity
// [Fix] Sync Kratos Traits & Local DB when a user is added to an organization
if s.kratos != nil && tenant != nil {
// Fetch Kratos Identity
identity, err := s.kratos.GetIdentity(ctx, userID)
if err == nil && identity != nil {
traits := identity.Traits
if traits == nil {
traits = make(map[string]any)
}
delete(traits, "companyCode")
delete(traits, "companyCodes")
traits["tenant_id"] = tenant.ID
traits["department"] = group.Name
// Update Kratos
updated, updateErr := s.kratos.UpdateIdentity(ctx, userID, traits, identity.State)
if updateErr != nil {
slog.Error("Failed to update identity traits during AddMember", "user", userID, "error", updateErr)
} else if updated != nil {
updatedIdentity = updated
} else {
identity.Traits = traits
updatedIdentity = identity
}
}
}
// Sync local user repo
// Kratos는 identity SSOT이고 조직/부서 정보의 원장이 아니므로 AddMember에서 traits를 수정하지 않습니다.
if s.userRepo != nil && tenant != nil {
localUser, err := s.userRepo.FindByID(ctx, userID)
if err != nil || localUser == nil {
if updatedIdentity != nil {
localUser = mapUserGroupKratosIdentityToLocalUser(*updatedIdentity)
if s.kratos != nil {
identity, identityErr := s.kratos.GetIdentity(ctx, userID)
if identityErr == nil && identity != nil {
localUser = mapUserGroupKratosIdentityToLocalUser(*identity)
} else {
slog.Warn("Skipping local user sync during AddMember because identity read is unavailable", "user", userID, "error", identityErr)
}
} else {
slog.Warn("Skipping local user sync during AddMember because identity projection is unavailable", "user", userID, "error", err)
localUser = nil
}
}
if localUser != nil {
@@ -326,7 +301,7 @@ func mapUserGroupKratosIdentityToLocalUser(identity KratosIdentity) *domain.User
ID: identity.ID,
Email: userGroupTraitString(traits, "email"),
Name: userGroupTraitString(traits, "name"),
Phone: userGroupTraitString(traits, "phone_number"),
Phone: domain.NormalizePhoneNumber(userGroupTraitString(traits, "phone_number")),
Role: role,
Status: userGroupIdentityStatus(identity.State),
Department: userGroupTraitString(traits, "department"),

View File

@@ -272,14 +272,6 @@ func TestUserGroupService_AddMember(t *testing.T) {
mockUserRepo.On("FindByID", mock.Anything, userID).Return(&domain.User{ID: userID}, nil)
mockTenantRepo.On("FindByID", mock.Anything, tenantID).Return(&domain.Tenant{ID: tenantID, Slug: tenantSlug}, nil)
// Mock Kratos
mockKratos.On("GetIdentity", mock.Anything, userID).Return(&KratosIdentity{
ID: userID,
Traits: map[string]any{"email": "user@test.com"},
State: "active",
}, nil)
mockKratos.On("UpdateIdentity", mock.Anything, userID, mock.Anything, "active").Return(&KratosIdentity{}, nil)
// Mock local user repo update (Ignored since Update is hardcoded to return nil without calling m.Called)
// mockUserRepo.On("Update", mock.Anything, mock.MatchedBy(func(u *domain.User) bool {
// return u.CompanyCode == tenantSlug && *u.TenantID == tenantID && u.Department == "Sales"
@@ -299,6 +291,8 @@ func TestUserGroupService_AddMember(t *testing.T) {
assert.NoError(t, err)
mockOutbox.AssertExpectations(t)
mockKratos.AssertExpectations(t)
mockKratos.AssertNotCalled(t, "GetIdentity", mock.Anything, userID)
mockKratos.AssertNotCalled(t, "UpdateIdentity", mock.Anything, userID, mock.Anything, mock.Anything)
// mockUserRepo.AssertExpectations(t)
}
@@ -326,19 +320,6 @@ func TestUserGroupService_AddMemberUpsertsLocalReadModelWhenMissing(t *testing.T
},
State: "active",
}, nil)
mockKratos.On("UpdateIdentity", mock.Anything, userID, mock.MatchedBy(func(traits map[string]any) bool {
_, hasCompanyCode := traits["companyCode"]
return !hasCompanyCode && traits["tenant_id"] == tenantID && traits["department"] == "Sales"
}), "active").Return(&KratosIdentity{
ID: userID,
Traits: map[string]any{
"email": "user@test.com",
"name": "User Test",
"tenant_id": tenantID,
"department": "Sales",
},
State: "active",
}, nil)
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
return e.Namespace == "Tenant" && e.Object == groupID && e.Relation == "members" && e.Subject == "User:"+userID
})).Return(nil).Once()
@@ -356,6 +337,7 @@ func TestUserGroupService_AddMemberUpsertsLocalReadModelWhenMissing(t *testing.T
assert.Equal(t, "Sales", mockUserRepo.updatedUsers[0].Department)
mockOutbox.AssertExpectations(t)
mockKratos.AssertExpectations(t)
mockKratos.AssertNotCalled(t, "UpdateIdentity", mock.Anything, userID, mock.Anything, mock.Anything)
}
func TestUserGroupService_AddMemberEnqueuesWorksmobileUserSync(t *testing.T) {
@@ -380,16 +362,6 @@ func TestUserGroupService_AddMemberEnqueuesWorksmobileUserSync(t *testing.T) {
Status: "active",
}, nil)
mockTenantRepo.On("FindByID", mock.Anything, tenantID).Return(&domain.Tenant{ID: tenantID, Slug: "tenant-slug"}, nil)
mockKratos.On("GetIdentity", mock.Anything, userID).Return(&KratosIdentity{
ID: userID,
Traits: map[string]any{"email": "user@test.com"},
State: "active",
}, nil)
mockKratos.On("UpdateIdentity", mock.Anything, userID, mock.Anything, "active").Return(&KratosIdentity{
ID: userID,
Traits: map[string]any{"email": "user@test.com", "tenant_id": tenantID, "department": "Sales"},
State: "active",
}, nil)
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
return e.Namespace == "Tenant" && e.Object == groupID && e.Relation == "members" && e.Subject == "User:"+userID
})).Return(nil).Once()
@@ -407,6 +379,8 @@ func TestUserGroupService_AddMemberEnqueuesWorksmobileUserSync(t *testing.T) {
assert.Equal(t, "Sales", worksmobile.userUpserts[0].Department)
mockOutbox.AssertExpectations(t)
mockKratos.AssertExpectations(t)
mockKratos.AssertNotCalled(t, "GetIdentity", mock.Anything, userID)
mockKratos.AssertNotCalled(t, "UpdateIdentity", mock.Anything, userID, mock.Anything, mock.Anything)
}
func TestUserGroupService_AssignRoleToTenant(t *testing.T) {

View File

@@ -75,7 +75,7 @@ func MapKratosIdentityToLocalUser(identity KratosIdentity) domain.User {
ID: identity.ID,
Email: kratosProjectionTraitString(traits, "email"),
Name: kratosProjectionTraitString(traits, "name"),
Phone: kratosProjectionTraitString(traits, "phone_number"),
Phone: domain.NormalizePhoneNumber(kratosProjectionTraitString(traits, "phone_number")),
Role: role,
Status: normalizeProjectionStatus(identity.State),
Department: kratosProjectionTraitString(traits, "department"),

View File

@@ -28,6 +28,10 @@ func (f *fakeUserProjectionRepo) CountTenantMembers(ctx context.Context, tenants
return nil, nil
}
func (f *fakeUserProjectionRepo) CountTenantMembersRecursive(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
return nil, nil
}
func (f *fakeUserProjectionRepo) ReplaceAllFromKratos(ctx context.Context, users []domain.User) error {
f.replacedUsers = append([]domain.User(nil), users...)
return f.replaceErr
@@ -79,6 +83,33 @@ func TestUserProjectionSyncService_ReconcileReplacesProjectionFromKratos(t *test
kratos.AssertExpectations(t)
}
func TestUserProjectionSyncService_ReconcileDeduplicatesKoreanCountryCodePhone(t *testing.T) {
ctx := context.Background()
kratos := new(MockKratosAdminServiceShared)
repo := &fakeUserProjectionRepo{}
svc := NewUserProjectionSyncService(kratos, repo)
kratos.On("ListIdentities", ctx).Return([]KratosIdentity{
{
ID: "00000000-0000-0000-0000-000000000102",
Traits: map[string]any{
"email": "two@example.com",
"name": "Two",
"phone_number": "+82 +821091917771",
},
State: "active",
},
}, nil).Once()
count, err := svc.Reconcile(ctx)
require.NoError(t, err)
assert.Equal(t, 1, count)
require.Len(t, repo.replacedUsers, 1)
assert.Equal(t, "+821091917771", repo.replacedUsers[0].Phone)
kratos.AssertExpectations(t)
}
func TestUserProjectionSyncService_ReconcileMarksFailedWhenKratosFails(t *testing.T) {
ctx := context.Background()
kratos := new(MockKratosAdminServiceShared)

View File

@@ -1,6 +1,7 @@
package service
import (
"baron-sso-backend/internal/domain"
"bytes"
"context"
"crypto"
@@ -17,11 +18,13 @@ import (
"net/url"
"strconv"
"strings"
"sync"
"time"
)
const (
defaultWorksmobileOAuthScope = "directory"
defaultWorksmobileOAuthScope = "directory"
worksmobileAPIRateLimitPerMinute = 240
)
type WorksmobileDirectoryClient interface {
@@ -43,6 +46,7 @@ type WorksmobileHTTPClient struct {
DirectoryToken string
SCIMToken string
HTTPClient *http.Client
RateLimiter WorksmobileRateLimiter
OAuthConfig WorksmobileOAuthConfig
DomainIDs []int64
OrgUnitWriteDelay time.Duration
@@ -50,6 +54,16 @@ type WorksmobileHTTPClient struct {
now func() time.Time
}
type WorksmobileRateLimiter interface {
Wait(ctx context.Context, key string) error
}
type worksmobileAPIRateLimiter struct {
interval time.Duration
mu sync.Mutex
next map[string]time.Time
}
type WorksmobileOAuthConfig struct {
ClientID string
ClientSecret string
@@ -64,6 +78,46 @@ type worksmobileAccessTokenCache struct {
ExpiresAt time.Time
}
func NewWorksmobileAPIRateLimiter(limit int, window time.Duration) WorksmobileRateLimiter {
if limit <= 0 || window <= 0 {
return &worksmobileAPIRateLimiter{}
}
return &worksmobileAPIRateLimiter{
interval: window / time.Duration(limit),
next: map[string]time.Time{},
}
}
func (l *worksmobileAPIRateLimiter) Wait(ctx context.Context, key string) error {
if l == nil || l.interval <= 0 {
return nil
}
key = strings.TrimSpace(key)
if key == "" {
key = "UNKNOWN"
}
l.mu.Lock()
now := time.Now()
waitUntil := l.next[key]
if waitUntil.Before(now) {
waitUntil = now
}
l.next[key] = waitUntil.Add(l.interval)
l.mu.Unlock()
if delay := time.Until(waitUntil); delay > 0 {
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
}
}
return nil
}
func (c WorksmobileOAuthConfig) normalized() WorksmobileOAuthConfig {
c.ClientID = strings.Trim(strings.TrimSpace(c.ClientID), `"`)
c.ClientSecret = strings.Trim(strings.TrimSpace(c.ClientSecret), `"`)
@@ -280,7 +334,10 @@ func (c *WorksmobileHTTPClient) UpsertUser(ctx context.Context, payload Worksmob
if identifier == "" {
identifier = strings.TrimSpace(payload.UserExternalKey)
}
return c.PatchUser(ctx, identifier, NewWorksmobileUserPatchPayload(payload))
if patchErr := c.PatchUser(ctx, identifier, NewWorksmobileUserPatchPayload(payload)); patchErr != nil {
return fmt.Errorf("worksmobile user create conflict: %w; patch after conflict failed: %v", err, patchErr)
}
return nil
}
return err
}
@@ -306,6 +363,23 @@ func (c *WorksmobileHTTPClient) AddUserAliasEmail(ctx context.Context, userID st
return err
}
func (c *WorksmobileHTTPClient) RemoveUserAliasEmail(ctx context.Context, userID string, email string) error {
userID = strings.TrimSpace(userID)
email = strings.TrimSpace(email)
if userID == "" {
return fmt.Errorf("worksmobile user id is required")
}
if email == "" {
return fmt.Errorf("worksmobile alias email is required")
}
return c.sendDirectoryJSON(
ctx,
http.MethodDelete,
"/v1.0/users/"+url.PathEscape(userID)+"/alias-emails/"+url.PathEscape(email),
nil,
)
}
func (c *WorksmobileHTTPClient) ResetUserPassword(ctx context.Context, userID string, password string) error {
userID = strings.TrimSpace(userID)
password = strings.TrimSpace(password)
@@ -315,15 +389,38 @@ func (c *WorksmobileHTTPClient) ResetUserPassword(ctx context.Context, userID st
if password == "" {
return fmt.Errorf("worksmobile password is required")
}
changePasswordAtNextLogin := true
payload := map[string]any{
"passwordConfig": WorksmobilePasswordConfig{
PasswordCreationType: "ADMIN",
Password: password,
PasswordCreationType: "ADMIN",
Password: password,
ChangePasswordAtNextLogin: &changePasswordAtNextLogin,
},
}
return c.sendDirectoryJSON(ctx, http.MethodPatch, "/v1.0/users/"+url.PathEscape(userID), payload)
}
func (c *WorksmobileHTTPClient) GetUser(ctx context.Context, userID string) (*WorksmobileRemoteUser, error) {
userID = strings.TrimSpace(userID)
if userID == "" {
return nil, fmt.Errorf("worksmobile user id is required")
}
var response map[string]any
if err := c.getDirectoryJSON(ctx, "/v1.0/users/"+url.PathEscape(userID), &response); err != nil {
return nil, err
}
user := parseWorksmobileDirectoryUser(response)
return &user, nil
}
func (c *WorksmobileHTTPClient) UndeleteUser(ctx context.Context, userID string) error {
userID = strings.TrimSpace(userID)
if userID == "" {
return fmt.Errorf("worksmobile user id is required")
}
return c.sendDirectoryJSON(ctx, http.MethodPost, "/v1.0/users/"+url.PathEscape(userID)+"/undelete", nil)
}
func (c *WorksmobileHTTPClient) PatchUser(ctx context.Context, identifier string, payload WorksmobileUserPatchPayload) error {
identifier = strings.TrimSpace(identifier)
if identifier == "" {
@@ -484,6 +581,9 @@ func (c *WorksmobileHTTPClient) getJSON(ctx context.Context, path string, target
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Accept", "application/json")
if err := c.waitForWorksmobileAPI(ctx, req.Method, req.URL); err != nil {
return err
}
resp, err := c.httpClient().Do(req)
if err != nil {
return err
@@ -512,6 +612,9 @@ func (c *WorksmobileHTTPClient) getDirectoryJSON(ctx context.Context, path strin
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Accept", "application/json")
if err := c.waitForWorksmobileAPI(ctx, req.Method, req.URL); err != nil {
return err
}
resp, err := c.httpClient().Do(req)
if err != nil {
return err
@@ -665,6 +768,9 @@ func (c *WorksmobileHTTPClient) requestDirectoryAccessToken(ctx context.Context,
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
if err := c.waitForWorksmobileAPI(ctx, req.Method, req.URL); err != nil {
return "", time.Time{}, err
}
resp, err := c.httpClient().Do(req)
if err != nil {
return "", time.Time{}, err
@@ -729,6 +835,9 @@ func (c *WorksmobileHTTPClient) sendJSONWithToken(ctx context.Context, method st
req.Header.Set("Content-Type", "application/json")
}
if err := c.waitForWorksmobileAPI(ctx, req.Method, req.URL); err != nil {
return err
}
resp, err := c.httpClient().Do(req)
if err != nil {
return err
@@ -801,6 +910,7 @@ type WorksmobileRemoteUser struct {
ExternalID string `json:"externalId"`
UserName string `json:"userName"`
Email string `json:"email"`
AliasEmails []string `json:"aliasEmails,omitempty"`
DisplayName string `json:"displayName"`
CellPhone string `json:"cellPhone,omitempty"`
EmployeeNumber string `json:"employeeNumber,omitempty"`
@@ -817,6 +927,10 @@ type WorksmobileRemoteUser struct {
OrgUnitManagers map[string]*bool `json:"orgUnitManagers,omitempty"`
Organizations []WorksmobileUserOrganization `json:"organizations,omitempty"`
Active bool `json:"active"`
IsAwaiting bool `json:"isAwaiting"`
IsPending bool `json:"isPending"`
IsSuspended bool `json:"isSuspended"`
IsDeleted bool `json:"isDeleted"`
}
type WorksmobileRemoteGroup struct {
@@ -852,18 +966,22 @@ func NewWorksmobileSCIMUserPayload(payload WorksmobileUserPayload) WorksmobileSC
},
}
if strings.TrimSpace(payload.CellPhone) != "" {
result.PhoneNumbers = []WorksmobileSCIMPhoneNumber{{Value: strings.TrimSpace(payload.CellPhone), Primary: true, Type: "mobile"}}
result.PhoneNumbers = []WorksmobileSCIMPhoneNumber{{Value: normalizeWorksmobileOutboundCellPhone(payload.CellPhone), Primary: true, Type: "mobile"}}
}
return result
}
func normalizeWorksmobileOutboundCellPhone(value string) string {
return domain.NormalizePhoneNumber(value)
}
func NewWorksmobileUserPatchPayload(payload WorksmobileUserPayload) WorksmobileUserPatchPayload {
return WorksmobileUserPatchPayload{
DomainID: payload.DomainID,
Email: strings.TrimSpace(payload.Email),
UserExternalKey: strings.TrimSpace(payload.UserExternalKey),
UserName: payload.UserName,
CellPhone: strings.TrimSpace(payload.CellPhone),
CellPhone: normalizeWorksmobileOutboundCellPhone(payload.CellPhone),
EmployeeNumber: strings.TrimSpace(payload.EmployeeNumber),
AliasEmails: payload.AliasEmails,
Locale: strings.TrimSpace(payload.Locale),
@@ -937,6 +1055,7 @@ func parseWorksmobileDirectoryUser(resource map[string]any) WorksmobileRemoteUse
ExternalID: firstStringFromMap(resource, "userExternalKey", "externalKey", "externalId"),
UserName: email,
Email: email,
AliasEmails: stringListFromMap(resource, "aliasEmails"),
DisplayName: parseWorksmobileDirectoryUserName(resource),
CellPhone: firstStringFromMap(resource, "cellPhone", "phoneNumber", "phone", "mobile", "mobilePhone"),
EmployeeNumber: firstStringFromMap(
@@ -954,6 +1073,10 @@ func parseWorksmobileDirectoryUser(resource map[string]any) WorksmobileRemoteUse
if active, ok := resource["active"].(bool); ok {
user.Active = active
}
user.IsAwaiting = boolFromMap(resource, "isAwaiting")
user.IsPending = boolFromMap(resource, "isPending")
user.IsSuspended = boolFromMap(resource, "isSuspended")
user.IsDeleted = boolFromMap(resource, "isDeleted")
primaryOrgUnit := parseWorksmobilePrimaryOrgUnitDetail(resource)
user.PrimaryOrgUnitID = primaryOrgUnit.ID
user.PrimaryOrgUnitName = primaryOrgUnit.Name
@@ -1285,6 +1408,25 @@ func firstStringFromMap(values map[string]any, keys ...string) string {
return ""
}
func stringListFromMap(values map[string]any, key string) []string {
raw, ok := values[key].([]any)
if !ok {
return nil
}
result := make([]string, 0, len(raw))
for _, item := range raw {
value, ok := item.(string)
if !ok {
continue
}
value = strings.TrimSpace(value)
if value != "" {
result = append(result, value)
}
}
return result
}
func boolFromMap(values map[string]any, key string) bool {
value, _ := values[key].(bool)
return value
@@ -1324,6 +1466,42 @@ func (c *WorksmobileHTTPClient) requestURL(path string) (string, error) {
return strings.TrimRight(baseURL, "/") + path, nil
}
func (c *WorksmobileHTTPClient) waitForWorksmobileAPI(ctx context.Context, method string, requestURL *url.URL) error {
if c.RateLimiter == nil {
return nil
}
return c.RateLimiter.Wait(ctx, worksmobileRateLimitKey(method, requestURL))
}
func worksmobileRateLimitKey(method string, requestURL *url.URL) string {
normalizedMethod := strings.ToUpper(strings.TrimSpace(method))
if normalizedMethod == "" {
normalizedMethod = "GET"
}
return normalizedMethod + " " + normalizeWorksmobileRateLimitPath(requestURL)
}
func normalizeWorksmobileRateLimitPath(requestURL *url.URL) string {
if requestURL == nil {
return "/"
}
path := requestURL.EscapedPath()
if path == "" {
path = "/"
}
segments := strings.Split(strings.Trim(path, "/"), "/")
if len(segments) == 1 && segments[0] == "" {
return "/"
}
for i := 1; i < len(segments); i++ {
switch strings.ToLower(segments[i-1]) {
case "users", "orgunits", "groups", "alias-emails":
segments[i] = "{id}"
}
}
return "/" + strings.Join(segments, "/")
}
func (c *WorksmobileHTTPClient) httpClient() *http.Client {
if c.HTTPClient != nil {
return c.HTTPClient

View File

@@ -64,6 +64,28 @@ func TestWorksmobileHTTPClientCreateUserPostsDirectoryAdminPasswordPayload(t *te
require.Len(t, passwordConfig["password"], 16)
}
func TestNewWorksmobileUserPatchPayloadNormalizesMalformedKoreanCellPhone(t *testing.T) {
payload := NewWorksmobileUserPatchPayload(WorksmobileUserPayload{
DomainID: 1001,
Email: "phone-canonical@samaneng.com",
CellPhone: "+82+821062836786",
UserName: WorksmobileUserName{LastName: "Phone Canonical User"},
})
require.Equal(t, "+821062836786", payload.CellPhone)
}
func TestNewWorksmobileSCIMUserPayloadNormalizesMalformedKoreanCellPhone(t *testing.T) {
payload := NewWorksmobileSCIMUserPayload(WorksmobileUserPayload{
Email: "phone-canonical@samaneng.com",
CellPhone: "+82+821062836786",
UserName: WorksmobileUserName{LastName: "Phone Canonical User"},
})
require.Len(t, payload.PhoneNumbers, 1)
require.Equal(t, "+821062836786", payload.PhoneNumbers[0].Value)
}
func TestWorksmobileHTTPClientUpsertUserPatchesOnCreateConflictWithoutPasswordOrPrivateEmail(t *testing.T) {
transport := &captureRoundTripper{
responses: []captureResponse{
@@ -155,6 +177,7 @@ func TestWorksmobileHTTPClientResetUserPasswordPatchesPasswordConfig(t *testing.
passwordConfig := payload["passwordConfig"].(map[string]any)
require.Equal(t, "ADMIN", passwordConfig["passwordCreationType"])
require.Equal(t, "Aa1!Aa1!Aa1!Aa1!", passwordConfig["password"])
require.Equal(t, true, passwordConfig["changePasswordAtNextLogin"])
}
func TestWorksmobileHTTPClientCreateUserRequiresDirectoryToken(t *testing.T) {
@@ -225,6 +248,84 @@ func TestWorksmobileHTTPClientRequestsJWTBearerAccessToken(t *testing.T) {
require.Equal(t, float64(1710003600), payload["exp"])
}
func TestWorksmobileHTTPClientAppliesRateLimitBeforeDirectoryAPICalls(t *testing.T) {
transport := &captureRoundTripper{
statusCode: http.StatusCreated,
body: `{}`,
}
limiter := &captureWorksmobileRateLimiter{}
client := &WorksmobileHTTPClient{
BaseURL: "https://works.example.test",
DirectoryToken: "directory-token-1",
HTTPClient: &http.Client{Transport: transport},
RateLimiter: limiter,
}
err := client.CreateUser(context.Background(), WorksmobileUserPayload{
Email: "tester@samaneng.com",
PasswordConfig: WorksmobilePasswordConfig{PasswordCreationType: "ADMIN", Password: "Aa1!Aa1!Aa1!Aa1!"},
})
require.NoError(t, err)
require.Equal(t, []string{"POST /v1.0/users"}, limiter.keys)
require.Len(t, transport.requests, 1)
}
func TestWorksmobileHTTPClientAppliesRateLimitBeforeOAuthTokenCalls(t *testing.T) {
privateKey := testRSAPrivateKeyPEM(t)
transport := &captureRoundTripper{
statusCode: http.StatusOK,
body: `{"access_token":"directory-token-from-jwt","token_type":"Bearer","expires_in":3600}`,
}
limiter := &captureWorksmobileRateLimiter{}
client := &WorksmobileHTTPClient{
HTTPClient: &http.Client{Transport: transport},
RateLimiter: limiter,
now: func() time.Time { return time.Unix(1710000000, 0) },
OAuthConfig: WorksmobileOAuthConfig{
ClientID: "client-id-1",
ClientSecret: "client-secret-1",
ServiceAccount: "service-account-1",
PrivateKey: privateKey,
Scope: "directory",
TokenURL: "https://auth.example.test/oauth2/v2.0/token",
},
}
_, _, err := client.requestDirectoryAccessToken(context.Background(), time.Unix(1710000000, 0))
require.NoError(t, err)
require.Equal(t, []string{"POST /oauth2/v2.0/token"}, limiter.keys)
require.Len(t, transport.requests, 1)
}
func TestWorksmobileHTTPClientRateLimitKeyNormalizesResourceIDsAndDropsQuery(t *testing.T) {
parsedURL, err := url.Parse("https://works.example.test/v1.0/users/user%40example.com/alias-emails/alias%40example.com?domainId=1")
require.NoError(t, err)
require.Equal(
t,
"POST /v1.0/users/{id}/alias-emails/{id}",
worksmobileRateLimitKey(http.MethodPost, parsedURL),
)
parsedURL, err = url.Parse("https://works.example.test/scim/v2/Users/works-user-1")
require.NoError(t, err)
require.Equal(t, "PATCH /scim/v2/Users/{id}", worksmobileRateLimitKey(http.MethodPatch, parsedURL))
}
func TestNewWorksmobileHTTPClientDoesNotInstallRateLimiterByDefault(t *testing.T) {
client := NewWorksmobileHTTPClientWithAuth("directory-token", "scim-token", WorksmobileOAuthConfig{})
require.Nil(t, client.RateLimiter)
}
func TestNewWorksmobileAPIRateLimiterCreatesLimiterForWorkerUse(t *testing.T) {
limiter := NewWorksmobileAPIRateLimiter(240, time.Minute)
require.NotNil(t, limiter)
}
func TestWorksmobileHTTPClientRequiresConfiguredAPIBaseURL(t *testing.T) {
client := &WorksmobileHTTPClient{
DirectoryToken: "directory-token-1",
@@ -608,6 +709,18 @@ func TestWorksmobileRelayWorkerProcessesUserSuspendAndMarksProcessed(t *testing.
require.Equal(t, []string{"tester@samaneng.com"}, client.suspendedUsers)
}
func TestWorksmobileRelayWorkerProcessOnceRecoversPanic(t *testing.T) {
repo := &fakeWorksmobileOutboxRepo{listReadyPanic: "list ready crashed"}
client := &fakeWorksmobileDirectoryClient{}
worker := NewWorksmobileRelayWorker(repo, client)
err := worker.ProcessOnce(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "worksmobile relay panic")
require.Contains(t, err.Error(), "list ready crashed")
}
func TestWorksmobileRelayWorkerProcessesActiveUserUpsertAndReactivates(t *testing.T) {
repo := &fakeWorksmobileOutboxRepo{
ready: []domain.WorksmobileOutbox{
@@ -736,6 +849,65 @@ func TestWorksmobileRelayWorkerSkipsDispatchWhenJobClaimFails(t *testing.T) {
require.Empty(t, client.createdOrgUnits)
}
func TestWorksmobileRelayWorkerSkipsProcessingWhenLeaderLockIsNotHeld(t *testing.T) {
repo := &fakeWorksmobileOutboxRepo{
ready: []domain.WorksmobileOutbox{
{
ID: "job-1",
ResourceType: domain.WorksmobileResourceUser,
ResourceID: "user-1",
Action: domain.WorksmobileActionUpsert,
Status: domain.WorksmobileOutboxStatusPending,
Payload: worksmobileUserOutboxPayload("root-1", WorksmobileUserPayload{
Email: "tester@samaneng.com",
UserExternalKey: "user-1",
}),
},
},
}
client := &fakeWorksmobileDirectoryClient{}
worker := NewWorksmobileRelayWorker(repo, client)
worker.SetLeaderLock(&fakeWorksmobileRelayLeaderLock{held: false})
err := worker.ProcessOnce(context.Background())
require.NoError(t, err)
require.Zero(t, repo.listReadyCalls)
require.Empty(t, repo.processingIDs)
require.Empty(t, repo.processedIDs)
require.Empty(t, client.createdUsers)
}
func TestWorksmobileRelayWorkerProcessesWhenLeaderLockIsHeld(t *testing.T) {
repo := &fakeWorksmobileOutboxRepo{
ready: []domain.WorksmobileOutbox{
{
ID: "job-1",
ResourceType: domain.WorksmobileResourceUser,
ResourceID: "user-1",
Action: domain.WorksmobileActionUpsert,
Status: domain.WorksmobileOutboxStatusPending,
Payload: worksmobileUserOutboxPayload("root-1", WorksmobileUserPayload{
Email: "tester@samaneng.com",
UserExternalKey: "user-1",
}),
},
},
}
client := &fakeWorksmobileDirectoryClient{}
lock := &fakeWorksmobileRelayLeaderLock{held: true}
worker := NewWorksmobileRelayWorker(repo, client)
worker.SetLeaderLock(lock)
err := worker.ProcessOnce(context.Background())
require.NoError(t, err)
require.Equal(t, 1, lock.ensureCalls)
require.Equal(t, 1, repo.listReadyCalls)
require.Equal(t, []string{"job-1"}, repo.processedIDs)
require.Equal(t, "tester@samaneng.com", client.createdUsers[0].Email)
}
func TestRedactWorksmobileOutboxPayloadsRemovesInitialPasswordFromOverview(t *testing.T) {
jobs := []domain.WorksmobileOutbox{
{
@@ -1179,6 +1351,8 @@ type fakeWorksmobileOutboxRepo struct {
payloadUpdates []domain.JSONMap
deletedPendingTenantRootID string
deletedPendingCount int
listReadyCalls int
listReadyPanic any
markProcessingClaims map[string]bool
processingIDs []string
processedIDs []string
@@ -1224,6 +1398,10 @@ func (f *fakeWorksmobileOutboxRepo) DeletePendingByTenantRoot(ctx context.Contex
}
func (f *fakeWorksmobileOutboxRepo) ListReady(ctx context.Context, limit int) ([]domain.WorksmobileOutbox, error) {
f.listReadyCalls++
if f.listReadyPanic != nil {
panic(f.listReadyPanic)
}
return f.ready, nil
}
@@ -1282,6 +1460,25 @@ type captureResponse struct {
body string
}
type captureWorksmobileRateLimiter struct {
keys []string
}
func (l *captureWorksmobileRateLimiter) Wait(ctx context.Context, key string) error {
l.keys = append(l.keys, key)
return ctx.Err()
}
type fakeWorksmobileRelayLeaderLock struct {
held bool
ensureCalls int
}
func (l *fakeWorksmobileRelayLeaderLock) EnsureLeadership(ctx context.Context) (bool, error) {
l.ensureCalls++
return l.held, ctx.Err()
}
func (t *captureRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
t.request = req
t.requests = append(t.requests, req)

View File

@@ -56,7 +56,7 @@ func TestWorksmobileLiveSamanUsersDirectoryProvisioning(t *testing.T) {
require.NoError(t, outboxRepo.MarkProcessed(ctx, job.ID))
continue
}
item, err := syncService.EnqueueUserSync(ctx, root.ID, user.ID, "")
item, err := syncService.EnqueueUserSync(ctx, root.ID, user.ID, "", "")
require.NoError(t, err)
require.NotEmpty(t, item)
require.NoError(t, outboxRepo.MarkRetry(ctx, job.ID))

View File

@@ -3,6 +3,7 @@ package service
import (
"baron-sso-backend/internal/domain"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
"math/big"
@@ -30,14 +31,14 @@ type WorksmobileOrgUnitPayload struct {
type WorksmobileUserPayload struct {
DomainID int64 `json:"domainId"`
Email string `json:"email"`
UserExternalKey string `json:"userExternalKey"`
UserExternalKey string `json:"userExternalKey,omitempty"`
UserName WorksmobileUserName `json:"userName"`
CellPhone string `json:"cellPhone,omitempty"`
EmployeeNumber string `json:"employeeNumber,omitempty"`
PrivateEmail string `json:"privateEmail,omitempty"`
AliasEmails []string `json:"aliasEmails,omitempty"`
Locale string `json:"locale,omitempty"`
PasswordConfig WorksmobilePasswordConfig `json:"passwordConfig"`
PasswordConfig WorksmobilePasswordConfig `json:"passwordConfig,omitempty"`
Task string `json:"task,omitempty"`
Organizations []WorksmobileUserOrganization `json:"organizations,omitempty"`
}
@@ -47,8 +48,52 @@ type WorksmobileUserName struct {
}
type WorksmobilePasswordConfig struct {
PasswordCreationType string `json:"passwordCreationType"`
Password string `json:"password"`
PasswordCreationType string `json:"passwordCreationType"`
Password string `json:"password"`
ChangePasswordAtNextLogin *bool `json:"changePasswordAtNextLogin,omitempty"`
}
func (c WorksmobilePasswordConfig) IsZero() bool {
return strings.TrimSpace(c.PasswordCreationType) == "" &&
strings.TrimSpace(c.Password) == "" &&
c.ChangePasswordAtNextLogin == nil
}
func (p WorksmobileUserPayload) MarshalJSON() ([]byte, error) {
type payloadJSON struct {
DomainID int64 `json:"domainId"`
Email string `json:"email"`
UserExternalKey string `json:"userExternalKey,omitempty"`
UserName WorksmobileUserName `json:"userName"`
CellPhone string `json:"cellPhone,omitempty"`
EmployeeNumber string `json:"employeeNumber,omitempty"`
PrivateEmail string `json:"privateEmail,omitempty"`
AliasEmails []string `json:"aliasEmails,omitempty"`
Locale string `json:"locale,omitempty"`
PasswordConfig *WorksmobilePasswordConfig `json:"passwordConfig,omitempty"`
Task string `json:"task,omitempty"`
Organizations []WorksmobileUserOrganization `json:"organizations,omitempty"`
}
var passwordConfig *WorksmobilePasswordConfig
if !p.PasswordConfig.IsZero() {
passwordConfig = &p.PasswordConfig
}
return json.Marshal(payloadJSON{
DomainID: p.DomainID,
Email: p.Email,
UserExternalKey: p.UserExternalKey,
UserName: p.UserName,
CellPhone: p.CellPhone,
EmployeeNumber: p.EmployeeNumber,
PrivateEmail: p.PrivateEmail,
AliasEmails: p.AliasEmails,
Locale: p.Locale,
PasswordConfig: passwordConfig,
Task: p.Task,
Organizations: p.Organizations,
})
}
type WorksmobilePasswordResetPayload struct {
@@ -184,15 +229,11 @@ func BuildWorksmobileUserPayloadForDomainTenants(user domain.User, tenant domain
Email: strings.TrimSpace(user.Email),
UserExternalKey: user.ID,
UserName: WorksmobileUserName{LastName: strings.TrimSpace(user.Name)},
CellPhone: strings.TrimSpace(user.Phone),
CellPhone: domain.NormalizePhoneNumber(user.Phone),
EmployeeNumber: employeeNumber,
Locale: "ko_KR",
PasswordConfig: WorksmobilePasswordConfig{
PasswordCreationType: "ADMIN",
Password: GenerateWorksmobileInitialPassword(),
},
Task: task,
Organizations: organizations,
Task: task,
Organizations: organizations,
}
payload.AliasEmails = BuildWorksmobileAliasEmails(user, tenant)
return payload, nil
@@ -205,12 +246,20 @@ type worksmobileAppointment struct {
HasManager bool
JobTitle string
PositionID string
Source string
}
func buildWorksmobileUserOrganizations(user domain.User, tenant domain.Tenant, tenantByID map[string]domain.Tenant, rootConfig domain.JSONMap) ([]WorksmobileUserOrganization, string, error) {
appointments := worksmobileAppointmentsFromMetadata(user.Metadata)
if len(appointments) == 0 {
appointments = []worksmobileAppointment{{TenantID: tenant.ID, IsPrimary: true}}
} else if !worksmobileAppointmentsContainTenant(appointments, tenant.ID) && !worksmobileAppointmentsHavePrimary(appointments) {
appointments = append([]worksmobileAppointment{{
TenantID: tenant.ID,
IsPrimary: true,
JobTitle: strings.TrimSpace(user.JobTitle),
PositionID: metadataString(user.Metadata, "worksmobilePositionId", "positionId", "position_id"),
}}, appointments...)
}
accountDomainTenant := worksmobileAccountDomainTenantFromEmail(user.Email, tenant, tenantByID)
accountDomainEnvKey := worksmobileTenantDomainIDEnvKey(accountDomainTenant)
@@ -235,6 +284,17 @@ func buildWorksmobileUserOrganizations(user domain.User, tenant domain.Tenant, t
if !ok {
continue
}
if worksmobileShouldSkipEmailDomainRootAppointment(appointment, appointmentTenant, appointments, tenantByID) {
seen[appointment.TenantID] = true
continue
}
if isWorksmobileDomainRootTenant(appointmentTenant) {
if appointment.IsPrimary && strings.TrimSpace(appointment.JobTitle) != "" && task == "" {
task = strings.TrimSpace(appointment.JobTitle)
}
seen[appointment.TenantID] = true
continue
}
if err := ValidateWorksmobileExternalKey(appointmentTenant.ID); err != nil {
return nil, "", err
}
@@ -276,7 +336,7 @@ func buildWorksmobileUserOrganizations(user domain.User, tenant domain.Tenant, t
seen[appointment.TenantID] = true
}
if len(organizations) == 0 {
return nil, "", errors.New("no valid worksmobile organization")
return nil, task, nil
}
if !worksmobileOrganizationsHavePrimary(organizations) {
organizations[0].Primary = true
@@ -288,6 +348,28 @@ func buildWorksmobileUserOrganizations(user domain.User, tenant domain.Tenant, t
return organizations, task, nil
}
func worksmobileAppointmentsContainTenant(appointments []worksmobileAppointment, tenantID string) bool {
tenantID = strings.TrimSpace(tenantID)
if tenantID == "" {
return false
}
for _, appointment := range appointments {
if strings.TrimSpace(appointment.TenantID) == tenantID {
return true
}
}
return false
}
func worksmobileAppointmentsHavePrimary(appointments []worksmobileAppointment) bool {
for _, appointment := range appointments {
if appointment.IsPrimary {
return true
}
}
return false
}
func worksmobileAppointmentsContainDomain(appointments []worksmobileAppointment, tenantByID map[string]domain.Tenant, envKey string) bool {
for _, appointment := range appointments {
tenant, ok := tenantByID[appointment.TenantID]
@@ -302,6 +384,26 @@ func worksmobileAppointmentsContainDomain(appointments []worksmobileAppointment,
return false
}
func worksmobileShouldSkipEmailDomainRootAppointment(appointment worksmobileAppointment, tenant domain.Tenant, appointments []worksmobileAppointment, tenantByID map[string]domain.Tenant) bool {
if strings.TrimSpace(appointment.Source) != "email_domain" || !isWorksmobileDomainRootTenant(tenant) {
return false
}
envKey := worksmobileTenantDomainIDEnvKey(tenant)
for _, candidate := range appointments {
if strings.TrimSpace(candidate.TenantID) == "" || strings.TrimSpace(candidate.TenantID) == tenant.ID {
continue
}
candidateTenant, ok := tenantByID[candidate.TenantID]
if !ok || isWorksmobileDomainRootTenant(candidateTenant) {
continue
}
if worksmobileTenantDomainIDEnvKey(worksmobileDomainClassificationTenant(candidateTenant, tenantByID)) == envKey {
return true
}
}
return false
}
func worksmobileOrganizationsHavePrimary(organizations []WorksmobileUserOrganization) bool {
for _, organization := range organizations {
if organization.Primary {
@@ -327,6 +429,7 @@ func worksmobileAppointmentsFromMetadata(metadata domain.JSONMap) []worksmobileA
IsPrimary: metadataBool(domain.JSONMap(item), "isPrimary", "primary"),
JobTitle: metadataString(domain.JSONMap(item), "jobTitle", "job_title", "task"),
PositionID: metadataString(domain.JSONMap(item), "worksmobilePositionId", "positionId", "position_id"),
Source: metadataString(domain.JSONMap(item), "assignmentSource", "source"),
}
if isManager, ok := metadataOptionalBool(domain.JSONMap(item), "isManager", "lead", "isLead"); ok {
appointment.IsManager = isManager
@@ -416,7 +519,7 @@ func ValidateWorksmobileAliasEmails(primaryEmail string, aliasEmails []string, e
func GenerateWorksmobileInitialPassword() string {
digits := "0123456789"
letters := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
symbols := "!@#$%^&*()-_=+[]{}"
symbols := "!@#$%"
all := digits + letters + symbols
password := []byte{

View File

@@ -84,6 +84,7 @@ func TestNormalizeRootChildWorksmobileOrgUnitParentClearsCrossDomainParent(t *te
func TestBuildWorksmobileUserPayloadMapsBaronUserAndPrimaryTenant(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
rootTenantID := "11111111-1111-1111-1111-111111111111"
tenantID := "33333333-3333-3333-3333-333333333333"
user := domain.User{
ID: "44444444-4444-4444-4444-444444444444",
@@ -98,9 +99,17 @@ func TestBuildWorksmobileUserPayloadMapsBaronUserAndPrimaryTenant(t *testing.T)
},
}
tenant := domain.Tenant{
ID: tenantID,
ID: tenantID,
Slug: "sales",
Name: "Sales",
Type: domain.TenantTypeOrganization,
ParentID: &rootTenantID,
}
rootTenant := domain.Tenant{
ID: rootTenantID,
Slug: "saman",
Name: "Saman",
Name: "삼안",
Type: domain.TenantTypeCompany,
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
}
rootConfig := domain.JSONMap{
@@ -111,7 +120,15 @@ func TestBuildWorksmobileUserPayloadMapsBaronUserAndPrimaryTenant(t *testing.T)
},
}
payload, err := BuildWorksmobileUserPayload(user, tenant, rootConfig)
payload, err := BuildWorksmobileUserPayloadForDomainTenants(
user,
tenant,
map[string]domain.Tenant{
rootTenantID: rootTenant,
tenantID: tenant,
},
rootConfig,
)
require.NoError(t, err)
require.Equal(t, int64(1001), payload.DomainID)
@@ -124,17 +141,76 @@ func TestBuildWorksmobileUserPayloadMapsBaronUserAndPrimaryTenant(t *testing.T)
require.Empty(t, payload.PrivateEmail)
require.Empty(t, payload.AliasEmails)
require.Equal(t, "ko_KR", payload.Locale)
require.Equal(t, "ADMIN", payload.PasswordConfig.PasswordCreationType)
require.Len(t, payload.PasswordConfig.Password, 16)
require.True(t, containsAny(payload.PasswordConfig.Password, "0123456789"))
require.True(t, containsAny(payload.PasswordConfig.Password, "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"))
require.True(t, containsAny(payload.PasswordConfig.Password, "!@#$%^&*()-_=+[]{}"))
require.Empty(t, payload.PasswordConfig.PasswordCreationType)
require.Empty(t, payload.PasswordConfig.Password)
require.Len(t, payload.Organizations, 1)
require.Equal(t, int64(1001), payload.Organizations[0].DomainID)
require.True(t, payload.Organizations[0].Primary)
require.Equal(t, "externalKey:"+tenantID, payload.Organizations[0].OrgUnits[0].OrgUnitID)
}
func TestBuildWorksmobileUserPayloadDeduplicatesKoreanCountryCodeInCellPhone(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
tenantID := "33333333-3333-3333-3333-333333333333"
user := domain.User{
ID: "44444444-4444-4444-4444-444444444444",
Email: "john1@samaneng.com",
Name: "John Doe",
Phone: "+82 +821091917771",
TenantID: &tenantID,
}
tenant := domain.Tenant{
ID: tenantID,
Slug: "saman",
Name: "Saman",
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
}
payload, err := BuildWorksmobileUserPayload(user, tenant, nil)
require.NoError(t, err)
require.Equal(t, "+821091917771", payload.CellPhone)
}
func TestWorksmobileUserPayloadJSONOmitsEmptyPasswordConfig(t *testing.T) {
data, err := json.Marshal(WorksmobileUserPayload{
DomainID: 1001,
Email: "target@samaneng.com",
UserExternalKey: "user-1",
UserName: WorksmobileUserName{LastName: "Target"},
})
require.NoError(t, err)
require.NotContains(t, string(data), "passwordConfig")
}
func TestBuildWorksmobileUserPayloadOmitsOrganizationsForSamanRootTenant(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
tenantID := "33333333-3333-3333-3333-333333333333"
user := domain.User{
ID: "44444444-4444-4444-4444-444444444444",
Email: "root-user@samaneng.com",
Name: "Root User",
JobTitle: "Advisor",
TenantID: &tenantID,
}
tenant := domain.Tenant{
ID: tenantID,
Slug: "saman",
Name: "삼안",
Type: domain.TenantTypeCompany,
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
}
payload, err := BuildWorksmobileUserPayload(user, tenant, nil)
require.NoError(t, err)
require.Equal(t, int64(1001), payload.DomainID)
require.Equal(t, "root-user@samaneng.com", payload.Email)
require.Equal(t, "Advisor", payload.Task)
require.Empty(t, payload.Organizations)
}
func TestBuildWorksmobileUserPayloadNormalizesLegacyCharacterMapEmployeeID(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
tenantID := "33333333-3333-3333-3333-333333333333"
@@ -168,6 +244,8 @@ func TestBuildWorksmobileUserPayloadNormalizesLegacyCharacterMapEmployeeID(t *te
func TestBuildWorksmobileUserPayloadMapsAdditionalAppointmentsToOrgUnits(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
t.Setenv("HANMAC_DOMAIN_ID", "1002")
samanRootID := "11111111-1111-1111-1111-111111111111"
hanmacRootID := "22222222-2222-2222-2222-222222222222"
primaryTenantID := "33333333-3333-3333-3333-333333333333"
secondaryTenantID := "55555555-5555-5555-5555-555555555555"
user := domain.User{
@@ -195,23 +273,41 @@ func TestBuildWorksmobileUserPayloadMapsAdditionalAppointmentsToOrgUnits(t *test
},
},
}
primaryTenant := domain.Tenant{
ID: primaryTenantID,
samanRoot := domain.Tenant{
ID: samanRootID,
Slug: "saman",
Name: "Saman",
Name: "삼안",
Type: domain.TenantTypeCompany,
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
}
secondaryTenant := domain.Tenant{
ID: secondaryTenantID,
hanmacRoot := domain.Tenant{
ID: hanmacRootID,
Slug: "hanmac",
Name: "Hanmac",
Name: "한맥기술",
Type: domain.TenantTypeCompany,
Domains: []domain.TenantDomain{{Domain: "hanmaceng.co.kr"}},
}
primaryTenant := domain.Tenant{
ID: primaryTenantID,
Slug: "saman-sales",
Name: "Saman Sales",
Type: domain.TenantTypeOrganization,
ParentID: &samanRootID,
}
secondaryTenant := domain.Tenant{
ID: secondaryTenantID,
Slug: "hanmac-sales",
Name: "Hanmac Sales",
Type: domain.TenantTypeOrganization,
ParentID: &hanmacRootID,
}
payload, err := BuildWorksmobileUserPayloadForDomainTenants(
user,
primaryTenant,
map[string]domain.Tenant{
samanRootID: samanRoot,
hanmacRootID: hanmacRoot,
primaryTenantID: primaryTenant,
secondaryTenantID: secondaryTenant,
},
@@ -234,9 +330,66 @@ func TestBuildWorksmobileUserPayloadMapsAdditionalAppointmentsToOrgUnits(t *test
require.True(t, *payload.Organizations[1].OrgUnits[0].IsManager)
}
func TestBuildWorksmobileUserPayloadKeepsPrimaryTenantWhenEmailDomainAppointmentExists(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
rootTenantID := "9caf62e1-297d-4e8f-870b-61780998bbeb"
primaryTenantID := "1edc196d-020c-4519-9ec4-3d23b99076e6"
user := domain.User{
ID: "64231465-d5c0-4085-b4a2-603b90834f86",
Email: "evenlee@samaneng.com",
Name: "이용운",
JobTitle: "부사장",
TenantID: &primaryTenantID,
Metadata: domain.JSONMap{
"additionalAppointments": []any{
map[string]any{
"tenantId": rootTenantID,
"tenantSlug": "saman",
"tenantName": "삼안",
"assignmentSource": "email_domain",
"sourceDomain": "samaneng.com",
},
},
},
}
rootTenant := domain.Tenant{
ID: rootTenantID,
Slug: "saman",
Name: "삼안",
Type: domain.TenantTypeCompany,
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
}
primaryTenant := domain.Tenant{
ID: primaryTenantID,
Slug: "asset-management",
Name: "자산관리",
Type: domain.TenantTypeOrganization,
ParentID: &rootTenantID,
}
payload, err := BuildWorksmobileUserPayloadForDomainTenants(
user,
primaryTenant,
map[string]domain.Tenant{
rootTenantID: rootTenant,
primaryTenantID: primaryTenant,
},
nil,
)
require.NoError(t, err)
require.Len(t, payload.Organizations, 1)
require.Equal(t, int64(1001), payload.Organizations[0].DomainID)
require.True(t, payload.Organizations[0].Primary)
require.Len(t, payload.Organizations[0].OrgUnits, 1)
require.Equal(t, "externalKey:"+primaryTenantID, payload.Organizations[0].OrgUnits[0].OrgUnitID)
require.True(t, payload.Organizations[0].OrgUnits[0].Primary)
}
func TestBuildWorksmobileUserPayloadKeepsFirstAffiliationPrimaryWhenBaronRepresentativeIsGPDTDC(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
t.Setenv("GPDTDC_DOMAIN_ID", "1003")
samanRootID := "11111111-1111-1111-1111-111111111111"
gpdtdcID := "5530ca6e-c5e6-4bf0-84d6-76c6a8fb70ee"
firstTenantID := "33333333-3333-3333-3333-333333333333"
secondTenantID := "55555555-5555-5555-5555-555555555555"
@@ -265,12 +418,20 @@ func TestBuildWorksmobileUserPayloadKeepsFirstAffiliationPrimaryWhenBaronReprese
Slug: "gpdtdc",
Name: "총괄기획&기술개발센터",
}
firstTenant := domain.Tenant{
ID: firstTenantID,
Slug: "rnd-saman",
Name: "삼안기술개발센터",
samanRoot := domain.Tenant{
ID: samanRootID,
Slug: "saman",
Name: "삼안",
Type: domain.TenantTypeCompany,
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
}
firstTenant := domain.Tenant{
ID: firstTenantID,
Slug: "rnd-center",
Name: "삼안기술개발센터",
Type: domain.TenantTypeOrganization,
ParentID: &samanRootID,
}
secondTenant := domain.Tenant{
ID: secondTenantID,
Slug: "tdc",
@@ -282,6 +443,7 @@ func TestBuildWorksmobileUserPayloadKeepsFirstAffiliationPrimaryWhenBaronReprese
user,
gpdtdcTenant,
map[string]domain.Tenant{
samanRootID: samanRoot,
gpdtdcID: gpdtdcTenant,
firstTenantID: firstTenant,
secondTenantID: secondTenant,
@@ -354,17 +516,12 @@ func TestBuildWorksmobileUserPayloadUsesEmailDomainForAccountDomainWhenPrimaryOr
require.NoError(t, err)
require.Equal(t, int64(1001), payload.DomainID)
require.Len(t, payload.Organizations, 2)
require.Equal(t, int64(1001), payload.Organizations[0].DomainID)
require.Len(t, payload.Organizations, 1)
require.Equal(t, int64(1003), payload.Organizations[0].DomainID)
require.True(t, payload.Organizations[0].Primary)
require.Equal(t, "dhlee@samaneng.com", payload.Organizations[0].Email)
require.Equal(t, "externalKey:"+samanID, payload.Organizations[0].OrgUnits[0].OrgUnitID)
require.Equal(t, "dhlee@baroncs.co.kr", payload.Organizations[0].Email)
require.Equal(t, "externalKey:"+leafTenantID, payload.Organizations[0].OrgUnits[0].OrgUnitID)
require.True(t, payload.Organizations[0].OrgUnits[0].Primary)
require.Equal(t, int64(1003), payload.Organizations[1].DomainID)
require.False(t, payload.Organizations[1].Primary)
require.Equal(t, "dhlee@baroncs.co.kr", payload.Organizations[1].Email)
require.Equal(t, "externalKey:"+leafTenantID, payload.Organizations[1].OrgUnits[0].OrgUnitID)
require.True(t, payload.Organizations[1].OrgUnits[0].Primary)
}
func TestWorksmobileUserPayloadJSONIncludesFalsePrimaryFields(t *testing.T) {

View File

@@ -0,0 +1,79 @@
package service
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"os"
"time"
"github.com/go-redis/redis/v8"
)
const (
worksmobileRelayLeaderLockKey = "baron:worksmobile:relay:leader"
worksmobileRelayLeaderLockTTL = 30 * time.Second
)
const worksmobileRelayLeaderRenewScript = `
if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("EXPIRE", KEYS[1], ARGV[2])
end
return 0
`
type WorksmobileRedisRelayLeaderLock struct {
client *redis.Client
key string
ttl time.Duration
ownerID string
}
func NewWorksmobileRedisRelayLeaderLock(redisService *RedisService) *WorksmobileRedisRelayLeaderLock {
if redisService == nil || redisService.Client == nil {
return nil
}
return &WorksmobileRedisRelayLeaderLock{
client: redisService.Client,
key: worksmobileRelayLeaderLockKey,
ttl: worksmobileRelayLeaderLockTTL,
ownerID: newWorksmobileRelayLeaderOwnerID(),
}
}
func (l *WorksmobileRedisRelayLeaderLock) EnsureLeadership(ctx context.Context) (bool, error) {
if l == nil || l.client == nil {
return true, nil
}
acquired, err := l.client.SetNX(ctx, l.key, l.ownerID, l.ttl).Result()
if err != nil {
return false, err
}
if acquired {
return true, nil
}
ttlSeconds := int64(l.ttl / time.Second)
if ttlSeconds <= 0 {
ttlSeconds = 30
}
result, err := l.client.Eval(ctx, worksmobileRelayLeaderRenewScript, []string{l.key}, l.ownerID, ttlSeconds).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func newWorksmobileRelayLeaderOwnerID() string {
hostname, _ := os.Hostname()
if hostname == "" {
hostname = "unknown-host"
}
randomBytes := make([]byte, 8)
if _, err := rand.Read(randomBytes); err != nil {
return fmt.Sprintf("%s:%d:%d", hostname, os.Getpid(), time.Now().UnixNano())
}
return fmt.Sprintf("%s:%d:%s", hostname, os.Getpid(), hex.EncodeToString(randomBytes))
}

View File

@@ -6,6 +6,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"sort"
"strings"
@@ -15,10 +16,15 @@ import (
type WorksmobileRelayWorker struct {
repo repository.WorksmobileOutboxRepository
client WorksmobileDirectoryClient
leaderLock WorksmobileRelayLeaderLock
interval time.Duration
batchLimit int
}
type WorksmobileRelayLeaderLock interface {
EnsureLeadership(ctx context.Context) (bool, error)
}
func NewWorksmobileRelayWorker(repo repository.WorksmobileOutboxRepository, client WorksmobileDirectoryClient) *WorksmobileRelayWorker {
return &WorksmobileRelayWorker{
repo: repo,
@@ -28,6 +34,17 @@ func NewWorksmobileRelayWorker(repo repository.WorksmobileOutboxRepository, clie
}
}
func (w *WorksmobileRelayWorker) SetLeaderLock(lock WorksmobileRelayLeaderLock) {
w.leaderLock = lock
}
func (w *WorksmobileRelayWorker) SetBatchLimit(limit int) {
if limit <= 0 {
return
}
w.batchLimit = limit
}
func (w *WorksmobileRelayWorker) Start(ctx context.Context) {
if w.repo == nil || w.client == nil {
slog.Warn("Worksmobile relay worker disabled")
@@ -49,7 +66,23 @@ func (w *WorksmobileRelayWorker) Start(ctx context.Context) {
}
}
func (w *WorksmobileRelayWorker) ProcessOnce(ctx context.Context) error {
func (w *WorksmobileRelayWorker) ProcessOnce(ctx context.Context) (err error) {
defer func() {
if recovered := recover(); recovered != nil {
err = fmt.Errorf("worksmobile relay panic: %v", recovered)
}
}()
if w.leaderLock != nil {
isLeader, err := w.leaderLock.EnsureLeadership(ctx)
if err != nil {
return err
}
if !isLeader {
return nil
}
}
jobs, err := w.repo.ListReady(ctx, w.batchLimit)
if err != nil {
return err
@@ -109,15 +142,20 @@ func (w *WorksmobileRelayWorker) dispatch(ctx context.Context, job domain.Worksm
aliasEmails := append([]string(nil), payload.AliasEmails...)
payload.AliasEmails = nil
if err := w.client.UpsertUser(ctx, payload); err != nil {
return err
return fmt.Errorf("worksmobile user upsert failed: %w", err)
}
for _, aliasEmail := range aliasEmails {
if err := w.client.AddUserAliasEmail(ctx, payload.Email, aliasEmail); err != nil {
return err
return fmt.Errorf("worksmobile user alias add failed: %w", err)
}
}
if stringValue(job.Payload["baronStatus"]) == domain.UserStatusActive {
return w.client.SetUserActive(ctx, worksmobileOutboxUserIdentifier(job), true)
if err := w.client.SetUserActive(ctx, worksmobileOutboxUserIdentifier(job), true); err != nil {
if isWorksmobileSCIMTokenNotConfiguredError(err) {
return nil
}
return fmt.Errorf("worksmobile user set active failed: %w", err)
}
}
return nil
case domain.WorksmobileActionDelete:
@@ -142,6 +180,10 @@ func (w *WorksmobileRelayWorker) dispatch(ctx context.Context, job domain.Worksm
}
}
func isWorksmobileSCIMTokenNotConfiguredError(err error) bool {
return err != nil && strings.Contains(err.Error(), "worksmobile scim token is not configured")
}
func sortWorksmobileReadyJobs(jobs []domain.WorksmobileOutbox) []domain.WorksmobileOutbox {
sorted := append([]domain.WorksmobileOutbox(nil), jobs...)
depthByID := worksmobileOrgUnitDepths(sorted)

View File

@@ -31,7 +31,7 @@ type WorksmobileAdminService interface {
EnqueueBackfillDryRun(ctx context.Context, tenantID string) (WorksmobileBackfillDryRun, error)
EnqueueOrgUnitSync(ctx context.Context, tenantID, orgUnitID string) (*domain.WorksmobileOutbox, error)
EnqueueOrgUnitDelete(ctx context.Context, tenantID, worksmobileOrgUnitID string) (*domain.WorksmobileOutbox, error)
EnqueueUserSync(ctx context.Context, tenantID, userID, credentialBatchID string) (*domain.WorksmobileOutbox, error)
EnqueueUserSync(ctx context.Context, tenantID, userID, credentialBatchID, initialPassword string) (*domain.WorksmobileOutbox, error)
EnqueueUserPasswordReset(ctx context.Context, tenantID, userID, credentialBatchID string) (*domain.WorksmobileOutbox, error)
RetryJob(ctx context.Context, tenantID, jobID string) (*domain.WorksmobileOutbox, error)
DeletePendingJobs(ctx context.Context, tenantID string) (WorksmobilePendingJobDeleteResult, error)
@@ -510,7 +510,7 @@ func (s *worksmobileSyncService) EnqueueOrgUnitDelete(ctx context.Context, tenan
return item, nil
}
func (s *worksmobileSyncService) EnqueueUserSync(ctx context.Context, tenantID, userID, credentialBatchID string) (*domain.WorksmobileOutbox, error) {
func (s *worksmobileSyncService) EnqueueUserSync(ctx context.Context, tenantID, userID, credentialBatchID, initialPassword string) (*domain.WorksmobileOutbox, error) {
root, err := s.hanmacRoot(ctx, tenantID)
if err != nil {
return nil, err
@@ -556,6 +556,13 @@ func (s *worksmobileSyncService) EnqueueUserSync(ctx context.Context, tenantID,
if err != nil {
return nil, err
}
initialPassword = strings.TrimSpace(initialPassword)
if initialPassword != "" {
payload.PasswordConfig = WorksmobilePasswordConfig{
PasswordCreationType: "ADMIN",
Password: initialPassword,
}
}
if err := s.validateUserAliasLocalParts(ctx, root, *user, payload); err != nil {
return nil, err
}
@@ -1167,10 +1174,12 @@ func normalizeWorksmobileOrgUnitParent(payload WorksmobileOrgUnitPayload, tenant
func worksmobileUserOutboxPayload(rootID string, payload WorksmobileUserPayload, statuses ...string) domain.JSONMap {
outboxPayload := domain.JSONMap{
"request": payload,
"tenantRootId": rootID,
"loginEmail": payload.Email,
"initialPassword": payload.PasswordConfig.Password,
"request": payload,
"tenantRootId": rootID,
"loginEmail": payload.Email,
}
if password := strings.TrimSpace(payload.PasswordConfig.Password); password != "" {
outboxPayload["initialPassword"] = password
}
if len(statuses) > 0 {
if status := strings.TrimSpace(statuses[0]); status != "" {
@@ -1428,7 +1437,7 @@ func compareWorksmobileUsers(localUsers []domain.User, remoteUsers []Worksmobile
excludedLocalIDs := map[string]bool{}
result := make([]WorksmobileComparisonItem, 0)
for _, user := range localUsers {
if !domain.IsWorksProvisionedUserStatus(user.Status) {
if user.DeletedAt.Valid || !domain.IsWorksProvisionedUserStatus(user.Status) {
excludedLocalIDs[user.ID] = true
if remote, ok := remoteByExternalID[user.ID]; ok {
matchedRemoteIDs[remote.ID] = true
@@ -1556,12 +1565,6 @@ func worksmobileUserNeedsUpdate(user domain.User, remote WorksmobileRemoteUser,
if worksmobileUserEmployeeNumberNeedsUpdate(user, remote) {
return true
}
if worksmobileUserOrganizationsNeedUpdate(user, remote, localTenants) {
return true
}
if worksmobileUserManagerNeedsUpdate(user, remote) {
return true
}
return false
}
@@ -1571,22 +1574,24 @@ func worksmobileUserPhoneNeedsUpdate(user domain.User, remote WorksmobileRemoteU
if localPhone == "" && remotePhone == "" {
return false
}
return localPhone != remotePhone
if localPhone != remotePhone {
return true
}
return localPhone != "" && worksmobilePhoneHasDuplicateKoreanCountryCode(remote.CellPhone)
}
func normalizeWorksmobilePhoneForCompare(value string) string {
normalized := strings.TrimSpace(value)
normalized = strings.NewReplacer("-", "", " ", "", "(", "", ")", "").Replace(normalized)
if normalized == "" {
return ""
return domain.NormalizePhoneNumber(value)
}
func worksmobilePhoneHasDuplicateKoreanCountryCode(value string) bool {
digits := strings.Builder{}
for _, r := range strings.TrimSpace(value) {
if r >= '0' && r <= '9' {
digits.WriteRune(r)
}
}
if strings.HasPrefix(normalized, "010") {
return "+82" + normalized[1:]
}
if strings.HasPrefix(normalized, "82") {
return "+" + normalized
}
return normalized
return strings.HasPrefix(digits.String(), "8282")
}
func worksmobileUserEmployeeNumberNeedsUpdate(user domain.User, remote WorksmobileRemoteUser) bool {

View File

@@ -50,7 +50,7 @@ func TestWorksmobileSyncServiceRejectsAliasEmailAlreadyUsedByOtherUser(t *testin
nil,
)
item, err := service.EnqueueUserSync(context.Background(), rootID, target.ID, "")
item, err := service.EnqueueUserSync(context.Background(), rootID, target.ID, "", "")
require.Nil(t, item)
require.Error(t, err)
@@ -90,7 +90,7 @@ func TestWorksmobileSyncServiceEnqueuesSuspendedUserStatusWithOrganizations(t *t
nil,
)
item, err := service.EnqueueUserSync(context.Background(), rootID, target.ID, "")
item, err := service.EnqueueUserSync(context.Background(), rootID, target.ID, "", "")
require.NoError(t, err)
require.NotNil(t, item)
@@ -135,7 +135,7 @@ func TestWorksmobileSyncServiceEnqueuesUserCredentialBatchID(t *testing.T) {
nil,
)
item, err := service.EnqueueUserSync(context.Background(), rootID, target.ID, "batch-1")
item, err := service.EnqueueUserSync(context.Background(), rootID, target.ID, "batch-1", "InputPass1!")
require.NoError(t, err)
require.NotNil(t, item)
@@ -144,6 +144,53 @@ func TestWorksmobileSyncServiceEnqueuesUserCredentialBatchID(t *testing.T) {
require.NotEmpty(t, outboxRepo.created[0].Payload["credentialBatchCreatedAt"])
require.Equal(t, "Target", outboxRepo.created[0].Payload["displayName"])
require.Equal(t, "Saman", outboxRepo.created[0].Payload["primaryLeafOrgName"])
require.Equal(t, "InputPass1!", outboxRepo.created[0].Payload["initialPassword"])
request, ok := outboxRepo.created[0].Payload["request"].(WorksmobileUserPayload)
require.True(t, ok)
require.Equal(t, "ADMIN", request.PasswordConfig.PasswordCreationType)
require.Equal(t, "InputPass1!", request.PasswordConfig.Password)
}
func TestWorksmobileSyncServiceDoesNotAutoGenerateInitialPassword(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
rootID := "root-tenant"
tenantID := "saman-tenant"
root := domain.Tenant{
ID: rootID,
Slug: HanmacFamilyTenantSlug,
Name: "Hanmac Family",
}
tenant := domain.Tenant{
ID: tenantID,
Slug: "saman",
Name: "Saman",
Type: domain.TenantTypeCompany,
ParentID: &rootID,
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
}
target := domain.User{
ID: "target-user",
Email: "target@samaneng.com",
Name: "Target",
Status: domain.UserStatusActive,
TenantID: &tenantID,
}
outboxRepo := &fakeWorksmobileOutboxRepo{}
service := NewWorksmobileSyncService(
&fakeWorksmobileTenantService{tenants: map[string]domain.Tenant{rootID: root, tenantID: tenant}, list: []domain.Tenant{root, tenant}},
&fakeWorksmobileUserRepo{byID: map[string]domain.User{target.ID: target}, byTenant: []domain.User{target}},
outboxRepo,
nil,
)
item, err := service.EnqueueUserSync(context.Background(), rootID, target.ID, "batch-1", "")
require.NoError(t, err)
require.NotNil(t, item)
require.NotContains(t, outboxRepo.created[0].Payload, "initialPassword")
request, ok := outboxRepo.created[0].Payload["request"].(WorksmobileUserPayload)
require.True(t, ok)
require.Empty(t, request.PasswordConfig.Password)
}
func TestWorksmobileSyncServiceEnqueuesUserPasswordResetCredentialBatch(t *testing.T) {
@@ -382,7 +429,7 @@ func TestWorksmobileSyncServiceDeprovisionsArchivedUser(t *testing.T) {
nil,
)
item, err := service.EnqueueUserSync(context.Background(), rootID, target.ID, "")
item, err := service.EnqueueUserSync(context.Background(), rootID, target.ID, "", "")
require.NoError(t, err)
require.NotNil(t, item)
@@ -1548,6 +1595,48 @@ func TestWorksmobileSyncServiceSkipsArchivedUsersInComparison(t *testing.T) {
require.Empty(t, comparison.Users)
}
func TestWorksmobileSyncServiceSkipsSoftDeletedUsersInComparison(t *testing.T) {
rootID := "root-tenant"
companyID := "company-tenant"
root := domain.Tenant{
ID: rootID,
Slug: HanmacFamilyTenantSlug,
Name: "한맥가족",
}
company := domain.Tenant{
ID: companyID,
Name: "계열사",
Type: domain.TenantTypeCompany,
ParentID: &rootID,
}
deleted := domain.User{
ID: "deleted-user",
Email: "deleted@samaneng.com",
Name: "Deleted",
TenantID: &companyID,
Status: domain.UserStatusActive,
DeletedAt: gorm.DeletedAt{
Time: time.Now(),
Valid: true,
},
}
service := NewWorksmobileSyncService(
&fakeWorksmobileTenantService{tenants: map[string]domain.Tenant{rootID: root, companyID: company}, list: []domain.Tenant{root, company}},
&fakeWorksmobileUserRepo{byTenant: []domain.User{deleted}},
&fakeWorksmobileOutboxRepo{},
&fakeWorksmobileDirectoryClient{users: []WorksmobileRemoteUser{{
ID: "works-deleted",
ExternalID: deleted.ID,
Email: deleted.Email,
}}},
)
comparison, err := service.GetComparison(context.Background(), rootID, true)
require.NoError(t, err)
require.Empty(t, comparison.Users)
}
func TestWorksmobileSyncServiceBackfillDryRunSkipsArchivedUsers(t *testing.T) {
rootID := "root-tenant"
companyID := "company-tenant"
@@ -1760,14 +1849,14 @@ func TestWorksmobileSyncServiceSkipsExcludedTenantAndUserEventSync(t *testing.T)
require.NoError(t, service.EnqueueTenantUpsertIfInScope(context.Background(), excludedOrg))
require.NoError(t, service.EnqueueTenantDeleteIfInScope(context.Background(), excludedOrg))
require.NoError(t, service.EnqueueUserUpsertIfInScope(context.Background(), user))
item, err := service.EnqueueUserSync(context.Background(), rootID, user.ID, "")
item, err := service.EnqueueUserSync(context.Background(), rootID, user.ID, "", "")
require.Nil(t, item)
require.ErrorContains(t, err, "excluded from Worksmobile sync")
require.Empty(t, outboxRepo.created)
}
func TestCompareWorksmobileUsersMarksManagerChangeNeedsUpdate(t *testing.T) {
func TestCompareWorksmobileUsersIgnoresManagerChange(t *testing.T) {
tenantID := "tenant-leaf"
user := domain.User{
ID: "user-manager",
@@ -1803,10 +1892,10 @@ func TestCompareWorksmobileUsersMarksManagerChangeNeedsUpdate(t *testing.T) {
)
require.Len(t, items, 1)
require.Equal(t, "needs_update", items[0].Status)
require.Equal(t, "matched", items[0].Status)
}
func TestCompareWorksmobileUsersMarksSecondaryManagerChangeNeedsUpdate(t *testing.T) {
func TestCompareWorksmobileUsersIgnoresSecondaryManagerChange(t *testing.T) {
primaryTenantID := "tenant-company"
secondaryTenantID := "tenant-gpdtdc-leaf"
user := domain.User{
@@ -1853,10 +1942,10 @@ func TestCompareWorksmobileUsersMarksSecondaryManagerChangeNeedsUpdate(t *testin
)
require.Len(t, items, 1)
require.Equal(t, "needs_update", items[0].Status)
require.Equal(t, "matched", items[0].Status)
}
func TestCompareWorksmobileUsersMarksMissingSecondaryOrganizationNeedsUpdate(t *testing.T) {
func TestCompareWorksmobileUsersIgnoresMissingSecondaryOrganization(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
t.Setenv("GPDTDC_DOMAIN_ID", "1003")
rootID := "tenant-root"
@@ -1916,7 +2005,7 @@ func TestCompareWorksmobileUsersMarksMissingSecondaryOrganizationNeedsUpdate(t *
)
require.Len(t, items, 1)
require.Equal(t, "needs_update", items[0].Status)
require.Equal(t, "matched", items[0].Status)
}
func TestCompareWorksmobileUsersMarksPhoneAndEmployeeNumberChangesNeedsUpdate(t *testing.T) {
@@ -1952,6 +2041,35 @@ func TestCompareWorksmobileUsersMarksPhoneAndEmployeeNumberChangesNeedsUpdate(t
require.Equal(t, "needs_update", items[0].Status)
}
func TestCompareWorksmobileUsersMarksMalformedRemoteKoreanPhoneNeedsUpdate(t *testing.T) {
tenantID := "tenant-saman"
user := domain.User{
ID: "user-phone-canonical",
Email: "phone-canonical@samaneng.com",
Name: "Phone Canonical User",
Phone: "+821062836786",
TenantID: &tenantID,
Status: domain.UserStatusActive,
}
items := compareWorksmobileUsers(
[]domain.User{user},
[]WorksmobileRemoteUser{{
ID: "works-user-phone-canonical",
ExternalID: user.ID,
Email: user.Email,
DisplayName: user.Name,
CellPhone: "+82+821062836786",
}},
true,
map[string]domain.Tenant{
tenantID: {ID: tenantID, Name: "삼안", Type: domain.TenantTypeCompany},
},
)
require.Len(t, items, 1)
require.Equal(t, "needs_update", items[0].Status)
}
type fakeWorksmobileTenantService struct {
tenants map[string]domain.Tenant
list []domain.Tenant