1
0
forked from baron/baron-sso

테넌트 목록 및 조직 계층 구조 개선

This commit is contained in:
2026-02-27 10:29:15 +09:00
parent 600961f33d
commit ca45a14bae
27 changed files with 1906 additions and 806 deletions

View File

@@ -59,7 +59,7 @@ func SeedTenants(db *gorm.DB) error {
}
slog.Info("[Bootstrap] Creating default tenant", "name", config.Name, "slug", config.Slug)
tenant, err := svc.RegisterTenant(ctx, config.Name, config.Slug, config.Description, config.Domains, nil)
tenant, err := svc.RegisterTenant(ctx, config.Name, config.Slug, domain.TenantTypeCompany, config.Description, config.Domains, nil)
if err != nil {
slog.Error("Failed to seed tenant", "slug", config.Slug, "error", err)
return err

View File

@@ -19,23 +19,23 @@ const (
type User struct {
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
Email string `gorm:"uniqueIndex;not null" json:"email"`
PasswordHash string `gorm:"not null" json:"-"`
Name string `gorm:"not null" json:"name"`
Phone string `json:"phone"`
Role string `gorm:"default:'user';not null" json:"role"` // super_admin, tenant_admin, rp_admin, user
AffiliationType string `json:"affiliationType"`
CompanyCode string `json:"companyCode"`
TenantID *string `gorm:"type:uuid;index" json:"tenantId,omitempty"`
PasswordHash *string `gorm:"column:password_hash" json:"-"`
Name string `gorm:"column:name;not null" json:"name"`
Phone string `gorm:"column:phone" json:"phone"`
Role string `gorm:"column:role;default:'user';not null" json:"role"` // super_admin, tenant_admin, rp_admin, user
AffiliationType string `gorm:"column:affiliation_type" json:"affiliationType"`
CompanyCode string `gorm:"column:company_code;index" json:"companyCode"`
TenantID *string `gorm:"column:tenant_id;type:uuid;index" json:"tenantId,omitempty"`
Tenant *Tenant `gorm:"foreignKey:TenantID" json:"tenant,omitempty"`
RelyingPartyID *string `gorm:"type:uuid;index" json:"relyingPartyId,omitempty"` // RP Admin용
Department string `json:"department"`
Position string `json:"position"` // 직급 (예: 수석, 책임, 선임)
JobTitle string `json:"jobTitle"` // 직무 (예: 프론트엔드 개발, 기획)
Metadata JSONMap `gorm:"type:jsonb" json:"metadata,omitempty"`
Status string `gorm:"default:'active'" json:"status"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
RelyingPartyID *string `gorm:"column:relying_party_id;type:uuid;index" json:"relyingPartyId,omitempty"` // RP Admin용
Department string `gorm:"column:department" json:"department"`
Position string `gorm:"column:position" json:"position"` // 직급 (예: 수석, 책임, 선임)
JobTitle string `gorm:"column:job_title" json:"jobTitle"` // 직무 (예: 프론트엔드 개발, 기획)
Metadata JSONMap `gorm:"column:metadata;type:jsonb" json:"metadata,omitempty"`
Status string `gorm:"column:status;default:'active'" json:"status"`
CreatedAt time.Time `gorm:"column:created_at" json:"createdAt"`
UpdatedAt time.Time `gorm:"column:updated_at" json:"updatedAt"`
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;index" json:"-"`
}
// BeforeCreate hook to generate UUID if not present

View File

@@ -102,6 +102,15 @@ func (m *AsyncMockUserRepo) List(ctx context.Context, offset, limit int, search
return nil, 0, nil
}
func (m *AsyncMockUserRepo) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
return 0, nil
}
func (m *AsyncMockUserRepo) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
return nil, nil
}
type AsyncMockRedisRepo struct {
mock.Mock
}
@@ -128,7 +137,7 @@ type AsyncMockTenantService struct {
mock.Mock
}
func (m *AsyncMockTenantService) RegisterTenant(ctx context.Context, name, slug, description string, domains []string, parentID *string) (*domain.Tenant, error) {
func (m *AsyncMockTenantService) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string) (*domain.Tenant, error) {
return nil, nil
}

View File

@@ -6,6 +6,7 @@ import (
"baron-sso-backend/internal/service"
"baron-sso-backend/internal/utils"
"errors"
"log/slog"
"strings"
"time"
@@ -16,15 +17,17 @@ import (
type TenantHandler struct {
DB *gorm.DB
Service service.TenantService
UserRepo repository.UserRepository
Keto service.KetoService
KetoOutbox repository.KetoOutboxRepository
KratosAdmin service.KratosAdminService
}
func NewTenantHandler(db *gorm.DB, svc service.TenantService, keto service.KetoService, outbox repository.KetoOutboxRepository, kratos service.KratosAdminService) *TenantHandler {
func NewTenantHandler(db *gorm.DB, svc service.TenantService, userRepo repository.UserRepository, keto service.KetoService, outbox repository.KetoOutboxRepository, kratos service.KratosAdminService) *TenantHandler {
return &TenantHandler{
DB: db,
Service: svc,
UserRepo: userRepo,
Keto: keto,
KetoOutbox: outbox,
KratosAdmin: kratos,
@@ -33,12 +36,15 @@ func NewTenantHandler(db *gorm.DB, svc service.TenantService, keto service.KetoS
type tenantSummary struct {
ID string `json:"id"`
Type string `json:"type"`
ParentID *string `json:"parentId"`
Name string `json:"name"`
Slug string `json:"slug"`
Description string `json:"description"`
Status string `json:"status"`
Domains []string `json:"domains,omitempty"`
Config domain.JSONMap `json:"config,omitempty"`
MemberCount int64 `json:"memberCount"`
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
}
@@ -98,6 +104,8 @@ func (h *TenantHandler) ListTenants(c *fiber.Ctx) error {
limit := c.QueryInt("limit", 50)
offset := c.QueryInt("offset", 0)
parentId := c.Query("parentId")
if limit <= 0 {
limit = 50
}
@@ -105,19 +113,45 @@ func (h *TenantHandler) ListTenants(c *fiber.Ctx) error {
offset = 0
}
// Use separate queries for count and find to avoid GORM statement contamination
countQuery := h.DB.Model(&domain.Tenant{})
if parentId != "" {
countQuery = countQuery.Where("parent_id = ?", parentId)
}
var total int64
if err := h.DB.Model(&domain.Tenant{}).Count(&total).Error; err != nil {
if err := countQuery.Count(&total).Error; err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
findQuery := h.DB.Model(&domain.Tenant{})
if parentId != "" {
findQuery = findQuery.Where("parent_id = ?", parentId)
}
var tenants []domain.Tenant
if err := h.DB.Order("created_at desc").Limit(limit).Offset(offset).Preload("Domains").Find(&tenants).Error; err != nil {
if err := findQuery.Order("created_at desc").Limit(limit).Offset(offset).Preload("Domains").Find(&tenants).Error; err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
// Fetch member counts for all tenants in one query using slugs (company codes)
slugs := make([]string, 0, len(tenants))
for _, t := range tenants {
slugs = append(slugs, t.Slug)
}
memberCounts, err := h.UserRepo.CountByCompanyCodes(c.Context(), slugs)
if err != nil {
slog.Warn("failed to count members for tenants", "error", err)
memberCounts = make(map[string]int64)
}
items := make([]tenantSummary, 0, len(tenants))
for _, t := range tenants {
items = append(items, mapTenantSummary(t))
summary := mapTenantSummary(t)
// Ensure robust matching by trimming and lowercasing the slug key
key := strings.ToLower(strings.TrimSpace(t.Slug))
summary.MemberCount = memberCounts[key]
items = append(items, summary)
}
return c.JSON(tenantListResponse{Items: items, Limit: limit, Offset: offset, Total: total})
@@ -141,7 +175,15 @@ func (h *TenantHandler) GetTenant(c *fiber.Ctx) error {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(mapTenantSummary(tenant))
memberCounts, err := h.UserRepo.CountByCompanyCodes(c.Context(), []string{tenant.Slug})
count := int64(0)
if err == nil {
count = memberCounts[strings.ToLower(tenant.Slug)]
}
summary := mapTenantSummary(tenant)
summary.MemberCount = count
return c.JSON(summary)
}
func (h *TenantHandler) CreateTenant(c *fiber.Ctx) error {
@@ -152,6 +194,7 @@ func (h *TenantHandler) CreateTenant(c *fiber.Ctx) error {
var req struct {
Name string `json:"name"`
Slug string `json:"slug"`
Type string `json:"type"`
Description string `json:"description"`
Status string `json:"status"`
Domains []string `json:"domains"`
@@ -167,6 +210,11 @@ func (h *TenantHandler) CreateTenant(c *fiber.Ctx) error {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "name is required"})
}
tenantType := normalizeTenantType(req.Type)
if tenantType == "" {
tenantType = domain.TenantTypeCompany // Default to COMPANY
}
slug := req.Slug
if slug == "" {
slug = utils.GenerateUniqueSlug(name, func(s string) bool {
@@ -193,7 +241,7 @@ func (h *TenantHandler) CreateTenant(c *fiber.Ctx) error {
parentID = &pid
}
tenant, err := h.Service.RegisterTenant(c.Context(), name, slug, req.Description, req.Domains, parentID)
tenant, err := h.Service.RegisterTenant(c.Context(), name, slug, tenantType, req.Description, req.Domains, parentID)
if err != nil {
if strings.Contains(err.Error(), "already exists") {
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": err.Error()})
@@ -201,12 +249,16 @@ func (h *TenantHandler) CreateTenant(c *fiber.Ctx) error {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
summary := mapTenantSummary(*tenant)
summary.MemberCount = 0
if req.Config != nil {
tenant.Config = req.Config
h.DB.Save(tenant)
summary.Config = tenant.Config
}
return c.Status(fiber.StatusCreated).JSON(mapTenantSummary(*tenant))
return c.Status(fiber.StatusCreated).JSON(summary)
}
func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
@@ -229,9 +281,11 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
var req struct {
Name *string `json:"name"`
Type *string `json:"type"`
Slug *string `json:"slug"`
Description *string `json:"description"`
Status *string `json:"status"`
ParentID *string `json:"parentId"`
Domains []string `json:"domains"`
Config map[string]any `json:"config"`
}
@@ -246,6 +300,13 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
}
tenant.Name = name
}
if req.Type != nil {
tenantType := normalizeTenantType(*req.Type)
if tenantType == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid tenant type"})
}
tenant.Type = tenantType
}
if req.Slug != nil {
slug := utils.GenerateSlug(*req.Slug)
if slug == "" {
@@ -271,6 +332,30 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
}
tenant.Status = status
}
if req.ParentID != nil {
pid := strings.TrimSpace(*req.ParentID)
if pid == "" {
tenant.ParentID = nil
} else {
tenant.ParentID = &pid
}
// [Keto] Sync hierarchy via Outbox
if h.KetoOutbox != nil {
if tenant.ParentID != nil {
_ = h.KetoOutbox.Create(c.Context(), &domain.KetoOutbox{
Namespace: "Tenant",
Object: tenant.ID,
Relation: "parents",
Subject: "Tenant:" + *tenant.ParentID,
Action: domain.KetoOutboxActionCreate,
})
} else {
// We don't have enough info here to delete specific parent if we don't know the old one,
// but for now we focus on adding.
}
}
}
if req.Config != nil {
tenant.Config = req.Config
}
@@ -432,6 +517,8 @@ func mapTenantSummary(t domain.Tenant) tenantSummary {
return tenantSummary{
ID: t.ID,
Type: t.Type,
ParentID: t.ParentID,
Name: t.Name,
Slug: t.Slug,
Description: t.Description,
@@ -453,3 +540,13 @@ func normalizeTenantStatus(value string) string {
}
return value
}
func normalizeTenantType(value string) string {
value = strings.ToUpper(strings.TrimSpace(value))
switch value {
case domain.TenantTypePersonal, domain.TenantTypeCompany, domain.TenantTypeCompanyGroup, domain.TenantTypeUserGroup:
return value
default:
return ""
}
}

View File

@@ -21,8 +21,8 @@ type MockTenantService struct {
mock.Mock
}
func (m *MockTenantService) RegisterTenant(ctx context.Context, name, slug, description string, domains []string, parentID *string) (*domain.Tenant, error) {
args := m.Called(ctx, name, slug, description, domains, parentID)
func (m *MockTenantService) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string) (*domain.Tenant, error) {
args := m.Called(ctx, name, slug, tenantType, description, domains, parentID)
if args.Get(0) == nil {
return nil, args.Error(1)
}
@@ -85,7 +85,7 @@ func TestTenantHandler_CreateTenant(t *testing.T) {
}
body, _ := json.Marshal(input)
mockSvc.On("RegisterTenant", mock.Anything, "Test Tenant", "test-tenant", "", []string{"test.com"}, (*string)(nil)).
mockSvc.On("RegisterTenant", mock.Anything, "Test Tenant", "test-tenant", domain.TenantTypeCompany, "", []string{"test.com"}, (*string)(nil)).
Return(&domain.Tenant{ID: "t1", Name: "Test Tenant", Slug: "test-tenant"}, nil)
req := httptest.NewRequest("POST", "/tenants", bytes.NewReader(body))

View File

@@ -68,6 +68,7 @@ func (h *UserHandler) ListUsers(c *fiber.Ctx) error {
limit := c.QueryInt("limit", 50)
offset := c.QueryInt("offset", 0)
search := strings.TrimSpace(c.Query("search"))
companyCode := strings.TrimSpace(c.Query("companyCode"))
if limit <= 0 {
limit = 50
@@ -89,14 +90,21 @@ func (h *UserHandler) ListUsers(c *fiber.Ctx) error {
// Tenant Admin filtering
if requesterRole == domain.RoleTenantAdmin {
if requesterCompany == "" || compCode != requesterCompany {
if requesterCompany == "" || !strings.EqualFold(compCode, requesterCompany) {
continue
}
}
// Search filtering
// Dedicated companyCode filter
if companyCode != "" && !strings.EqualFold(compCode, companyCode) {
continue
}
// Search filtering (Keyword search in email, name, or companyCode)
if search != "" {
if !strings.Contains(email, searchLower) && !strings.Contains(name, searchLower) {
if !strings.Contains(email, searchLower) &&
!strings.Contains(name, searchLower) &&
!strings.Contains(strings.ToLower(compCode), searchLower) {
continue
}
}
@@ -118,14 +126,27 @@ func (h *UserHandler) ListUsers(c *fiber.Ctx) error {
items = append(items, summary)
}
// [Lazy Sync] Asynchronously update local DB with fresh data from Kratos
// This ensures that member counts (which use local DB) eventually match reality
if h.UserRepo != nil {
go func(ids []service.KratosIdentity) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
for _, identity := range ids {
localUser := h.mapToLocalUser(identity)
_ = h.UserRepo.Update(ctx, localUser)
}
}(filtered)
}
return c.JSON(userListResponse{Items: items, Limit: limit, Offset: offset, Total: total})
}
// 2. Fallback to Local DB if Kratos is down (Development only recommended)
// 2. Fallback to Local DB if Kratos is down
slog.Warn("Kratos unavailable, falling back to local DB for user list", "error", err)
// Fetch from UserRepo
users, total, err := h.UserRepo.List(c.Context(), offset, limit, search)
users, total, err := h.UserRepo.List(c.Context(), offset, limit, search, companyCode)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to fetch users from both kratos and local db"})
}
@@ -289,66 +310,7 @@ func (h *UserHandler) CreateUser(c *fiber.Ctx) error {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
// [New] Local DB Sync
localUser := &domain.User{
ID: identityID,
Email: email,
Name: name,
Phone: normalizePhoneNumber(req.Phone),
AffiliationType: "internal",
CompanyCode: req.CompanyCode,
Department: req.Department,
Role: role,
Status: "active",
Metadata: req.Metadata,
}
if tenantID != "" {
localUser.TenantID = &tenantID
}
// [SoT Policy] Kratos가 SoT이므로 로컬 DB 저장은 비동기 Read-Model 동기화로 처리합니다.
if h.UserRepo != nil {
go func(u *domain.User) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.UserRepo.Create(ctx, u); err != nil {
slog.Error("[UserHandler] Failed to sync user to local DB", "email", u.Email, "error", err)
}
}(localUser)
}
// [Keto] Sync relations via Outbox
if h.KetoOutboxRepo != nil {
// 1. Tenant Membership
if localUser.TenantID != nil {
_ = h.KetoOutboxRepo.Create(c.Context(), &domain.KetoOutbox{
Namespace: "Tenant",
Object: *localUser.TenantID,
Relation: "members",
Subject: "User:" + identityID,
Action: domain.KetoOutboxActionCreate,
})
}
// 2. Role Specifics
if role == domain.RoleSuperAdmin {
_ = h.KetoOutboxRepo.Create(c.Context(), &domain.KetoOutbox{
Namespace: "System",
Object: "global",
Relation: "super_admins",
Subject: "User:" + identityID,
Action: domain.KetoOutboxActionCreate,
})
} else if role == domain.RoleTenantAdmin && localUser.TenantID != nil {
_ = h.KetoOutboxRepo.Create(c.Context(), &domain.KetoOutbox{
Namespace: "Tenant",
Object: *localUser.TenantID,
Relation: "admins",
Subject: "User:" + identityID,
Action: domain.KetoOutboxActionCreate,
})
}
}
// Fetch the newly created identity to ensure we have all traits
identity, err := h.KratosAdmin.GetIdentity(c.Context(), identityID)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
@@ -357,6 +319,28 @@ func (h *UserHandler) CreateUser(c *fiber.Ctx) error {
return c.Status(fiber.StatusCreated).JSON(fiber.Map{"id": identityID, "initialPassword": generatedPassword})
}
// [New] Local DB Sync - Ensure user exists in read-model
if h.UserRepo != nil {
localUser := h.mapToLocalUser(*identity)
// Sync to local DB
go func(u *domain.User, role string, tID *string) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
// Use Update (upsert) instead of Create for robustness
if err := h.UserRepo.Update(ctx, u); err != nil {
slog.Error("[UserHandler] Failed to sync new user to local DB", "email", u.Email, "error", err)
return
}
// [Keto] Sync relations via Outbox
if h.KetoOutboxRepo != nil {
h.syncKetoRole(ctx, u.ID, role, "", "", tID)
}
}(localUser, role, localUser.TenantID)
}
response := h.mapIdentitySummary(c.Context(), *identity)
if generatedPassword != "" {
response.InitialPassword = generatedPassword
@@ -382,6 +366,18 @@ func (h *UserHandler) UpdateUser(c *fiber.Ctx) error {
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "user not found"})
}
// Capture current local state for transition comparison
var oldRole string
var oldTenantID string
if h.UserRepo != nil {
if local, err := h.UserRepo.FindByID(c.Context(), userID); err == nil && local != nil {
oldRole = local.Role
if local.TenantID != nil {
oldTenantID = *local.TenantID
}
}
}
// [New] Check access scope
requester, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
if requester != nil && requester.Role == domain.RoleTenantAdmin {
@@ -420,7 +416,12 @@ func (h *UserHandler) UpdateUser(c *fiber.Ctx) error {
traits["name"] = strings.TrimSpace(*req.Name)
}
if req.Phone != nil {
traits["phone_number"] = normalizePhoneNumber(strings.TrimSpace(*req.Phone))
phone := normalizePhoneNumber(strings.TrimSpace(*req.Phone))
if phone == "" {
delete(traits, "phone_number")
} else {
traits["phone_number"] = phone
}
}
if req.CompanyCode != nil {
code := strings.TrimSpace(*req.CompanyCode)
@@ -471,92 +472,18 @@ func (h *UserHandler) UpdateUser(c *fiber.Ctx) error {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
// [New] Local DB Sync
// [New] Local DB Sync - Sync synchronously to ensure immediate consistency for the caller
if h.UserRepo != nil {
if localUser, err := h.UserRepo.FindByID(c.Context(), userID); err == nil && localUser != nil {
oldRole := localUser.Role
oldTenantID := ""
if localUser.TenantID != nil {
oldTenantID = *localUser.TenantID
}
if req.Name != nil {
localUser.Name = *req.Name
}
if req.Phone != nil {
localUser.Phone = normalizePhoneNumber(*req.Phone)
}
if req.CompanyCode != nil {
localUser.CompanyCode = *req.CompanyCode
if tenant, err := h.TenantService.GetTenantBySlug(c.Context(), *req.CompanyCode); err == nil && tenant != nil {
localUser.TenantID = &tenant.ID
}
}
if req.Department != nil {
localUser.Department = *req.Department
}
if req.Role != nil {
localUser.Role = *req.Role
}
if req.Status != nil {
localUser.Status = *req.Status
}
if req.Metadata != nil {
localUser.Metadata = req.Metadata
}
// [SoT Policy] Kratos가 SoT이므로 로컬 DB 저장은 비동기 Read-Model 동기화로 처리합니다.
// [ReBAC Policy] 로컬 DB와 Keto 간의 정합성을 위해 Outbox를 함께 기록합니다.
go func(u *domain.User, rRole *string, oRole string, oTenantID string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.UserRepo.Update(ctx, u); err == nil {
// [Keto Sync on Role Change] via Outbox
if h.KetoOutboxRepo != nil && rRole != nil && *rRole != oRole {
uID := u.ID
newR := *rRole
if oRole == domain.RoleSuperAdmin {
_ = h.KetoOutboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "System",
Object: "global",
Relation: "super_admins",
Subject: "User:" + uID,
Action: domain.KetoOutboxActionDelete,
})
} else if oRole == domain.RoleTenantAdmin && oTenantID != "" {
_ = h.KetoOutboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: oTenantID,
Relation: "admins",
Subject: "User:" + uID,
Action: domain.KetoOutboxActionDelete,
})
}
if newR == domain.RoleSuperAdmin {
_ = h.KetoOutboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "System",
Object: "global",
Relation: "super_admins",
Subject: "User:" + uID,
Action: domain.KetoOutboxActionCreate,
})
} else if newR == domain.RoleTenantAdmin && u.TenantID != nil {
_ = h.KetoOutboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: *u.TenantID,
Relation: "admins",
Subject: "User:" + uID,
Action: domain.KetoOutboxActionCreate,
})
}
}
} else {
slog.Error("[UserHandler] Failed to sync user update to local DB", "userID", u.ID, "error", err)
}
}(localUser, req.Role, oldRole, oldTenantID)
updatedLocalUser := h.mapToLocalUser(*updated)
ctx := context.Background() // Use request context if appropriate, but sync must finish
if err := h.UserRepo.Update(ctx, updatedLocalUser); err != nil {
slog.Error("[UserHandler] Failed to sync updated user to local DB", "userID", updatedLocalUser.ID, "error", err)
}
// [Keto Sync] asynchronously as it's less critical for immediate UI count
go h.syncKetoRole(context.Background(), updatedLocalUser.ID,
extractTraitString(updated.Traits, "grade"), oldRole, oldTenantID, updatedLocalUser.TenantID)
}
if req.Password != nil && *req.Password != "" {
@@ -654,6 +581,97 @@ func (h *UserHandler) mapIdentitySummary(ctx context.Context, identity service.K
return summary
}
func (h *UserHandler) normalizePhoneNumber(phone string) string {
return normalizePhoneNumber(phone)
}
func (h *UserHandler) mapToLocalUser(identity service.KratosIdentity) *domain.User {
traits := identity.Traits
role := extractTraitString(traits, "grade")
if role == "" {
role = "user"
}
compCode := extractTraitString(traits, "companyCode")
user := &domain.User{
ID: identity.ID,
Email: extractTraitString(traits, "email"),
Name: extractTraitString(traits, "name"),
Phone: extractTraitString(traits, "phone_number"),
Role: role,
Status: normalizeStatus(identity.State),
CompanyCode: compCode,
Department: extractTraitString(traits, "department"),
AffiliationType: extractTraitString(traits, "affiliationType"),
CreatedAt: identity.CreatedAt,
UpdatedAt: identity.UpdatedAt,
}
if compCode != "" && h.TenantService != nil {
// Use a background context or a timeout-limited context for tenant lookup
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if tenant, err := h.TenantService.GetTenantBySlug(ctx, compCode); err == nil && tenant != nil {
user.TenantID = &tenant.ID
}
}
// Metadata
user.Metadata = make(domain.JSONMap)
coreTraits := map[string]bool{
"email": true, "name": true, "phone_number": true,
"grade": true, "companyCode": true, "department": true,
"affiliationType": true, "role": true, "tenant_id": true,
}
for k, v := range traits {
if !coreTraits[k] {
user.Metadata[k] = v
}
}
return user
}
func (h *UserHandler) syncKetoRole(ctx context.Context, userID, newRole, oldRole, oldTenantID string, newTenantID *string) {
// Remove old roles
if oldRole == domain.RoleSuperAdmin {
_ = h.KetoOutboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "System",
Object: "global",
Relation: "super_admins",
Subject: "User:" + userID,
Action: domain.KetoOutboxActionDelete,
})
} else if oldRole == domain.RoleTenantAdmin && oldTenantID != "" {
_ = h.KetoOutboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: oldTenantID,
Relation: "admins",
Subject: "User:" + userID,
Action: domain.KetoOutboxActionDelete,
})
}
// Add new roles
if newRole == domain.RoleSuperAdmin {
_ = h.KetoOutboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "System",
Object: "global",
Relation: "super_admins",
Subject: "User:" + userID,
Action: domain.KetoOutboxActionCreate,
})
} else if newRole == domain.RoleTenantAdmin && newTenantID != nil {
_ = h.KetoOutboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: *newTenantID,
Relation: "admins",
Subject: "User:" + userID,
Action: domain.KetoOutboxActionCreate,
})
}
}
func extractTraitString(traits map[string]interface{}, key string) string {
if traits == nil {
return ""

View File

@@ -3,6 +3,7 @@ package repository
import (
"baron-sso-backend/internal/domain"
"context"
"strings"
"gorm.io/gorm"
)
@@ -14,7 +15,10 @@ type UserRepository interface {
FindByID(ctx context.Context, id string) (*domain.User, error)
FindByIDs(ctx context.Context, ids []string) ([]domain.User, error)
ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error)
List(ctx context.Context, offset, limit int, search string) ([]domain.User, int64, error)
List(ctx context.Context, offset, limit int, search string, companyCode string) ([]domain.User, int64, error)
CountByTenant(ctx context.Context, tenantID string) (int64, error)
CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error)
CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error)
Delete(ctx context.Context, id string) error
}
@@ -69,14 +73,111 @@ func (r *userRepository) ListByTenant(ctx context.Context, tenantID string) ([]d
return users, nil
}
func (r *userRepository) List(ctx context.Context, offset, limit int, search string) ([]domain.User, int64, error) {
func (r *userRepository) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("tenant_id = ?", tenantID).Count(&count).Error
return count, err
}
func (r *userRepository) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
type result struct {
TenantID string
Count int64
}
var results []result
if len(tenantIDs) == 0 {
return make(map[string]int64), nil
}
if err := r.db.WithContext(ctx).Model(&domain.User{}).
Select("tenant_id, count(*) as count").
Where("tenant_id IN ?", tenantIDs).
Group("tenant_id").
Find(&results).Error; err != nil {
return nil, err
}
counts := make(map[string]int64)
for _, res := range results {
if res.TenantID != "" {
counts[res.TenantID] = res.Count
}
}
// Ensure all requested tenant IDs are in the map, even if count is 0
for _, id := range tenantIDs {
if _, ok := counts[id]; !ok {
counts[id] = 0
}
}
return counts, nil
}
func (r *userRepository) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) {
if len(codes) == 0 {
return make(map[string]int64), nil
}
// 1. Resolve IDs for these codes to support dual counting (slug or ID)
var tenants []domain.Tenant
_ = r.db.WithContext(ctx).Where("slug IN ?", codes).Find(&tenants).Error
idToSlug := make(map[string]string)
slugToNormalized := make(map[string]string)
for _, code := range codes {
slugToNormalized[strings.ToLower(strings.TrimSpace(code))] = code
}
for _, t := range tenants {
idToSlug[t.ID] = t.Slug
}
type result struct {
CompanyCode string
TenantID string
Count int64
}
var results []result
// Use a more comprehensive aggregation
err := r.db.WithContext(ctx).Model(&domain.User{}).
Select("company_code, tenant_id, count(*) as count").
Where("company_code IN ? OR tenant_id IN (SELECT id FROM tenants WHERE slug IN ?)", codes, codes).
Group("company_code, tenant_id").
Scan(&results).Error
if err != nil {
return nil, err
}
counts := make(map[string]int64)
for _, res := range results {
var slug string
if res.CompanyCode != "" {
slug = res.CompanyCode
} else if res.TenantID != "" {
slug = idToSlug[res.TenantID]
}
if slug != "" {
normalizedSlug := strings.ToLower(strings.TrimSpace(slug))
counts[normalizedSlug] += res.Count
}
}
return counts, nil
}
func (r *userRepository) List(ctx context.Context, offset, limit int, search string, companyCode string) ([]domain.User, int64, error) {
var users []domain.User
var total int64
db := r.db.WithContext(ctx).Model(&domain.User{})
if companyCode != "" {
db = db.Where("company_code = ?", companyCode)
}
if search != "" {
searchTerm := "%" + search + "%"
db = db.Where("email LIKE ? OR name LIKE ?", searchTerm, searchTerm)
db = db.Where("(email LIKE ? OR name LIKE ? OR company_code LIKE ?)", searchTerm, searchTerm, searchTerm)
}
if err := db.Count(&total).Error; err != nil {

View File

@@ -13,7 +13,7 @@ import (
)
type TenantService interface {
RegisterTenant(ctx context.Context, name, slug, description string, domains []string, parentID *string) (*domain.Tenant, error)
RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string) (*domain.Tenant, error)
RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error)
GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error)
GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error)
@@ -89,7 +89,7 @@ func (s *tenantService) ListManageableTenants(ctx context.Context, userID string
return s.repo.FindByIDs(ctx, allIDs)
}
func (s *tenantService) RegisterTenant(ctx context.Context, name, slug, description string, domains []string, parentID *string) (*domain.Tenant, error) {
func (s *tenantService) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string) (*domain.Tenant, error) {
// Validate Slug
if ok, msg := utils.ValidateSlug(slug); !ok {
return nil, errors.New(msg)
@@ -106,7 +106,7 @@ func (s *tenantService) RegisterTenant(ctx context.Context, name, slug, descript
// 2. Create Tenant
tenant := &domain.Tenant{
Type: domain.TenantTypeCompany, // Default to COMPANY for manual registration
Type: tenantType,
Name: name,
Slug: slug,
Description: description,

View File

@@ -21,7 +21,7 @@ func TestTenantService_RegisterTenant_DuplicateSlug(t *testing.T) {
// Mock: slug already exists
mockRepo.On("FindBySlug", ctx, slug).Return(&domain.Tenant{ID: "existing-id", Slug: slug}, nil)
tenant, err := svc.RegisterTenant(ctx, "New Name", slug, "", nil, nil)
tenant, err := svc.RegisterTenant(ctx, "New Name", slug, domain.TenantTypeCompany, "", nil, nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "already exists")
assert.Nil(t, tenant)
@@ -32,11 +32,11 @@ func TestTenantService_RegisterTenant_InvalidSlug(t *testing.T) {
ctx := context.Background()
// Case 1: Too short
_, err := svc.RegisterTenant(ctx, "Name", "a", "", nil, nil)
_, err := svc.RegisterTenant(ctx, "Name", "a", domain.TenantTypeCompany, "", nil, nil)
assert.Error(t, err)
// Case 2: Invalid characters
_, err = svc.RegisterTenant(ctx, "Name", "Invalid Slug!", "", nil, nil)
_, err = svc.RegisterTenant(ctx, "Name", "Invalid Slug!", domain.TenantTypeCompany, "", nil, nil)
assert.Error(t, err)
}

View File

@@ -120,6 +120,21 @@ func (m *MockUserRepoForTenant) List(ctx context.Context, offset, limit int, sea
return nil, 0, nil
}
func (m *MockUserRepoForTenant) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
args := m.Called(tenantID)
return int64(args.Int(0)), args.Error(1)
}
func (m *MockUserRepoForTenant) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
args := m.Called(tenantIDs)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(map[string]int64), args.Error(1)
}
func TestTenantService_RegisterTenant_AutoVerify(t *testing.T) {
mockRepo := new(MockTenantRepoForSvc)
mockOutbox := new(MockKetoOutboxRepositoryShared)
@@ -136,7 +151,7 @@ func TestTenantService_RegisterTenant_AutoVerify(t *testing.T) {
mockRepo.On("AddDomain", ctx, mock.Anything, "example.com", true).Return(nil)
mockRepo.On("FindBySlug", ctx, slug).Return(&domain.Tenant{ID: "t1", Slug: slug}, nil).Once()
tenant, err := svc.RegisterTenant(ctx, name, slug, "", domains, nil)
tenant, err := svc.RegisterTenant(ctx, name, slug, domain.TenantTypeCompany, "", domains, nil)
assert.NoError(t, err)
assert.NotNil(t, tenant)
assert.Equal(t, "t1", tenant.ID)

View File

@@ -81,6 +81,20 @@ func (m *MockUserRepository) List(ctx context.Context, offset, limit int, search
return nil, 0, nil
}
func (m *MockUserRepository) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
args := m.Called(tenantID)
return int64(args.Int(0)), args.Error(1)
}
func (m *MockUserRepository) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
args := m.Called(tenantIDs)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(map[string]int64), args.Error(1)
}
type MockTenantRepository struct {
mock.Mock
}