1
0
forked from baron/baron-sso

Implement tenant import and RP auto login policies

This commit is contained in:
2026-04-30 15:45:34 +09:00
parent 24807eab0f
commit f7e4d43b16
76 changed files with 5307 additions and 441 deletions

View File

@@ -29,6 +29,10 @@ func Run(db *gorm.DB) error {
func migrateSchemas(db *gorm.DB) error {
slog.Info("[Bootstrap] Migrating database schemas...")
if err := dropLegacyTenantDomainUniqueIndex(db); err != nil {
return err
}
// Add all domain models here
return db.AutoMigrate(
&domain.Tenant{},
@@ -43,6 +47,20 @@ func migrateSchemas(db *gorm.DB) error {
&domain.KetoOutbox{},
&domain.SharedLink{},
&domain.DeveloperRequest{},
&domain.RPUserMetadata{},
// &domain.RelyingParty{}, // Removed: SSOT is Hydra + Keto
)
}
func dropLegacyTenantDomainUniqueIndex(db *gorm.DB) error {
if !db.Migrator().HasTable(&domain.TenantDomain{}) {
return nil
}
if !db.Migrator().HasIndex(&domain.TenantDomain{}, "idx_tenant_domains_domain") {
return nil
}
if err := db.Migrator().DropIndex(&domain.TenantDomain{}, "idx_tenant_domains_domain"); err != nil {
return fmt.Errorf("failed to drop legacy tenant domain unique index: %w", err)
}
return nil
}

View File

@@ -4,15 +4,33 @@ import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"baron-sso-backend/internal/service"
"baron-sso-backend/internal/utils"
"bytes"
"context"
"encoding/csv"
"errors"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"strings"
"gorm.io/gorm"
)
const seedTenantCSVPathEnv = "SEED_TENANT_CSV_PATH"
var seedTenantCSVPathCandidates = []string{
"adminfront/seed-tenant.csv",
"../adminfront/seed-tenant.csv",
"../../adminfront/seed-tenant.csv",
"../../../adminfront/seed-tenant.csv",
"/app/adminfront/seed-tenant.csv",
}
type InitialTenantConfig struct {
TenantID string
Name string
Slug string
Type string
@@ -21,32 +39,31 @@ type InitialTenantConfig struct {
Domains []string
}
// Hardcoded for now, can be moved to config file or env later
var defaultTenants = []InitialTenantConfig{
{
Name: "한맥가족",
Slug: "hanmac-family",
Type: domain.TenantTypeCompanyGroup,
},
{
Name: "한맥기술",
Slug: "hanmac",
Type: domain.TenantTypeCompany,
ParentSlug: "hanmac-family",
Description: "Primary Family Company",
Domains: []string{"hanmaceng.co.kr", "hmac.kr"},
},
{
Name: "삼안",
Slug: "saman",
Type: domain.TenantTypeCompany,
ParentSlug: "hanmac-family",
Domains: []string{"samaneng.com"},
},
func SeedTenants(db *gorm.DB) error {
slog.Info("[Bootstrap] Checking initial tenant seed...")
var tenantCount int64
if err := db.Model(&domain.Tenant{}).Count(&tenantCount).Error; err != nil {
return fmt.Errorf("count tenants before seed: %w", err)
}
if tenantCount > 0 {
slog.Info("[Bootstrap] Tenant seed skipped because tenants already exist", "count", tenantCount)
return nil
}
configs, err := loadSeedTenantConfigs()
if err != nil {
return err
}
if len(configs) == 0 {
return errors.New("seed tenant csv has no tenant rows")
}
return seedTenantConfigs(db, configs)
}
func SeedTenants(db *gorm.DB) error {
slog.Info("[Bootstrap] Seeding initial tenants...")
func seedTenantConfigs(db *gorm.DB, configs []InitialTenantConfig) error {
slog.Info("[Bootstrap] Seeding initial tenants from CSV...", "count", len(configs))
repo := repository.NewTenantRepository(db)
userRepo := repository.NewUserRepository(db)
userGroupRepo := repository.NewUserGroupRepository(db)
@@ -54,7 +71,7 @@ func SeedTenants(db *gorm.DB) error {
svc := service.NewTenantService(repo, userRepo, userGroupRepo, outboxRepo)
ctx := context.Background()
for _, config := range defaultTenants {
for _, config := range orderSeedTenantConfigsByParentSlug(configs) {
tenantType := config.Type
if tenantType == "" {
tenantType = domain.TenantTypeCompany
@@ -73,75 +90,273 @@ func SeedTenants(db *gorm.DB) error {
parentID = &parent.ID
}
existing, err := repo.FindBySlug(ctx, config.Slug)
if err == nil && existing != nil {
slog.Info("[Bootstrap] Tenant already exists, checking domains...", "slug", config.Slug)
changed := false
if existing.Name != config.Name {
existing.Name = config.Name
changed = true
}
if existing.Type != tenantType {
existing.Type = tenantType
changed = true
}
if existing.Status != domain.TenantStatusActive {
existing.Status = domain.TenantStatusActive
changed = true
}
if config.ParentSlug != "" {
if existing.ParentID == nil || *existing.ParentID != *parentID {
existing.ParentID = parentID
changed = true
if err := outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: existing.ID,
Relation: "parents",
Subject: "Tenant:" + *parentID,
Action: domain.KetoOutboxActionCreate,
}); err != nil {
slog.Error("Failed to create outbox entry for seeded tenant hierarchy", "tenant", existing.ID, "error", err)
return err
}
}
} else if existing.ParentID != nil {
existing.ParentID = nil
changed = true
}
if changed {
if err := repo.Update(ctx, existing); err != nil {
slog.Error("Failed to update seeded tenant", "slug", config.Slug, "error", err)
return err
}
}
// Optional: Check and add missing domains
for _, d := range config.Domains {
found := false
for _, ed := range existing.Domains {
if ed.Domain == d {
found = true
break
}
}
if !found {
slog.Info("[Bootstrap] Adding missing domain to tenant", "slug", config.Slug, "domain", d)
if err := repo.AddDomain(ctx, existing.ID, d, true); err != nil {
slog.Error("Failed to add domain", "error", err)
}
}
}
continue
slog.Info("[Bootstrap] Creating seed tenant", "name", config.Name, "slug", config.Slug)
var tenant *domain.Tenant
var err error
if config.TenantID != "" {
tenant, err = createSeedTenant(ctx, repo, outboxRepo, config, tenantType, parentID)
} else {
tenant, err = svc.RegisterTenant(ctx, config.Name, config.Slug, tenantType, config.Description, config.Domains, parentID, "")
}
slog.Info("[Bootstrap] Creating default tenant", "name", config.Name, "slug", config.Slug)
tenant, err := svc.RegisterTenant(ctx, config.Name, config.Slug, tenantType, config.Description, config.Domains, parentID, "")
if err != nil {
slog.Error("Failed to seed tenant", "slug", config.Slug, "error", err)
return err
}
// Explicitly set to active during seed
tenant.Status = domain.TenantStatusActive
db.Save(tenant)
if err := db.Save(tenant).Error; err != nil {
return err
}
}
return nil
}
func loadSeedTenantConfigs() ([]InitialTenantConfig, error) {
path, err := findSeedTenantCSVPath()
if err != nil {
return nil, err
}
file, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open seed tenant csv %q: %w", path, err)
}
defer file.Close()
configs, err := parseSeedTenantCSV(file)
if err != nil {
return nil, fmt.Errorf("parse seed tenant csv %q: %w", path, err)
}
return configs, nil
}
func findSeedTenantCSVPath() (string, error) {
if configured := strings.TrimSpace(os.Getenv(seedTenantCSVPathEnv)); configured != "" {
return configured, nil
}
for _, candidate := range seedTenantCSVPathCandidates {
cleaned := filepath.Clean(candidate)
if _, err := os.Stat(cleaned); err == nil {
return cleaned, nil
}
}
return "", fmt.Errorf("seed tenant csv not found; set %s or add adminfront/seed-tenant.csv", seedTenantCSVPathEnv)
}
func parseSeedTenantCSV(r io.Reader) ([]InitialTenantConfig, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, errors.New("failed to read csv")
}
data = bytes.TrimPrefix(data, []byte{0xEF, 0xBB, 0xBF})
reader := csv.NewReader(bytes.NewReader(data))
reader.FieldsPerRecord = -1
rows, err := reader.ReadAll()
if err != nil {
return nil, fmt.Errorf("invalid csv: %w", err)
}
if len(rows) == 0 {
return nil, errors.New("csv is empty")
}
header := seedTenantCSVHeaderIndex(rows[0])
for _, key := range []string{"name", "type", "slug"} {
if _, ok := header[key]; !ok {
return nil, fmt.Errorf("missing required column: %s", key)
}
}
configs := make([]InitialTenantConfig, 0, len(rows)-1)
for i, row := range rows[1:] {
if seedTenantCSVRowIsEmpty(row) {
continue
}
name := seedTenantCSVValue(row, header, "name")
if name == "" {
return nil, fmt.Errorf("row %d: name is required", i+2)
}
tenantType := normalizeSeedTenantType(seedTenantCSVValue(row, header, "type"))
if tenantType == "" {
return nil, fmt.Errorf("row %d: invalid tenant type", i+2)
}
slug := utils.GenerateSlug(seedTenantCSVValue(row, header, "slug"))
if slug == "" {
return nil, fmt.Errorf("row %d: slug is required", i+2)
}
configs = append(configs, InitialTenantConfig{
TenantID: seedTenantCSVValue(row, header, "tenant_id"),
Name: name,
Type: tenantType,
ParentSlug: seedTenantCSVValue(row, header, "parent_tenant_slug"),
Slug: slug,
Description: seedTenantCSVValue(row, header, "memo"),
Domains: splitSeedTenantCSVDomains(seedTenantCSVValue(row, header, "email_domain")),
})
}
return configs, nil
}
func seedTenantCSVHeaderIndex(header []string) map[string]int {
index := make(map[string]int, len(header))
aliases := map[string]string{
"id": "tenant_id",
"tenantid": "tenant_id",
"tenant_id": "tenant_id",
"name": "name",
"type": "type",
"parenttenantslug": "parent_tenant_slug",
"parent_tenant_slug": "parent_tenant_slug",
"parent_slug": "parent_tenant_slug",
"slug": "slug",
"memo": "memo",
"description": "memo",
"email-domain": "email_domain",
"emaildomain": "email_domain",
"email_domain": "email_domain",
"domain": "email_domain",
"domains": "email_domain",
}
for i, column := range header {
key := strings.ToLower(strings.TrimSpace(column))
key = strings.ReplaceAll(key, " ", "_")
if canonical, ok := aliases[key]; ok {
index[canonical] = i
}
}
return index
}
func seedTenantCSVValue(row []string, header map[string]int, key string) string {
idx, ok := header[key]
if !ok || idx >= len(row) {
return ""
}
return strings.TrimSpace(row[idx])
}
func seedTenantCSVRowIsEmpty(row []string) bool {
for _, value := range row {
if strings.TrimSpace(value) != "" {
return false
}
}
return true
}
func normalizeSeedTenantType(value string) string {
switch strings.ToUpper(strings.TrimSpace(value)) {
case domain.TenantTypePersonal:
return domain.TenantTypePersonal
case domain.TenantTypeCompany:
return domain.TenantTypeCompany
case domain.TenantTypeCompanyGroup:
return domain.TenantTypeCompanyGroup
case domain.TenantTypeUserGroup:
return domain.TenantTypeUserGroup
default:
return ""
}
}
func splitSeedTenantCSVDomains(value string) []string {
value = strings.ReplaceAll(value, "\n", ";")
value = strings.ReplaceAll(value, ",", ";")
parts := strings.Split(value, ";")
domains := make([]string, 0, len(parts))
seen := make(map[string]bool, len(parts))
for _, part := range parts {
domainName := strings.ToLower(strings.TrimSpace(part))
if domainName == "" || seen[domainName] {
continue
}
seen[domainName] = true
domains = append(domains, domainName)
}
return domains
}
func orderSeedTenantConfigsByParentSlug(configs []InitialTenantConfig) []InitialTenantConfig {
bySlug := make(map[string]InitialTenantConfig, len(configs))
for _, config := range configs {
bySlug[strings.ToLower(config.Slug)] = config
}
ordered := make([]InitialTenantConfig, 0, len(configs))
visited := make(map[string]bool, len(configs))
var visit func(config InitialTenantConfig)
visit = func(config InitialTenantConfig) {
key := strings.ToLower(config.Slug)
if visited[key] {
return
}
if config.ParentSlug != "" {
if parent, ok := bySlug[strings.ToLower(config.ParentSlug)]; ok {
visit(parent)
}
}
visited[key] = true
ordered = append(ordered, config)
}
for _, config := range configs {
visit(config)
}
return ordered
}
func createSeedTenant(
ctx context.Context,
repo repository.TenantRepository,
outboxRepo repository.KetoOutboxRepository,
config InitialTenantConfig,
tenantType string,
parentID *string,
) (*domain.Tenant, error) {
tenant := &domain.Tenant{
ID: config.TenantID,
Type: tenantType,
Name: config.Name,
Slug: config.Slug,
Description: config.Description,
Status: domain.TenantStatusActive,
ParentID: parentID,
}
if err := repo.Create(ctx, tenant); err != nil {
return nil, err
}
if err := outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: tenant.ID,
Relation: "admins",
Subject: "System:global#super_admins",
Action: domain.KetoOutboxActionCreate,
}); err != nil {
return nil, err
}
if tenant.ParentID != nil {
if err := outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: tenant.ID,
Relation: "parents",
Subject: "Tenant:" + *tenant.ParentID,
Action: domain.KetoOutboxActionCreate,
}); err != nil {
return nil, err
}
}
for _, domainName := range config.Domains {
if err := repo.AddDomain(ctx, tenant.ID, domainName, true); err != nil {
slog.Error("Failed to add domain to seeded tenant", "tenant", config.Slug, "domain", domainName, "error", err)
}
}
return repo.FindBySlug(ctx, config.Slug)
}

View File

@@ -2,17 +2,21 @@ package bootstrap
import (
"baron-sso-backend/internal/domain"
"reflect"
"os"
"path/filepath"
"testing"
)
func TestDefaultTenantsSeedOrderAndHierarchy(t *testing.T) {
func TestSeedTenantCSVDefinesOnlyRequiredRootTenants(t *testing.T) {
configs, err := loadSeedTenantConfigs()
if err != nil {
t.Fatalf("loadSeedTenantConfigs returned error: %v", err)
}
expected := []struct {
name string
slug string
tenantType string
parentSlug string
domains []string
}{
{
name: "한맥가족",
@@ -20,55 +24,61 @@ func TestDefaultTenantsSeedOrderAndHierarchy(t *testing.T) {
tenantType: domain.TenantTypeCompanyGroup,
},
{
name: "한맥기술",
slug: "hanmac",
tenantType: domain.TenantTypeCompany,
parentSlug: "hanmac-family",
domains: []string{"hanmaceng.co.kr", "hmac.kr"},
},
{
name: "삼안",
slug: "saman",
tenantType: domain.TenantTypeCompany,
parentSlug: "hanmac-family",
domains: []string{"samaneng.com"},
name: "Personal",
slug: "personal",
tenantType: domain.TenantTypePersonal,
},
}
if len(defaultTenants) != len(expected) {
t.Fatalf("expected %d default tenants, got %d", len(expected), len(defaultTenants))
if len(configs) != len(expected) {
t.Fatalf("expected %d seed tenants, got %d", len(expected), len(configs))
}
for i, want := range expected {
got := defaultTenants[i]
got := configs[i]
if got.Name != want.name {
t.Fatalf("tenant[%d] name = %q, want %q", i, got.Name, want.name)
}
if got.Slug != want.slug {
t.Fatalf("tenant[%d] slug = %q, want %q", i, got.Slug, want.slug)
}
if tenantType := stringField(t, got, "Type"); tenantType != want.tenantType {
t.Fatalf("tenant[%d] type = %q, want %q", i, tenantType, want.tenantType)
if got.Type != want.tenantType {
t.Fatalf("tenant[%d] type = %q, want %q", i, got.Type, want.tenantType)
}
if parentSlug := stringField(t, got, "ParentSlug"); parentSlug != want.parentSlug {
t.Fatalf("tenant[%d] parent slug = %q, want %q", i, parentSlug, want.parentSlug)
if got.ParentSlug != "" {
t.Fatalf("tenant[%d] parent slug = %q, want empty root tenant", i, got.ParentSlug)
}
if !reflect.DeepEqual(got.Domains, want.domains) {
t.Fatalf("tenant[%d] domains = %#v, want %#v", i, got.Domains, want.domains)
}
for _, tenant := range configs {
if tenant.Slug == "system" || tenant.Slug == "hanmac" || tenant.Slug == "saman" {
t.Fatalf("tenant %q must be configured by import, not seed CSV", tenant.Slug)
}
}
}
func stringField(t *testing.T, target InitialTenantConfig, name string) string {
t.Helper()
func TestLoadSeedTenantConfigsUsesConfiguredCSVPath(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "seed-tenant.csv")
csv := "name,type,parent_tenant_slug,slug,memo,email_domain\n" +
"Root,COMPANY_GROUP,,root,Root memo,\n" +
"Child,COMPANY,root,child,Child memo,child.example.com\n"
if err := os.WriteFile(path, []byte(csv), 0o600); err != nil {
t.Fatalf("failed to write seed csv: %v", err)
}
t.Setenv(seedTenantCSVPathEnv, path)
value := reflect.ValueOf(target)
field := value.FieldByName(name)
if !field.IsValid() {
t.Fatalf("InitialTenantConfig.%s is required", name)
configs, err := loadSeedTenantConfigs()
if err != nil {
t.Fatalf("loadSeedTenantConfigs returned error: %v", err)
}
if field.Kind() != reflect.String {
t.Fatalf("InitialTenantConfig.%s must be a string", name)
if len(configs) != 2 {
t.Fatalf("expected 2 configs, got %d", len(configs))
}
if configs[1].ParentSlug != "root" {
t.Fatalf("child parent slug = %q, want root", configs[1].ParentSlug)
}
if len(configs[1].Domains) != 1 || configs[1].Domains[0] != "child.example.com" {
t.Fatalf("child domains = %#v, want child.example.com", configs[1].Domains)
}
return field.String()
}