1
0
forked from baron/baron-sso

feat: implement dynamic tenant provisioning and remove hardcoded company mappings

This commit is contained in:
2026-04-06 16:13:03 +09:00
parent 003f12f008
commit c78604df06
9 changed files with 125 additions and 67 deletions

View File

@@ -521,6 +521,14 @@ func (h *AuthHandler) Signup(c *fiber.Ctx) error {
slog.Warn("[Signup] Attempted to join non-active tenant by domain", "email", req.Email, "tenant", tenant.Slug, "status", tenant.Status)
return errorJSON(c, fiber.StatusForbidden, "Your organization's tenant is currently not active.")
}
} else {
// [New Policy] Try dynamic provisioning via Group Policies if tenant doesn't exist
tenant, err := h.TenantService.ProvisionTenantByDomain(c.Context(), domainName)
if err == nil && tenant != nil {
slog.Info("[Signup] Auto-provisioned tenant via group policy", "email", req.Email, "tenant", tenant.Slug)
companyCode = tenant.Slug
tenantID = &tenant.ID
}
}
}
@@ -529,8 +537,6 @@ func (h *AuthHandler) Signup(c *fiber.Ctx) error {
tenant, err := h.TenantService.GetTenantBySlug(c.Context(), req.CompanyCode)
if err == nil && tenant != nil {
if tenant.Status == domain.TenantStatusActive {
// Policy: Should we allow manual joining by Slug?
// For now, let's allow it but log it as manual.
slog.Info("[Signup] Assigning tenant by manual slug", "email", req.Email, "tenant", tenant.Slug)
companyCode = tenant.Slug
tenantID = &tenant.ID
@@ -538,54 +544,18 @@ func (h *AuthHandler) Signup(c *fiber.Ctx) error {
return errorJSON(c, fiber.StatusForbidden, "The specified organization is not active.")
}
} else {
// If companyCode provided but not found, we automatically create one
// [New Policy] 자동 생성 로직 추가
slog.Info("[Signup] CompanyCode not found, creating new tenant automatically", "slug", req.CompanyCode)
// Determine name from CompanyCode
tenantName := req.CompanyCode
// Map slug to localized name if possible
slugToName := map[string]string{
"HANMAC": "한맥",
"SAMAN": "삼안",
"JANGHEON": "장헌",
"HALLA": "한라",
"PTC": "PTC",
"BARON": "바론",
}
if name, ok := slugToName[strings.ToUpper(req.CompanyCode)]; ok {
tenantName = name
}
// Create the tenant
// Note: creatorID is unknown at this point, will be set via Read-Model sync later
newTenant, err := h.TenantService.RegisterTenant(c.Context(),
tenantName,
req.CompanyCode,
domain.TenantTypeCompany,
"Automatically created during signup",
nil, // domains
nil, // parentID
"", // creatorID (will sync later)
)
if err != nil {
// Handle race condition: if tenant was created by another request just now
if strings.Contains(err.Error(), "already exists") {
newTenant, err = h.TenantService.GetTenantBySlug(c.Context(), req.CompanyCode)
}
if err != nil || newTenant == nil {
slog.Error("[Signup] Failed to create tenant automatically", "slug", req.CompanyCode, "error", err)
return errorJSON(c, fiber.StatusInternalServerError, "Failed to initialize organization.")
}
}
slog.Info("[Signup] Successfully created missing tenant", "slug", req.CompanyCode, "id", newTenant.ID)
tenantID = &newTenant.ID
companyCode = newTenant.Slug
// [New Policy] Do NOT create tenants automatically with hardcoded names.
// Only allow joining existing tenants.
slog.Warn("[Signup] Attempted to join non-existent organization", "slug", req.CompanyCode, "email", req.Email)
return errorJSON(c, fiber.StatusNotFound, "The specified organization code was not found. Please contact your administrator.")
}
}
if tenantID == nil {
slog.Warn("[Signup] No tenant assigned to user", "email", req.Email)
return errorJSON(c, fiber.StatusBadRequest, "We couldn't identify your organization. Please provide a company code or use your corporate email.")
}
// Normalize Phone (E.164 형태로 보관)
normalizedPhone := strings.ReplaceAll(req.Phone, "-", "")
normalizedPhone = strings.ReplaceAll(normalizedPhone, " ", "")

View File

@@ -198,7 +198,10 @@ func (m *AsyncMockTenantService) IsDomainAllowed(ctx context.Context, domainName
return false, nil
}
func (m *AsyncMockTenantService) ApproveTenant(ctx context.Context, id string) error { return nil }
func (m *AsyncMockTenantService) SetKetoService(keto service.KetoService) {}
func (m *AsyncMockTenantService) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) SetKetoService(keto service.KetoService) {}
func (m *AsyncMockTenantService) AddTenantAdmin(ctx context.Context, tenantID, userID string) error {
return nil
}
@@ -269,7 +272,9 @@ func TestSignup_AsyncDB_Isolation(t *testing.T) {
mockRedis.On("Delete", phoneKey).Return(nil)
// Tenant Mocks
mockTenant.On("GetTenantByDomain", mock.Anything, "example.com").Return(nil, errors.New("not found"))
validTenant := &domain.Tenant{ID: "t1", Slug: "example", Status: domain.TenantStatusActive}
mockTenant.On("GetTenantByDomain", mock.Anything, "example.com").Return(validTenant, nil)
mockTenant.On("GetTenant", mock.Anything, "t1").Return(validTenant, nil)
// Kratos Mocks (Success)
mockIdp.On("CreateUser", mock.Anything, "Password123!").Return("new-user-uuid", nil)

View File

@@ -5,6 +5,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
@@ -98,7 +99,7 @@ func TestSignup_CompanyCodeValidation(t *testing.T) {
})
mockRedis.On("Get", mock.Anything).Return(string(verifiedState), nil)
t.Run("Create Tenant if CompanyCode Missing", func(t *testing.T) {
t.Run("Fail - Tenant not found for CompanyCode", func(t *testing.T) {
reqBody := domain.SignupRequest{
Email: "user@gmail.com",
Password: "StrongPass123!",
@@ -109,20 +110,15 @@ func TestSignup_CompanyCodeValidation(t *testing.T) {
}
body, _ := json.Marshal(reqBody)
newTenant := &domain.Tenant{ID: "t_new", Slug: "new-slug", Status: domain.TenantStatusActive}
mockTenantSvc.On("GetTenantByDomain", mock.Anything, "gmail.com").Return(nil, nil)
mockTenantSvc.On("GetTenantBySlug", mock.Anything, "new-slug").Return(nil, nil)
mockTenantSvc.On("RegisterTenant", mock.Anything, "new-slug", "new-slug", domain.TenantTypeCompany, mock.Anything, mock.Anything, mock.Anything, "").Return(newTenant, nil)
mockTenantSvc.On("GetTenant", mock.Anything, "t_new").Return(newTenant, nil)
mockIdp.On("CreateUser", mock.Anything, mock.Anything).Return("user-id", nil)
mockRedis.On("Delete", mock.Anything).Return(nil)
mockTenantSvc.On("GetTenantByDomain", mock.Anything, "gmail.com").Return(nil, nil).Once()
mockTenantSvc.On("ProvisionTenantByDomain", mock.Anything, "gmail.com").Return(nil, errors.New("not found")).Once()
mockTenantSvc.On("GetTenantBySlug", mock.Anything, "new-slug").Return(nil, nil).Once()
req := httptest.NewRequest("POST", "/signup", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
})
t.Run("Active Company Code", func(t *testing.T) {
@@ -137,10 +133,11 @@ func TestSignup_CompanyCodeValidation(t *testing.T) {
body, _ := json.Marshal(reqBody)
validTenant := &domain.Tenant{ID: "t1", Slug: "valid-slug", Status: domain.TenantStatusActive}
mockTenantSvc.On("GetTenantByDomain", mock.Anything, "gmail.com").Return(nil, nil)
mockTenantSvc.On("GetTenantBySlug", mock.Anything, "valid-slug").Return(validTenant, nil)
mockTenantSvc.On("GetTenant", mock.Anything, "t1").Return(validTenant, nil)
mockIdp.On("CreateUser", mock.Anything, mock.Anything).Return("user-id", nil)
mockTenantSvc.On("GetTenantByDomain", mock.Anything, "gmail.com").Return(nil, nil).Once()
mockTenantSvc.On("ProvisionTenantByDomain", mock.Anything, "gmail.com").Return(nil, errors.New("not found")).Once()
mockTenantSvc.On("GetTenantBySlug", mock.Anything, "valid-slug").Return(validTenant, nil).Once()
mockTenantSvc.On("GetTenant", mock.Anything, "t1").Return(validTenant, nil).Once()
mockIdp.On("CreateUser", mock.Anything, mock.Anything).Return("user-id", nil).Once()
mockRedis.On("Delete", mock.Anything).Return(nil)
req := httptest.NewRequest("POST", "/signup", bytes.NewReader(body))

View File

@@ -88,6 +88,14 @@ func (m *MockTenantService) SetKetoService(keto service.KetoService) {
m.Called(keto)
}
func (m *MockTenantService) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
args := m.Called(ctx, domainName)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
type MockUserRepoForHandler struct {
mock.Mock
}

View File

@@ -103,6 +103,14 @@ func (m *MockTenantServiceForUser) ListManageableTenants(ctx context.Context, us
return args.Get(0).([]domain.Tenant), args.Error(1)
}
func (m *MockTenantServiceForUser) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
args := m.Called(ctx, domainName)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
// --- Tests ---
func TestUserHandler_BulkCreateUsers(t *testing.T) {

View File

@@ -18,6 +18,7 @@ type TenantRepository interface {
FindByIDs(ctx context.Context, ids []string) ([]domain.Tenant, error)
AddDomain(ctx context.Context, tenantID string, domainName string, verified bool) error
List(ctx context.Context, limit, offset int, parentID string) ([]domain.Tenant, int64, error)
ListByType(ctx context.Context, tenantType string) ([]domain.Tenant, error)
}
type tenantRepository struct {
@@ -112,3 +113,11 @@ func (r *tenantRepository) List(ctx context.Context, limit, offset int, parentID
return tenants, total, nil
}
func (r *tenantRepository) ListByType(ctx context.Context, tenantType string) ([]domain.Tenant, error) {
var tenants []domain.Tenant
if err := r.db.WithContext(ctx).Where("type = ?", tenantType).Preload("Domains").Find(&tenants).Error; err != nil {
return nil, err
}
return tenants, nil
}

View File

@@ -22,6 +22,7 @@ type TenantService interface {
ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error)
IsDomainAllowed(ctx context.Context, domainName string) (bool, error)
ApproveTenant(ctx context.Context, id string) error
ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) // 추가
SetKetoService(keto KetoService) // 추가
}
@@ -311,3 +312,51 @@ func (s *tenantService) IsDomainAllowed(ctx context.Context, domainName string)
}
return tenant != nil && tenant.Status == domain.TenantStatusActive, nil
}
func (s *tenantService) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
// 1. Find all COMPANY_GROUP tenants
groups, err := s.repo.ListByType(ctx, domain.TenantTypeCompanyGroup)
if err != nil {
return nil, err
}
for _, g := range groups {
// 2. Check autoProvisioning config
rawConfig, ok := g.Config["autoProvisioning"].(map[string]interface{})
if !ok {
continue
}
enabled, _ := rawConfig["enabled"].(bool)
if !enabled {
continue
}
mapping, ok := rawConfig["mappingRules"].(map[string]interface{})
if !ok {
continue
}
// 3. Find rule for this domain
rule, ok := mapping[domainName].(map[string]interface{})
if !ok {
continue
}
slug, _ := rule["slug"].(string)
name, _ := rule["name"].(string)
if slug == "" || name == "" {
continue
}
// 4. Create new sub-tenant under this group
slog.Info("[Provisioning] Found rule for domain, creating sub-tenant", "domain", domainName, "parent", g.Slug, "newTenant", slug)
// Use RegisterTenant to handle DB creation and Keto Outbox sync
// creatorID is empty as per security policy (manual delegation later)
return s.RegisterTenant(ctx, name, slug, domain.TenantTypeCompany, "Automatically provisioned via group policy", []string{domainName}, &g.ID, "")
}
return nil, gorm.ErrRecordNotFound
}

View File

@@ -64,6 +64,14 @@ func (m *MockTenantRepoForSvc) List(ctx context.Context, limit, offset int, pare
return args.Get(0).([]domain.Tenant), int64(args.Int(1)), args.Error(2)
}
func (m *MockTenantRepoForSvc) ListByType(ctx context.Context, tenantType string) ([]domain.Tenant, error) {
args := m.Called(ctx, tenantType)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]domain.Tenant), args.Error(1)
}
type MockKetoSvcForTenant struct {
mock.Mock
}

View File

@@ -158,14 +158,18 @@ func (m *MockTenantRepository) FindByDomain(ctx context.Context, domainName stri
return nil, nil
}
func (m *MockTenantRepository) AddDomain(ctx context.Context, tenantID string, domainName string, verified bool) error {
return nil
}
func (m *MockTenantRepository) List(ctx context.Context, limit, offset int, parentID string) ([]domain.Tenant, int64, error) {
return nil, 0, nil
}
func (m *MockTenantRepository) ListByType(ctx context.Context, tenantType string) ([]domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantRepository) AddDomain(ctx context.Context, tenantID string, domainName string, verified bool) error {
return nil
}
func TestUserGroupService_Create(t *testing.T) {
mockRepo := new(MockUserGroupRepository)
mockTenantRepo := new(MockTenantRepository)