forked from baron/baron-sso
Implement tenant import and RP auto login policies
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"baron-sso-backend/internal/service"
|
||||
"baron-sso-backend/internal/utils"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -68,14 +69,22 @@ type tenantImportResult struct {
|
||||
Errors []string `json:"errors"`
|
||||
}
|
||||
|
||||
type tenantDomainConflict struct {
|
||||
Domain string `json:"domain"`
|
||||
TenantID string `json:"tenantId"`
|
||||
TenantName string `json:"tenantName"`
|
||||
TenantSlug string `json:"tenantSlug"`
|
||||
}
|
||||
|
||||
type tenantCSVRecord struct {
|
||||
TenantID string
|
||||
Name string
|
||||
Type string
|
||||
ParentTenantID *string
|
||||
Slug string
|
||||
Memo string
|
||||
Domains []string
|
||||
TenantID string
|
||||
Name string
|
||||
Type string
|
||||
ParentTenantID *string
|
||||
ParentTenantSlug string
|
||||
Slug string
|
||||
Memo string
|
||||
Domains []string
|
||||
}
|
||||
|
||||
func (h *TenantHandler) RegisterTenantPublic(c *fiber.Ctx) error {
|
||||
@@ -258,13 +267,24 @@ func (h *TenantHandler) ExportTenantsCSV(c *fiber.Ctx) error {
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer := csv.NewWriter(&buf)
|
||||
if err := writer.Write([]string{"tenant_id", "name", "type", "parent_tenant_id", "slug", "memo", "email_domain"}); err != nil {
|
||||
includeIDs := includeCSVIds(c)
|
||||
if includeIDs {
|
||||
if err := writer.Write([]string{"tenant_id", "name", "type", "parent_tenant_id", "parent_tenant_slug", "slug", "memo", "email_domain"}); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
} else if err := writer.Write([]string{"name", "type", "parent_tenant_slug", "slug", "memo", "email_domain"}); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
slugByID := make(map[string]string, len(tenants))
|
||||
for _, tenant := range tenants {
|
||||
slugByID[tenant.ID] = tenant.Slug
|
||||
}
|
||||
for _, tenant := range tenants {
|
||||
parentID := ""
|
||||
parentSlug := ""
|
||||
if tenant.ParentID != nil {
|
||||
parentID = *tenant.ParentID
|
||||
parentSlug = slugByID[parentID]
|
||||
}
|
||||
domains := make([]string, 0, len(tenant.Domains))
|
||||
for _, domainName := range tenant.Domains {
|
||||
@@ -273,15 +293,27 @@ func (h *TenantHandler) ExportTenantsCSV(c *fiber.Ctx) error {
|
||||
domains = append(domains, domainName)
|
||||
}
|
||||
}
|
||||
if err := writer.Write([]string{
|
||||
tenant.ID,
|
||||
row := []string{
|
||||
tenant.Name,
|
||||
tenant.Type,
|
||||
parentID,
|
||||
parentSlug,
|
||||
tenant.Slug,
|
||||
tenant.Description,
|
||||
strings.Join(domains, ";"),
|
||||
}); err != nil {
|
||||
}
|
||||
if includeIDs {
|
||||
row = []string{
|
||||
tenant.ID,
|
||||
tenant.Name,
|
||||
tenant.Type,
|
||||
parentID,
|
||||
parentSlug,
|
||||
tenant.Slug,
|
||||
tenant.Description,
|
||||
strings.Join(domains, ";"),
|
||||
}
|
||||
}
|
||||
if err := writer.Write(row); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
}
|
||||
@@ -305,33 +337,60 @@ func (h *TenantHandler) ImportTenantsCSV(c *fiber.Ctx) error {
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, err.Error())
|
||||
}
|
||||
records = orderTenantCSVRecordsByParentSlug(records)
|
||||
|
||||
creatorID := ""
|
||||
if profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse); ok && profile != nil {
|
||||
creatorID = profile.ID
|
||||
}
|
||||
|
||||
tenantIDBySlug := make(map[string]string)
|
||||
if h.Service != nil {
|
||||
if tenants, _, err := h.Service.ListTenants(c.Context(), 10000, 0, ""); err == nil {
|
||||
for _, tenant := range tenants {
|
||||
tenantIDBySlug[strings.ToLower(tenant.Slug)] = tenant.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := tenantImportResult{Errors: make([]string, 0)}
|
||||
for i, record := range records {
|
||||
rowNumber := i + 2
|
||||
if record.ParentTenantID == nil && record.ParentTenantSlug != "" {
|
||||
parentID := tenantIDBySlug[strings.ToLower(record.ParentTenantSlug)]
|
||||
if parentID == "" {
|
||||
result.Failed++
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("row %d: parent tenant slug not found: %s", rowNumber, record.ParentTenantSlug))
|
||||
continue
|
||||
}
|
||||
record.ParentTenantID = &parentID
|
||||
}
|
||||
if record.TenantID != "" || (h.DB != nil && record.Slug != "") {
|
||||
updated, err := h.upsertTenantCSVRecord(c, record)
|
||||
tenant, updated, err := h.upsertTenantCSVRecord(c, record)
|
||||
if err != nil {
|
||||
result.Failed++
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("row %d: %s", rowNumber, err.Error()))
|
||||
continue
|
||||
}
|
||||
if updated {
|
||||
tenantIDBySlug[strings.ToLower(record.Slug)] = tenant.ID
|
||||
result.Updated++
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.createTenantCSVRecord(c, record, creatorID); err != nil {
|
||||
tenant, err := h.createTenantCSVRecord(c, record, creatorID)
|
||||
if err != nil {
|
||||
result.Failed++
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("row %d: %s", rowNumber, err.Error()))
|
||||
continue
|
||||
}
|
||||
if tenant == nil {
|
||||
result.Failed++
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("row %d: tenant creation returned empty result", rowNumber))
|
||||
continue
|
||||
}
|
||||
tenantIDBySlug[strings.ToLower(record.Slug)] = tenant.ID
|
||||
result.Created++
|
||||
}
|
||||
|
||||
@@ -414,13 +473,14 @@ func parseTenantCSVRecords(r io.Reader) ([]tenantCSVRecord, error) {
|
||||
}
|
||||
|
||||
records = append(records, tenantCSVRecord{
|
||||
TenantID: tenantCSVValue(row, header, "tenant_id"),
|
||||
Name: name,
|
||||
Type: tenantType,
|
||||
ParentTenantID: parentID,
|
||||
Slug: slug,
|
||||
Memo: tenantCSVValue(row, header, "memo"),
|
||||
Domains: splitTenantCSVDomains(tenantCSVValue(row, header, "email_domain")),
|
||||
TenantID: tenantCSVValue(row, header, "tenant_id"),
|
||||
Name: name,
|
||||
Type: tenantType,
|
||||
ParentTenantID: parentID,
|
||||
ParentTenantSlug: tenantCSVValue(row, header, "parent_tenant_slug"),
|
||||
Slug: slug,
|
||||
Memo: tenantCSVValue(row, header, "memo"),
|
||||
Domains: splitTenantCSVDomains(tenantCSVValue(row, header, "email_domain")),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -430,23 +490,25 @@ func parseTenantCSVRecords(r io.Reader) ([]tenantCSVRecord, error) {
|
||||
func tenantCSVHeaderIndex(header []string) map[string]int {
|
||||
index := make(map[string]int, len(header))
|
||||
aliases := map[string]string{
|
||||
"id": "tenant_id",
|
||||
"tenantid": "tenant_id",
|
||||
"tenant_id": "tenant_id",
|
||||
"name": "name",
|
||||
"type": "type",
|
||||
"parentid": "parent_tenant_id",
|
||||
"parent_id": "parent_tenant_id",
|
||||
"parenttenantid": "parent_tenant_id",
|
||||
"parent_tenant_id": "parent_tenant_id",
|
||||
"slug": "slug",
|
||||
"memo": "memo",
|
||||
"description": "memo",
|
||||
"email-domain": "email_domain",
|
||||
"emaildomain": "email_domain",
|
||||
"email_domain": "email_domain",
|
||||
"domain": "email_domain",
|
||||
"domains": "email_domain",
|
||||
"id": "tenant_id",
|
||||
"tenantid": "tenant_id",
|
||||
"tenant_id": "tenant_id",
|
||||
"name": "name",
|
||||
"type": "type",
|
||||
"parentid": "parent_tenant_id",
|
||||
"parent_id": "parent_tenant_id",
|
||||
"parenttenantid": "parent_tenant_id",
|
||||
"parent_tenant_id": "parent_tenant_id",
|
||||
"parenttenantslug": "parent_tenant_slug",
|
||||
"parent_tenant_slug": "parent_tenant_slug",
|
||||
"slug": "slug",
|
||||
"memo": "memo",
|
||||
"description": "memo",
|
||||
"email-domain": "email_domain",
|
||||
"emaildomain": "email_domain",
|
||||
"email_domain": "email_domain",
|
||||
"domain": "email_domain",
|
||||
"domains": "email_domain",
|
||||
}
|
||||
for i, column := range header {
|
||||
key := strings.ToLower(strings.TrimSpace(column))
|
||||
@@ -475,6 +537,40 @@ func tenantCSVRowIsEmpty(row []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func includeCSVIds(c *fiber.Ctx) bool {
|
||||
value := strings.ToLower(strings.TrimSpace(c.Query("includeIds")))
|
||||
return value == "true" || value == "1" || value == "yes"
|
||||
}
|
||||
|
||||
func orderTenantCSVRecordsByParentSlug(records []tenantCSVRecord) []tenantCSVRecord {
|
||||
bySlug := make(map[string]tenantCSVRecord, len(records))
|
||||
for _, record := range records {
|
||||
bySlug[strings.ToLower(record.Slug)] = record
|
||||
}
|
||||
|
||||
ordered := make([]tenantCSVRecord, 0, len(records))
|
||||
visited := make(map[string]bool, len(records))
|
||||
var visit func(record tenantCSVRecord)
|
||||
visit = func(record tenantCSVRecord) {
|
||||
key := strings.ToLower(record.Slug)
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
if record.ParentTenantSlug != "" {
|
||||
if parent, ok := bySlug[strings.ToLower(record.ParentTenantSlug)]; ok {
|
||||
visit(parent)
|
||||
}
|
||||
}
|
||||
visited[key] = true
|
||||
ordered = append(ordered, record)
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
visit(record)
|
||||
}
|
||||
return ordered
|
||||
}
|
||||
|
||||
func splitTenantCSVDomains(value string) []string {
|
||||
value = strings.ReplaceAll(value, "\n", ";")
|
||||
value = strings.ReplaceAll(value, ",", ";")
|
||||
@@ -492,12 +588,203 @@ func splitTenantCSVDomains(value string) []string {
|
||||
return domains
|
||||
}
|
||||
|
||||
func (h *TenantHandler) upsertTenantCSVRecord(c *fiber.Ctx, record tenantCSVRecord) (bool, error) {
|
||||
func normalizeTenantDomainInputs(values []string) []string {
|
||||
seen := make(map[string]bool, len(values))
|
||||
domains := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
for _, part := range strings.FieldsFunc(value, func(r rune) bool {
|
||||
return r == ',' || r == ';' || r == '\n' || r == '\r' || r == '\t' || r == ' '
|
||||
}) {
|
||||
domainName := strings.ToLower(strings.TrimSpace(part))
|
||||
if domainName == "" || seen[domainName] {
|
||||
continue
|
||||
}
|
||||
seen[domainName] = true
|
||||
domains = append(domains, domainName)
|
||||
}
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
func normalizeTenantConfig(config map[string]any) (domain.JSONMap, error) {
|
||||
normalized := make(domain.JSONMap, len(config))
|
||||
for key, value := range config {
|
||||
if key == "userSchema" {
|
||||
fields, err := normalizeTenantUserSchema(value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalized[key] = fields
|
||||
continue
|
||||
}
|
||||
normalized[key] = value
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func normalizeTenantUserSchema(value any) ([]any, error) {
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rawFields, ok := value.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("userSchema must be an array")
|
||||
}
|
||||
|
||||
fields := make([]any, 0, len(rawFields))
|
||||
for _, raw := range rawFields {
|
||||
field, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("userSchema fields must be objects")
|
||||
}
|
||||
|
||||
normalized := make(map[string]any, len(field))
|
||||
for key, value := range field {
|
||||
if key == "maxLength" {
|
||||
continue
|
||||
}
|
||||
normalized[key] = value
|
||||
}
|
||||
|
||||
isLoginID, _ := normalized["isLoginId"].(bool)
|
||||
if isLoginID {
|
||||
fieldType, _ := normalized["type"].(string)
|
||||
if fieldType != "" && fieldType != "text" {
|
||||
return nil, fmt.Errorf("login ID fields must be text")
|
||||
}
|
||||
normalized["type"] = "text"
|
||||
normalized["indexed"] = true
|
||||
} else if indexed, ok := normalized["indexed"].(bool); !ok || !indexed {
|
||||
normalized["indexed"] = false
|
||||
}
|
||||
|
||||
fields = append(fields, normalized)
|
||||
}
|
||||
|
||||
return fields, nil
|
||||
}
|
||||
|
||||
func normalizeTenantDomainForceSet(values []string) map[string]bool {
|
||||
domains := normalizeTenantDomainInputs(values)
|
||||
force := make(map[string]bool, len(domains))
|
||||
for _, domainName := range domains {
|
||||
force[domainName] = true
|
||||
}
|
||||
return force
|
||||
}
|
||||
|
||||
func tenantDomainConflictJSON(c *fiber.Ctx, conflicts []tenantDomainConflict) error {
|
||||
return c.Status(fiber.StatusConflict).JSON(fiber.Map{
|
||||
"code": "tenant_domain_conflict",
|
||||
"error": "domain is already assigned to another tenant",
|
||||
"conflicts": conflicts,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *TenantHandler) findTenantDomainConflicts(ctx context.Context, tenantID string, domains []string, forceDomains []string) ([]tenantDomainConflict, error) {
|
||||
if h.DB == nil || h.DB.Config == nil || len(domains) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
force := normalizeTenantDomainForceSet(forceDomains)
|
||||
var rows []domain.TenantDomain
|
||||
query := h.DB.WithContext(ctx).Where("domain IN ?", domains)
|
||||
if tenantID != "" {
|
||||
query = query.Where("tenant_id <> ?", tenantID)
|
||||
}
|
||||
if err := query.Find(&rows).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conflicts := make([]tenantDomainConflict, 0, len(rows))
|
||||
tenantIDs := make([]string, 0, len(rows))
|
||||
seenTenantIDs := make(map[string]bool, len(rows))
|
||||
for _, row := range rows {
|
||||
if force[row.Domain] {
|
||||
continue
|
||||
}
|
||||
if !seenTenantIDs[row.TenantID] {
|
||||
seenTenantIDs[row.TenantID] = true
|
||||
tenantIDs = append(tenantIDs, row.TenantID)
|
||||
}
|
||||
}
|
||||
|
||||
tenantsByID := make(map[string]domain.Tenant, len(tenantIDs))
|
||||
if len(tenantIDs) > 0 {
|
||||
var tenants []domain.Tenant
|
||||
if err := h.DB.WithContext(ctx).Where("id IN ?", tenantIDs).Find(&tenants).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, tenant := range tenants {
|
||||
tenantsByID[tenant.ID] = tenant
|
||||
}
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
if force[row.Domain] {
|
||||
continue
|
||||
}
|
||||
conflict := tenantDomainConflict{
|
||||
Domain: row.Domain,
|
||||
TenantID: row.TenantID,
|
||||
}
|
||||
if tenant, ok := tenantsByID[row.TenantID]; ok {
|
||||
conflict.TenantName = tenant.Name
|
||||
conflict.TenantSlug = tenant.Slug
|
||||
}
|
||||
conflicts = append(conflicts, conflict)
|
||||
}
|
||||
|
||||
return conflicts, nil
|
||||
}
|
||||
|
||||
func (h *TenantHandler) replaceTenantDomains(ctx context.Context, tenantID string, domains []string, forceDomains []string) error {
|
||||
if h.DB == nil {
|
||||
return errors.New("database not available")
|
||||
}
|
||||
if h.DB.Config == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
deleteQuery := h.DB.WithContext(ctx).Where("tenant_id = ?", tenantID)
|
||||
if len(domains) > 0 {
|
||||
deleteQuery = deleteQuery.Where("domain NOT IN ?", domains)
|
||||
}
|
||||
if err := deleteQuery.Delete(&domain.TenantDomain{}).Error; err != nil {
|
||||
return fmt.Errorf("failed to clear old domains: %w", err)
|
||||
}
|
||||
|
||||
for _, domainName := range domains {
|
||||
var existing domain.TenantDomain
|
||||
err := h.DB.WithContext(ctx).Unscoped().
|
||||
Where("tenant_id = ? AND domain = ?", tenantID, domainName).
|
||||
First(&existing).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
if err := repository.NewTenantRepository(h.DB).AddDomain(ctx, tenantID, domainName, true); err != nil {
|
||||
return fmt.Errorf("failed to add domain: %s", domainName)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.DB.WithContext(ctx).Unscoped().Model(&existing).Updates(map[string]any{
|
||||
"verified": true,
|
||||
"deleted_at": nil,
|
||||
}).Error; err != nil {
|
||||
return fmt.Errorf("failed to add domain: %s", domainName)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *TenantHandler) upsertTenantCSVRecord(c *fiber.Ctx, record tenantCSVRecord) (*domain.Tenant, bool, error) {
|
||||
if h.DB == nil {
|
||||
if record.TenantID != "" {
|
||||
return false, errors.New("database not available for tenant update")
|
||||
return nil, false, errors.New("database not available for tenant update")
|
||||
}
|
||||
return false, nil
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
var tenant domain.Tenant
|
||||
@@ -510,10 +797,10 @@ func (h *TenantHandler) upsertTenantCSVRecord(c *fiber.Ctx, record tenantCSVReco
|
||||
}
|
||||
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return false, nil
|
||||
return nil, false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
tenant.Name = record.Name
|
||||
@@ -526,29 +813,29 @@ func (h *TenantHandler) upsertTenantCSVRecord(c *fiber.Ctx, record tenantCSVReco
|
||||
}
|
||||
|
||||
if err := h.DB.Save(&tenant).Error; err != nil {
|
||||
return false, err
|
||||
return nil, false, err
|
||||
}
|
||||
if err := h.DB.Delete(&domain.TenantDomain{}, "tenant_id = ?", tenant.ID).Error; err != nil {
|
||||
return false, err
|
||||
return nil, false, err
|
||||
}
|
||||
repo := repository.NewTenantRepository(h.DB)
|
||||
for _, domainName := range record.Domains {
|
||||
if err := repo.AddDomain(c.Context(), tenant.ID, domainName, true); err != nil {
|
||||
return false, err
|
||||
return nil, false, err
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
return &tenant, true, nil
|
||||
}
|
||||
|
||||
func (h *TenantHandler) createTenantCSVRecord(c *fiber.Ctx, record tenantCSVRecord, creatorID string) error {
|
||||
func (h *TenantHandler) createTenantCSVRecord(c *fiber.Ctx, record tenantCSVRecord, creatorID string) (*domain.Tenant, error) {
|
||||
if h.DB != nil && record.TenantID != "" {
|
||||
var exists int64
|
||||
if err := h.DB.Unscoped().Model(&domain.Tenant{}).Where("slug = ?", record.Slug).Count(&exists).Error; err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if exists > 0 {
|
||||
return errors.New("tenant slug already exists")
|
||||
return nil, errors.New("tenant slug already exists")
|
||||
}
|
||||
|
||||
tenant := domain.Tenant{
|
||||
@@ -561,7 +848,7 @@ func (h *TenantHandler) createTenantCSVRecord(c *fiber.Ctx, record tenantCSVReco
|
||||
Status: domain.TenantStatusActive,
|
||||
}
|
||||
if err := h.DB.Create(&tenant).Error; err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if h.KetoOutbox != nil {
|
||||
_ = h.KetoOutbox.Create(c.Context(), &domain.KetoOutbox{
|
||||
@@ -595,14 +882,14 @@ func (h *TenantHandler) createTenantCSVRecord(c *fiber.Ctx, record tenantCSVReco
|
||||
repo := repository.NewTenantRepository(h.DB)
|
||||
for _, domainName := range record.Domains {
|
||||
if err := repo.AddDomain(c.Context(), tenant.ID, domainName, true); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return &tenant, nil
|
||||
}
|
||||
|
||||
_, err := h.Service.RegisterTenant(c.Context(), record.Name, record.Slug, record.Type, record.Memo, record.Domains, record.ParentTenantID, creatorID)
|
||||
return err
|
||||
tenant, err := h.Service.RegisterTenant(c.Context(), record.Name, record.Slug, record.Type, record.Memo, record.Domains, record.ParentTenantID, creatorID)
|
||||
return tenant, err
|
||||
}
|
||||
|
||||
func (h *TenantHandler) GetTenant(c *fiber.Ctx) error {
|
||||
@@ -646,14 +933,15 @@ func (h *TenantHandler) CreateTenant(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Status string `json:"status"`
|
||||
Domains []string `json:"domains"`
|
||||
ParentID *string `json:"parentId"`
|
||||
Config map[string]any `json:"config"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Status string `json:"status"`
|
||||
Domains []string `json:"domains"`
|
||||
ForceDomains []string `json:"forceDomainConflicts"`
|
||||
ParentID *string `json:"parentId"`
|
||||
Config map[string]any `json:"config"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
|
||||
@@ -701,7 +989,16 @@ func (h *TenantHandler) CreateTenant(c *fiber.Ctx) error {
|
||||
creatorID = profile.ID
|
||||
}
|
||||
|
||||
tenant, err := h.Service.RegisterTenant(c.Context(), name, slug, tenantType, req.Description, req.Domains, parentID, creatorID)
|
||||
normalizedDomains := normalizeTenantDomainInputs(req.Domains)
|
||||
conflicts, err := h.findTenantDomainConflicts(c.Context(), "", normalizedDomains, req.ForceDomains)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
if len(conflicts) > 0 {
|
||||
return tenantDomainConflictJSON(c, conflicts)
|
||||
}
|
||||
|
||||
tenant, err := h.Service.RegisterTenant(c.Context(), name, slug, tenantType, req.Description, nil, parentID, creatorID)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "already exists") {
|
||||
return errorJSON(c, fiber.StatusConflict, err.Error())
|
||||
@@ -713,10 +1010,20 @@ func (h *TenantHandler) CreateTenant(c *fiber.Ctx) error {
|
||||
summary.MemberCount = 0
|
||||
|
||||
if req.Config != nil {
|
||||
tenant.Config = req.Config
|
||||
config, err := normalizeTenantConfig(req.Config)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, err.Error())
|
||||
}
|
||||
tenant.Config = config
|
||||
h.DB.Save(tenant)
|
||||
summary.Config = tenant.Config
|
||||
}
|
||||
if err := h.replaceTenantDomains(c.Context(), tenant.ID, normalizedDomains, req.ForceDomains); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
if len(normalizedDomains) > 0 {
|
||||
summary.Domains = normalizedDomains
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusCreated).JSON(summary)
|
||||
}
|
||||
@@ -740,14 +1047,15 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name *string `json:"name"`
|
||||
Type *string `json:"type"`
|
||||
Slug *string `json:"slug"`
|
||||
Description *string `json:"description"`
|
||||
Status *string `json:"status"`
|
||||
ParentID *string `json:"parentId"`
|
||||
Domains []string `json:"domains"`
|
||||
Config map[string]any `json:"config"`
|
||||
Name *string `json:"name"`
|
||||
Type *string `json:"type"`
|
||||
Slug *string `json:"slug"`
|
||||
Description *string `json:"description"`
|
||||
Status *string `json:"status"`
|
||||
ParentID *string `json:"parentId"`
|
||||
Domains []string `json:"domains"`
|
||||
ForceDomains []string `json:"forceDomainConflicts"`
|
||||
Config map[string]any `json:"config"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
|
||||
@@ -835,7 +1143,11 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
|
||||
}
|
||||
}
|
||||
if req.Config != nil {
|
||||
tenant.Config = req.Config
|
||||
config, err := normalizeTenantConfig(req.Config)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, err.Error())
|
||||
}
|
||||
tenant.Config = config
|
||||
}
|
||||
|
||||
if err := h.DB.Save(&tenant).Error; err != nil {
|
||||
@@ -844,18 +1156,16 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
|
||||
|
||||
// Update domains if provided
|
||||
if req.Domains != nil {
|
||||
// Simple approach: Delete existing and recreate
|
||||
if err := h.DB.Delete(&domain.TenantDomain{}, "tenant_id = ?", tenant.ID).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "failed to clear old domains")
|
||||
normalizedDomains := normalizeTenantDomainInputs(req.Domains)
|
||||
conflicts, err := h.findTenantDomainConflicts(c.Context(), tenant.ID, normalizedDomains, req.ForceDomains)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
for _, d := range req.Domains {
|
||||
if strings.TrimSpace(d) == "" {
|
||||
continue
|
||||
}
|
||||
// Use repository for consistency
|
||||
if err := repository.NewTenantRepository(h.DB).AddDomain(c.Context(), tenant.ID, strings.TrimSpace(d), true); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "failed to add domain: "+d)
|
||||
}
|
||||
if len(conflicts) > 0 {
|
||||
return tenantDomainConflictJSON(c, conflicts)
|
||||
}
|
||||
if err := h.replaceTenantDomains(c.Context(), tenant.ID, normalizedDomains, req.ForceDomains); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user