forked from baron/baron-sso
Implement tenant import and RP auto login policies
This commit is contained in:
@@ -4,15 +4,33 @@ import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"baron-sso-backend/internal/service"
|
||||
"baron-sso-backend/internal/utils"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const seedTenantCSVPathEnv = "SEED_TENANT_CSV_PATH"
|
||||
|
||||
var seedTenantCSVPathCandidates = []string{
|
||||
"adminfront/seed-tenant.csv",
|
||||
"../adminfront/seed-tenant.csv",
|
||||
"../../adminfront/seed-tenant.csv",
|
||||
"../../../adminfront/seed-tenant.csv",
|
||||
"/app/adminfront/seed-tenant.csv",
|
||||
}
|
||||
|
||||
type InitialTenantConfig struct {
|
||||
TenantID string
|
||||
Name string
|
||||
Slug string
|
||||
Type string
|
||||
@@ -21,32 +39,31 @@ type InitialTenantConfig struct {
|
||||
Domains []string
|
||||
}
|
||||
|
||||
// Hardcoded for now, can be moved to config file or env later
|
||||
var defaultTenants = []InitialTenantConfig{
|
||||
{
|
||||
Name: "한맥가족",
|
||||
Slug: "hanmac-family",
|
||||
Type: domain.TenantTypeCompanyGroup,
|
||||
},
|
||||
{
|
||||
Name: "한맥기술",
|
||||
Slug: "hanmac",
|
||||
Type: domain.TenantTypeCompany,
|
||||
ParentSlug: "hanmac-family",
|
||||
Description: "Primary Family Company",
|
||||
Domains: []string{"hanmaceng.co.kr", "hmac.kr"},
|
||||
},
|
||||
{
|
||||
Name: "삼안",
|
||||
Slug: "saman",
|
||||
Type: domain.TenantTypeCompany,
|
||||
ParentSlug: "hanmac-family",
|
||||
Domains: []string{"samaneng.com"},
|
||||
},
|
||||
func SeedTenants(db *gorm.DB) error {
|
||||
slog.Info("[Bootstrap] Checking initial tenant seed...")
|
||||
|
||||
var tenantCount int64
|
||||
if err := db.Model(&domain.Tenant{}).Count(&tenantCount).Error; err != nil {
|
||||
return fmt.Errorf("count tenants before seed: %w", err)
|
||||
}
|
||||
if tenantCount > 0 {
|
||||
slog.Info("[Bootstrap] Tenant seed skipped because tenants already exist", "count", tenantCount)
|
||||
return nil
|
||||
}
|
||||
|
||||
configs, err := loadSeedTenantConfigs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(configs) == 0 {
|
||||
return errors.New("seed tenant csv has no tenant rows")
|
||||
}
|
||||
|
||||
return seedTenantConfigs(db, configs)
|
||||
}
|
||||
|
||||
func SeedTenants(db *gorm.DB) error {
|
||||
slog.Info("[Bootstrap] Seeding initial tenants...")
|
||||
func seedTenantConfigs(db *gorm.DB, configs []InitialTenantConfig) error {
|
||||
slog.Info("[Bootstrap] Seeding initial tenants from CSV...", "count", len(configs))
|
||||
repo := repository.NewTenantRepository(db)
|
||||
userRepo := repository.NewUserRepository(db)
|
||||
userGroupRepo := repository.NewUserGroupRepository(db)
|
||||
@@ -54,7 +71,7 @@ func SeedTenants(db *gorm.DB) error {
|
||||
svc := service.NewTenantService(repo, userRepo, userGroupRepo, outboxRepo)
|
||||
ctx := context.Background()
|
||||
|
||||
for _, config := range defaultTenants {
|
||||
for _, config := range orderSeedTenantConfigsByParentSlug(configs) {
|
||||
tenantType := config.Type
|
||||
if tenantType == "" {
|
||||
tenantType = domain.TenantTypeCompany
|
||||
@@ -73,75 +90,273 @@ func SeedTenants(db *gorm.DB) error {
|
||||
parentID = &parent.ID
|
||||
}
|
||||
|
||||
existing, err := repo.FindBySlug(ctx, config.Slug)
|
||||
if err == nil && existing != nil {
|
||||
slog.Info("[Bootstrap] Tenant already exists, checking domains...", "slug", config.Slug)
|
||||
changed := false
|
||||
if existing.Name != config.Name {
|
||||
existing.Name = config.Name
|
||||
changed = true
|
||||
}
|
||||
if existing.Type != tenantType {
|
||||
existing.Type = tenantType
|
||||
changed = true
|
||||
}
|
||||
if existing.Status != domain.TenantStatusActive {
|
||||
existing.Status = domain.TenantStatusActive
|
||||
changed = true
|
||||
}
|
||||
if config.ParentSlug != "" {
|
||||
if existing.ParentID == nil || *existing.ParentID != *parentID {
|
||||
existing.ParentID = parentID
|
||||
changed = true
|
||||
if err := outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: existing.ID,
|
||||
Relation: "parents",
|
||||
Subject: "Tenant:" + *parentID,
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
}); err != nil {
|
||||
slog.Error("Failed to create outbox entry for seeded tenant hierarchy", "tenant", existing.ID, "error", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else if existing.ParentID != nil {
|
||||
existing.ParentID = nil
|
||||
changed = true
|
||||
}
|
||||
if changed {
|
||||
if err := repo.Update(ctx, existing); err != nil {
|
||||
slog.Error("Failed to update seeded tenant", "slug", config.Slug, "error", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Optional: Check and add missing domains
|
||||
for _, d := range config.Domains {
|
||||
found := false
|
||||
for _, ed := range existing.Domains {
|
||||
if ed.Domain == d {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
slog.Info("[Bootstrap] Adding missing domain to tenant", "slug", config.Slug, "domain", d)
|
||||
if err := repo.AddDomain(ctx, existing.ID, d, true); err != nil {
|
||||
slog.Error("Failed to add domain", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
slog.Info("[Bootstrap] Creating seed tenant", "name", config.Name, "slug", config.Slug)
|
||||
var tenant *domain.Tenant
|
||||
var err error
|
||||
if config.TenantID != "" {
|
||||
tenant, err = createSeedTenant(ctx, repo, outboxRepo, config, tenantType, parentID)
|
||||
} else {
|
||||
tenant, err = svc.RegisterTenant(ctx, config.Name, config.Slug, tenantType, config.Description, config.Domains, parentID, "")
|
||||
}
|
||||
|
||||
slog.Info("[Bootstrap] Creating default tenant", "name", config.Name, "slug", config.Slug)
|
||||
tenant, err := svc.RegisterTenant(ctx, config.Name, config.Slug, tenantType, config.Description, config.Domains, parentID, "")
|
||||
if err != nil {
|
||||
slog.Error("Failed to seed tenant", "slug", config.Slug, "error", err)
|
||||
return err
|
||||
}
|
||||
// Explicitly set to active during seed
|
||||
tenant.Status = domain.TenantStatusActive
|
||||
db.Save(tenant)
|
||||
if err := db.Save(tenant).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadSeedTenantConfigs() ([]InitialTenantConfig, error) {
|
||||
path, err := findSeedTenantCSVPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open seed tenant csv %q: %w", path, err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
configs, err := parseSeedTenantCSV(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse seed tenant csv %q: %w", path, err)
|
||||
}
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
func findSeedTenantCSVPath() (string, error) {
|
||||
if configured := strings.TrimSpace(os.Getenv(seedTenantCSVPathEnv)); configured != "" {
|
||||
return configured, nil
|
||||
}
|
||||
|
||||
for _, candidate := range seedTenantCSVPathCandidates {
|
||||
cleaned := filepath.Clean(candidate)
|
||||
if _, err := os.Stat(cleaned); err == nil {
|
||||
return cleaned, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("seed tenant csv not found; set %s or add adminfront/seed-tenant.csv", seedTenantCSVPathEnv)
|
||||
}
|
||||
|
||||
func parseSeedTenantCSV(r io.Reader) ([]InitialTenantConfig, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to read csv")
|
||||
}
|
||||
data = bytes.TrimPrefix(data, []byte{0xEF, 0xBB, 0xBF})
|
||||
|
||||
reader := csv.NewReader(bytes.NewReader(data))
|
||||
reader.FieldsPerRecord = -1
|
||||
rows, err := reader.ReadAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid csv: %w", err)
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return nil, errors.New("csv is empty")
|
||||
}
|
||||
|
||||
header := seedTenantCSVHeaderIndex(rows[0])
|
||||
for _, key := range []string{"name", "type", "slug"} {
|
||||
if _, ok := header[key]; !ok {
|
||||
return nil, fmt.Errorf("missing required column: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
configs := make([]InitialTenantConfig, 0, len(rows)-1)
|
||||
for i, row := range rows[1:] {
|
||||
if seedTenantCSVRowIsEmpty(row) {
|
||||
continue
|
||||
}
|
||||
|
||||
name := seedTenantCSVValue(row, header, "name")
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("row %d: name is required", i+2)
|
||||
}
|
||||
|
||||
tenantType := normalizeSeedTenantType(seedTenantCSVValue(row, header, "type"))
|
||||
if tenantType == "" {
|
||||
return nil, fmt.Errorf("row %d: invalid tenant type", i+2)
|
||||
}
|
||||
|
||||
slug := utils.GenerateSlug(seedTenantCSVValue(row, header, "slug"))
|
||||
if slug == "" {
|
||||
return nil, fmt.Errorf("row %d: slug is required", i+2)
|
||||
}
|
||||
|
||||
configs = append(configs, InitialTenantConfig{
|
||||
TenantID: seedTenantCSVValue(row, header, "tenant_id"),
|
||||
Name: name,
|
||||
Type: tenantType,
|
||||
ParentSlug: seedTenantCSVValue(row, header, "parent_tenant_slug"),
|
||||
Slug: slug,
|
||||
Description: seedTenantCSVValue(row, header, "memo"),
|
||||
Domains: splitSeedTenantCSVDomains(seedTenantCSVValue(row, header, "email_domain")),
|
||||
})
|
||||
}
|
||||
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
func seedTenantCSVHeaderIndex(header []string) map[string]int {
|
||||
index := make(map[string]int, len(header))
|
||||
aliases := map[string]string{
|
||||
"id": "tenant_id",
|
||||
"tenantid": "tenant_id",
|
||||
"tenant_id": "tenant_id",
|
||||
"name": "name",
|
||||
"type": "type",
|
||||
"parenttenantslug": "parent_tenant_slug",
|
||||
"parent_tenant_slug": "parent_tenant_slug",
|
||||
"parent_slug": "parent_tenant_slug",
|
||||
"slug": "slug",
|
||||
"memo": "memo",
|
||||
"description": "memo",
|
||||
"email-domain": "email_domain",
|
||||
"emaildomain": "email_domain",
|
||||
"email_domain": "email_domain",
|
||||
"domain": "email_domain",
|
||||
"domains": "email_domain",
|
||||
}
|
||||
for i, column := range header {
|
||||
key := strings.ToLower(strings.TrimSpace(column))
|
||||
key = strings.ReplaceAll(key, " ", "_")
|
||||
if canonical, ok := aliases[key]; ok {
|
||||
index[canonical] = i
|
||||
}
|
||||
}
|
||||
return index
|
||||
}
|
||||
|
||||
func seedTenantCSVValue(row []string, header map[string]int, key string) string {
|
||||
idx, ok := header[key]
|
||||
if !ok || idx >= len(row) {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(row[idx])
|
||||
}
|
||||
|
||||
func seedTenantCSVRowIsEmpty(row []string) bool {
|
||||
for _, value := range row {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func normalizeSeedTenantType(value string) string {
|
||||
switch strings.ToUpper(strings.TrimSpace(value)) {
|
||||
case domain.TenantTypePersonal:
|
||||
return domain.TenantTypePersonal
|
||||
case domain.TenantTypeCompany:
|
||||
return domain.TenantTypeCompany
|
||||
case domain.TenantTypeCompanyGroup:
|
||||
return domain.TenantTypeCompanyGroup
|
||||
case domain.TenantTypeUserGroup:
|
||||
return domain.TenantTypeUserGroup
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func splitSeedTenantCSVDomains(value string) []string {
|
||||
value = strings.ReplaceAll(value, "\n", ";")
|
||||
value = strings.ReplaceAll(value, ",", ";")
|
||||
parts := strings.Split(value, ";")
|
||||
domains := make([]string, 0, len(parts))
|
||||
seen := make(map[string]bool, len(parts))
|
||||
for _, part := range parts {
|
||||
domainName := strings.ToLower(strings.TrimSpace(part))
|
||||
if domainName == "" || seen[domainName] {
|
||||
continue
|
||||
}
|
||||
seen[domainName] = true
|
||||
domains = append(domains, domainName)
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
func orderSeedTenantConfigsByParentSlug(configs []InitialTenantConfig) []InitialTenantConfig {
|
||||
bySlug := make(map[string]InitialTenantConfig, len(configs))
|
||||
for _, config := range configs {
|
||||
bySlug[strings.ToLower(config.Slug)] = config
|
||||
}
|
||||
|
||||
ordered := make([]InitialTenantConfig, 0, len(configs))
|
||||
visited := make(map[string]bool, len(configs))
|
||||
var visit func(config InitialTenantConfig)
|
||||
visit = func(config InitialTenantConfig) {
|
||||
key := strings.ToLower(config.Slug)
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
if config.ParentSlug != "" {
|
||||
if parent, ok := bySlug[strings.ToLower(config.ParentSlug)]; ok {
|
||||
visit(parent)
|
||||
}
|
||||
}
|
||||
visited[key] = true
|
||||
ordered = append(ordered, config)
|
||||
}
|
||||
|
||||
for _, config := range configs {
|
||||
visit(config)
|
||||
}
|
||||
return ordered
|
||||
}
|
||||
|
||||
func createSeedTenant(
|
||||
ctx context.Context,
|
||||
repo repository.TenantRepository,
|
||||
outboxRepo repository.KetoOutboxRepository,
|
||||
config InitialTenantConfig,
|
||||
tenantType string,
|
||||
parentID *string,
|
||||
) (*domain.Tenant, error) {
|
||||
tenant := &domain.Tenant{
|
||||
ID: config.TenantID,
|
||||
Type: tenantType,
|
||||
Name: config.Name,
|
||||
Slug: config.Slug,
|
||||
Description: config.Description,
|
||||
Status: domain.TenantStatusActive,
|
||||
ParentID: parentID,
|
||||
}
|
||||
|
||||
if err := repo.Create(ctx, tenant); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: tenant.ID,
|
||||
Relation: "admins",
|
||||
Subject: "System:global#super_admins",
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tenant.ParentID != nil {
|
||||
if err := outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: tenant.ID,
|
||||
Relation: "parents",
|
||||
Subject: "Tenant:" + *tenant.ParentID,
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
for _, domainName := range config.Domains {
|
||||
if err := repo.AddDomain(ctx, tenant.ID, domainName, true); err != nil {
|
||||
slog.Error("Failed to add domain to seeded tenant", "tenant", config.Slug, "domain", domainName, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return repo.FindBySlug(ctx, config.Slug)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user