1
0
forked from baron/baron-sso

Implement tenant import and RP auto login policies

This commit is contained in:
2026-04-30 15:45:34 +09:00
parent 24807eab0f
commit f7e4d43b16
76 changed files with 5307 additions and 441 deletions

View File

@@ -6,6 +6,7 @@ import (
"baron-sso-backend/internal/service"
"baron-sso-backend/internal/utils"
"bytes"
"context"
"encoding/csv"
"errors"
"fmt"
@@ -68,14 +69,22 @@ type tenantImportResult struct {
Errors []string `json:"errors"`
}
type tenantDomainConflict struct {
Domain string `json:"domain"`
TenantID string `json:"tenantId"`
TenantName string `json:"tenantName"`
TenantSlug string `json:"tenantSlug"`
}
type tenantCSVRecord struct {
TenantID string
Name string
Type string
ParentTenantID *string
Slug string
Memo string
Domains []string
TenantID string
Name string
Type string
ParentTenantID *string
ParentTenantSlug string
Slug string
Memo string
Domains []string
}
func (h *TenantHandler) RegisterTenantPublic(c *fiber.Ctx) error {
@@ -258,13 +267,24 @@ func (h *TenantHandler) ExportTenantsCSV(c *fiber.Ctx) error {
var buf bytes.Buffer
writer := csv.NewWriter(&buf)
if err := writer.Write([]string{"tenant_id", "name", "type", "parent_tenant_id", "slug", "memo", "email_domain"}); err != nil {
includeIDs := includeCSVIds(c)
if includeIDs {
if err := writer.Write([]string{"tenant_id", "name", "type", "parent_tenant_id", "parent_tenant_slug", "slug", "memo", "email_domain"}); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
} else if err := writer.Write([]string{"name", "type", "parent_tenant_slug", "slug", "memo", "email_domain"}); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
slugByID := make(map[string]string, len(tenants))
for _, tenant := range tenants {
slugByID[tenant.ID] = tenant.Slug
}
for _, tenant := range tenants {
parentID := ""
parentSlug := ""
if tenant.ParentID != nil {
parentID = *tenant.ParentID
parentSlug = slugByID[parentID]
}
domains := make([]string, 0, len(tenant.Domains))
for _, domainName := range tenant.Domains {
@@ -273,15 +293,27 @@ func (h *TenantHandler) ExportTenantsCSV(c *fiber.Ctx) error {
domains = append(domains, domainName)
}
}
if err := writer.Write([]string{
tenant.ID,
row := []string{
tenant.Name,
tenant.Type,
parentID,
parentSlug,
tenant.Slug,
tenant.Description,
strings.Join(domains, ";"),
}); err != nil {
}
if includeIDs {
row = []string{
tenant.ID,
tenant.Name,
tenant.Type,
parentID,
parentSlug,
tenant.Slug,
tenant.Description,
strings.Join(domains, ";"),
}
}
if err := writer.Write(row); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
}
@@ -305,33 +337,60 @@ func (h *TenantHandler) ImportTenantsCSV(c *fiber.Ctx) error {
if err != nil {
return errorJSON(c, fiber.StatusBadRequest, err.Error())
}
records = orderTenantCSVRecordsByParentSlug(records)
creatorID := ""
if profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse); ok && profile != nil {
creatorID = profile.ID
}
tenantIDBySlug := make(map[string]string)
if h.Service != nil {
if tenants, _, err := h.Service.ListTenants(c.Context(), 10000, 0, ""); err == nil {
for _, tenant := range tenants {
tenantIDBySlug[strings.ToLower(tenant.Slug)] = tenant.ID
}
}
}
result := tenantImportResult{Errors: make([]string, 0)}
for i, record := range records {
rowNumber := i + 2
if record.ParentTenantID == nil && record.ParentTenantSlug != "" {
parentID := tenantIDBySlug[strings.ToLower(record.ParentTenantSlug)]
if parentID == "" {
result.Failed++
result.Errors = append(result.Errors, fmt.Sprintf("row %d: parent tenant slug not found: %s", rowNumber, record.ParentTenantSlug))
continue
}
record.ParentTenantID = &parentID
}
if record.TenantID != "" || (h.DB != nil && record.Slug != "") {
updated, err := h.upsertTenantCSVRecord(c, record)
tenant, updated, err := h.upsertTenantCSVRecord(c, record)
if err != nil {
result.Failed++
result.Errors = append(result.Errors, fmt.Sprintf("row %d: %s", rowNumber, err.Error()))
continue
}
if updated {
tenantIDBySlug[strings.ToLower(record.Slug)] = tenant.ID
result.Updated++
continue
}
}
if err := h.createTenantCSVRecord(c, record, creatorID); err != nil {
tenant, err := h.createTenantCSVRecord(c, record, creatorID)
if err != nil {
result.Failed++
result.Errors = append(result.Errors, fmt.Sprintf("row %d: %s", rowNumber, err.Error()))
continue
}
if tenant == nil {
result.Failed++
result.Errors = append(result.Errors, fmt.Sprintf("row %d: tenant creation returned empty result", rowNumber))
continue
}
tenantIDBySlug[strings.ToLower(record.Slug)] = tenant.ID
result.Created++
}
@@ -414,13 +473,14 @@ func parseTenantCSVRecords(r io.Reader) ([]tenantCSVRecord, error) {
}
records = append(records, tenantCSVRecord{
TenantID: tenantCSVValue(row, header, "tenant_id"),
Name: name,
Type: tenantType,
ParentTenantID: parentID,
Slug: slug,
Memo: tenantCSVValue(row, header, "memo"),
Domains: splitTenantCSVDomains(tenantCSVValue(row, header, "email_domain")),
TenantID: tenantCSVValue(row, header, "tenant_id"),
Name: name,
Type: tenantType,
ParentTenantID: parentID,
ParentTenantSlug: tenantCSVValue(row, header, "parent_tenant_slug"),
Slug: slug,
Memo: tenantCSVValue(row, header, "memo"),
Domains: splitTenantCSVDomains(tenantCSVValue(row, header, "email_domain")),
})
}
@@ -430,23 +490,25 @@ func parseTenantCSVRecords(r io.Reader) ([]tenantCSVRecord, error) {
func tenantCSVHeaderIndex(header []string) map[string]int {
index := make(map[string]int, len(header))
aliases := map[string]string{
"id": "tenant_id",
"tenantid": "tenant_id",
"tenant_id": "tenant_id",
"name": "name",
"type": "type",
"parentid": "parent_tenant_id",
"parent_id": "parent_tenant_id",
"parenttenantid": "parent_tenant_id",
"parent_tenant_id": "parent_tenant_id",
"slug": "slug",
"memo": "memo",
"description": "memo",
"email-domain": "email_domain",
"emaildomain": "email_domain",
"email_domain": "email_domain",
"domain": "email_domain",
"domains": "email_domain",
"id": "tenant_id",
"tenantid": "tenant_id",
"tenant_id": "tenant_id",
"name": "name",
"type": "type",
"parentid": "parent_tenant_id",
"parent_id": "parent_tenant_id",
"parenttenantid": "parent_tenant_id",
"parent_tenant_id": "parent_tenant_id",
"parenttenantslug": "parent_tenant_slug",
"parent_tenant_slug": "parent_tenant_slug",
"slug": "slug",
"memo": "memo",
"description": "memo",
"email-domain": "email_domain",
"emaildomain": "email_domain",
"email_domain": "email_domain",
"domain": "email_domain",
"domains": "email_domain",
}
for i, column := range header {
key := strings.ToLower(strings.TrimSpace(column))
@@ -475,6 +537,40 @@ func tenantCSVRowIsEmpty(row []string) bool {
return true
}
func includeCSVIds(c *fiber.Ctx) bool {
value := strings.ToLower(strings.TrimSpace(c.Query("includeIds")))
return value == "true" || value == "1" || value == "yes"
}
func orderTenantCSVRecordsByParentSlug(records []tenantCSVRecord) []tenantCSVRecord {
bySlug := make(map[string]tenantCSVRecord, len(records))
for _, record := range records {
bySlug[strings.ToLower(record.Slug)] = record
}
ordered := make([]tenantCSVRecord, 0, len(records))
visited := make(map[string]bool, len(records))
var visit func(record tenantCSVRecord)
visit = func(record tenantCSVRecord) {
key := strings.ToLower(record.Slug)
if visited[key] {
return
}
if record.ParentTenantSlug != "" {
if parent, ok := bySlug[strings.ToLower(record.ParentTenantSlug)]; ok {
visit(parent)
}
}
visited[key] = true
ordered = append(ordered, record)
}
for _, record := range records {
visit(record)
}
return ordered
}
func splitTenantCSVDomains(value string) []string {
value = strings.ReplaceAll(value, "\n", ";")
value = strings.ReplaceAll(value, ",", ";")
@@ -492,12 +588,203 @@ func splitTenantCSVDomains(value string) []string {
return domains
}
func (h *TenantHandler) upsertTenantCSVRecord(c *fiber.Ctx, record tenantCSVRecord) (bool, error) {
func normalizeTenantDomainInputs(values []string) []string {
seen := make(map[string]bool, len(values))
domains := make([]string, 0, len(values))
for _, value := range values {
for _, part := range strings.FieldsFunc(value, func(r rune) bool {
return r == ',' || r == ';' || r == '\n' || r == '\r' || r == '\t' || r == ' '
}) {
domainName := strings.ToLower(strings.TrimSpace(part))
if domainName == "" || seen[domainName] {
continue
}
seen[domainName] = true
domains = append(domains, domainName)
}
}
return domains
}
func normalizeTenantConfig(config map[string]any) (domain.JSONMap, error) {
normalized := make(domain.JSONMap, len(config))
for key, value := range config {
if key == "userSchema" {
fields, err := normalizeTenantUserSchema(value)
if err != nil {
return nil, err
}
normalized[key] = fields
continue
}
normalized[key] = value
}
return normalized, nil
}
func normalizeTenantUserSchema(value any) ([]any, error) {
if value == nil {
return nil, nil
}
rawFields, ok := value.([]any)
if !ok {
return nil, fmt.Errorf("userSchema must be an array")
}
fields := make([]any, 0, len(rawFields))
for _, raw := range rawFields {
field, ok := raw.(map[string]any)
if !ok {
return nil, fmt.Errorf("userSchema fields must be objects")
}
normalized := make(map[string]any, len(field))
for key, value := range field {
if key == "maxLength" {
continue
}
normalized[key] = value
}
isLoginID, _ := normalized["isLoginId"].(bool)
if isLoginID {
fieldType, _ := normalized["type"].(string)
if fieldType != "" && fieldType != "text" {
return nil, fmt.Errorf("login ID fields must be text")
}
normalized["type"] = "text"
normalized["indexed"] = true
} else if indexed, ok := normalized["indexed"].(bool); !ok || !indexed {
normalized["indexed"] = false
}
fields = append(fields, normalized)
}
return fields, nil
}
func normalizeTenantDomainForceSet(values []string) map[string]bool {
domains := normalizeTenantDomainInputs(values)
force := make(map[string]bool, len(domains))
for _, domainName := range domains {
force[domainName] = true
}
return force
}
func tenantDomainConflictJSON(c *fiber.Ctx, conflicts []tenantDomainConflict) error {
return c.Status(fiber.StatusConflict).JSON(fiber.Map{
"code": "tenant_domain_conflict",
"error": "domain is already assigned to another tenant",
"conflicts": conflicts,
})
}
func (h *TenantHandler) findTenantDomainConflicts(ctx context.Context, tenantID string, domains []string, forceDomains []string) ([]tenantDomainConflict, error) {
if h.DB == nil || h.DB.Config == nil || len(domains) == 0 {
return nil, nil
}
force := normalizeTenantDomainForceSet(forceDomains)
var rows []domain.TenantDomain
query := h.DB.WithContext(ctx).Where("domain IN ?", domains)
if tenantID != "" {
query = query.Where("tenant_id <> ?", tenantID)
}
if err := query.Find(&rows).Error; err != nil {
return nil, err
}
conflicts := make([]tenantDomainConflict, 0, len(rows))
tenantIDs := make([]string, 0, len(rows))
seenTenantIDs := make(map[string]bool, len(rows))
for _, row := range rows {
if force[row.Domain] {
continue
}
if !seenTenantIDs[row.TenantID] {
seenTenantIDs[row.TenantID] = true
tenantIDs = append(tenantIDs, row.TenantID)
}
}
tenantsByID := make(map[string]domain.Tenant, len(tenantIDs))
if len(tenantIDs) > 0 {
var tenants []domain.Tenant
if err := h.DB.WithContext(ctx).Where("id IN ?", tenantIDs).Find(&tenants).Error; err != nil {
return nil, err
}
for _, tenant := range tenants {
tenantsByID[tenant.ID] = tenant
}
}
for _, row := range rows {
if force[row.Domain] {
continue
}
conflict := tenantDomainConflict{
Domain: row.Domain,
TenantID: row.TenantID,
}
if tenant, ok := tenantsByID[row.TenantID]; ok {
conflict.TenantName = tenant.Name
conflict.TenantSlug = tenant.Slug
}
conflicts = append(conflicts, conflict)
}
return conflicts, nil
}
func (h *TenantHandler) replaceTenantDomains(ctx context.Context, tenantID string, domains []string, forceDomains []string) error {
if h.DB == nil {
return errors.New("database not available")
}
if h.DB.Config == nil {
return nil
}
deleteQuery := h.DB.WithContext(ctx).Where("tenant_id = ?", tenantID)
if len(domains) > 0 {
deleteQuery = deleteQuery.Where("domain NOT IN ?", domains)
}
if err := deleteQuery.Delete(&domain.TenantDomain{}).Error; err != nil {
return fmt.Errorf("failed to clear old domains: %w", err)
}
for _, domainName := range domains {
var existing domain.TenantDomain
err := h.DB.WithContext(ctx).Unscoped().
Where("tenant_id = ? AND domain = ?", tenantID, domainName).
First(&existing).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
if err := repository.NewTenantRepository(h.DB).AddDomain(ctx, tenantID, domainName, true); err != nil {
return fmt.Errorf("failed to add domain: %s", domainName)
}
continue
}
if err != nil {
return err
}
if err := h.DB.WithContext(ctx).Unscoped().Model(&existing).Updates(map[string]any{
"verified": true,
"deleted_at": nil,
}).Error; err != nil {
return fmt.Errorf("failed to add domain: %s", domainName)
}
}
return nil
}
func (h *TenantHandler) upsertTenantCSVRecord(c *fiber.Ctx, record tenantCSVRecord) (*domain.Tenant, bool, error) {
if h.DB == nil {
if record.TenantID != "" {
return false, errors.New("database not available for tenant update")
return nil, false, errors.New("database not available for tenant update")
}
return false, nil
return nil, false, nil
}
var tenant domain.Tenant
@@ -510,10 +797,10 @@ func (h *TenantHandler) upsertTenantCSVRecord(c *fiber.Ctx, record tenantCSVReco
}
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
return nil, false, nil
}
if err != nil {
return false, err
return nil, false, err
}
tenant.Name = record.Name
@@ -526,29 +813,29 @@ func (h *TenantHandler) upsertTenantCSVRecord(c *fiber.Ctx, record tenantCSVReco
}
if err := h.DB.Save(&tenant).Error; err != nil {
return false, err
return nil, false, err
}
if err := h.DB.Delete(&domain.TenantDomain{}, "tenant_id = ?", tenant.ID).Error; err != nil {
return false, err
return nil, false, err
}
repo := repository.NewTenantRepository(h.DB)
for _, domainName := range record.Domains {
if err := repo.AddDomain(c.Context(), tenant.ID, domainName, true); err != nil {
return false, err
return nil, false, err
}
}
return true, nil
return &tenant, true, nil
}
func (h *TenantHandler) createTenantCSVRecord(c *fiber.Ctx, record tenantCSVRecord, creatorID string) error {
func (h *TenantHandler) createTenantCSVRecord(c *fiber.Ctx, record tenantCSVRecord, creatorID string) (*domain.Tenant, error) {
if h.DB != nil && record.TenantID != "" {
var exists int64
if err := h.DB.Unscoped().Model(&domain.Tenant{}).Where("slug = ?", record.Slug).Count(&exists).Error; err != nil {
return err
return nil, err
}
if exists > 0 {
return errors.New("tenant slug already exists")
return nil, errors.New("tenant slug already exists")
}
tenant := domain.Tenant{
@@ -561,7 +848,7 @@ func (h *TenantHandler) createTenantCSVRecord(c *fiber.Ctx, record tenantCSVReco
Status: domain.TenantStatusActive,
}
if err := h.DB.Create(&tenant).Error; err != nil {
return err
return nil, err
}
if h.KetoOutbox != nil {
_ = h.KetoOutbox.Create(c.Context(), &domain.KetoOutbox{
@@ -595,14 +882,14 @@ func (h *TenantHandler) createTenantCSVRecord(c *fiber.Ctx, record tenantCSVReco
repo := repository.NewTenantRepository(h.DB)
for _, domainName := range record.Domains {
if err := repo.AddDomain(c.Context(), tenant.ID, domainName, true); err != nil {
return err
return nil, err
}
}
return nil
return &tenant, nil
}
_, err := h.Service.RegisterTenant(c.Context(), record.Name, record.Slug, record.Type, record.Memo, record.Domains, record.ParentTenantID, creatorID)
return err
tenant, err := h.Service.RegisterTenant(c.Context(), record.Name, record.Slug, record.Type, record.Memo, record.Domains, record.ParentTenantID, creatorID)
return tenant, err
}
func (h *TenantHandler) GetTenant(c *fiber.Ctx) error {
@@ -646,14 +933,15 @@ func (h *TenantHandler) CreateTenant(c *fiber.Ctx) error {
}
var req struct {
Name string `json:"name"`
Slug string `json:"slug"`
Type string `json:"type"`
Description string `json:"description"`
Status string `json:"status"`
Domains []string `json:"domains"`
ParentID *string `json:"parentId"`
Config map[string]any `json:"config"`
Name string `json:"name"`
Slug string `json:"slug"`
Type string `json:"type"`
Description string `json:"description"`
Status string `json:"status"`
Domains []string `json:"domains"`
ForceDomains []string `json:"forceDomainConflicts"`
ParentID *string `json:"parentId"`
Config map[string]any `json:"config"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
@@ -701,7 +989,16 @@ func (h *TenantHandler) CreateTenant(c *fiber.Ctx) error {
creatorID = profile.ID
}
tenant, err := h.Service.RegisterTenant(c.Context(), name, slug, tenantType, req.Description, req.Domains, parentID, creatorID)
normalizedDomains := normalizeTenantDomainInputs(req.Domains)
conflicts, err := h.findTenantDomainConflicts(c.Context(), "", normalizedDomains, req.ForceDomains)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
if len(conflicts) > 0 {
return tenantDomainConflictJSON(c, conflicts)
}
tenant, err := h.Service.RegisterTenant(c.Context(), name, slug, tenantType, req.Description, nil, parentID, creatorID)
if err != nil {
if strings.Contains(err.Error(), "already exists") {
return errorJSON(c, fiber.StatusConflict, err.Error())
@@ -713,10 +1010,20 @@ func (h *TenantHandler) CreateTenant(c *fiber.Ctx) error {
summary.MemberCount = 0
if req.Config != nil {
tenant.Config = req.Config
config, err := normalizeTenantConfig(req.Config)
if err != nil {
return errorJSON(c, fiber.StatusBadRequest, err.Error())
}
tenant.Config = config
h.DB.Save(tenant)
summary.Config = tenant.Config
}
if err := h.replaceTenantDomains(c.Context(), tenant.ID, normalizedDomains, req.ForceDomains); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
if len(normalizedDomains) > 0 {
summary.Domains = normalizedDomains
}
return c.Status(fiber.StatusCreated).JSON(summary)
}
@@ -740,14 +1047,15 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
}
var req struct {
Name *string `json:"name"`
Type *string `json:"type"`
Slug *string `json:"slug"`
Description *string `json:"description"`
Status *string `json:"status"`
ParentID *string `json:"parentId"`
Domains []string `json:"domains"`
Config map[string]any `json:"config"`
Name *string `json:"name"`
Type *string `json:"type"`
Slug *string `json:"slug"`
Description *string `json:"description"`
Status *string `json:"status"`
ParentID *string `json:"parentId"`
Domains []string `json:"domains"`
ForceDomains []string `json:"forceDomainConflicts"`
Config map[string]any `json:"config"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
@@ -835,7 +1143,11 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
}
}
if req.Config != nil {
tenant.Config = req.Config
config, err := normalizeTenantConfig(req.Config)
if err != nil {
return errorJSON(c, fiber.StatusBadRequest, err.Error())
}
tenant.Config = config
}
if err := h.DB.Save(&tenant).Error; err != nil {
@@ -844,18 +1156,16 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
// Update domains if provided
if req.Domains != nil {
// Simple approach: Delete existing and recreate
if err := h.DB.Delete(&domain.TenantDomain{}, "tenant_id = ?", tenant.ID).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, "failed to clear old domains")
normalizedDomains := normalizeTenantDomainInputs(req.Domains)
conflicts, err := h.findTenantDomainConflicts(c.Context(), tenant.ID, normalizedDomains, req.ForceDomains)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
for _, d := range req.Domains {
if strings.TrimSpace(d) == "" {
continue
}
// Use repository for consistency
if err := repository.NewTenantRepository(h.DB).AddDomain(c.Context(), tenant.ID, strings.TrimSpace(d), true); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, "failed to add domain: "+d)
}
if len(conflicts) > 0 {
return tenantDomainConflictJSON(c, conflicts)
}
if err := h.replaceTenantDomains(c.Context(), tenant.ID, normalizedDomains, req.ForceDomains); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
}