forked from baron/baron-sso
368 lines
12 KiB
Go
368 lines
12 KiB
Go
package handler
|
|
|
|
import (
|
|
"baron-sso-backend/internal/domain"
|
|
"baron-sso-backend/internal/service"
|
|
"baron-sso-backend/internal/testsupport"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/testcontainers/testcontainers-go"
|
|
postgres_module "github.com/testcontainers/testcontainers-go/modules/postgres"
|
|
"github.com/testcontainers/testcontainers-go/wait"
|
|
gorm_postgres "gorm.io/driver/postgres"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
func newTenantHandlerSeedDeleteDB(t *testing.T) *gorm.DB {
|
|
t.Helper()
|
|
if !testsupport.DockerAvailable() {
|
|
t.Skip("Docker provider is unavailable in this environment")
|
|
}
|
|
|
|
ctx := context.Background()
|
|
postgresContainer, err := postgres_module.Run(ctx,
|
|
"postgres:16-alpine",
|
|
postgres_module.WithDatabase("testdb"),
|
|
postgres_module.WithUsername("user"),
|
|
postgres_module.WithPassword("password"),
|
|
testcontainers.WithWaitStrategy(
|
|
wait.ForLog("database system is ready to accept connections").
|
|
WithOccurrence(2).
|
|
WithStartupTimeout(30*time.Second)),
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("failed to start postgres container: %v", err)
|
|
}
|
|
t.Cleanup(func() {
|
|
if err := postgresContainer.Terminate(ctx); err != nil {
|
|
log.Printf("failed to terminate postgres container: %v", err)
|
|
}
|
|
})
|
|
|
|
connStr, err := postgresContainer.ConnectionString(ctx, "sslmode=disable")
|
|
if err != nil {
|
|
t.Fatalf("failed to get postgres connection string: %v", err)
|
|
}
|
|
db, err := gorm.Open(gorm_postgres.Open(connStr), &gorm.Config{})
|
|
if err != nil {
|
|
t.Fatalf("failed to open postgres connection: %v", err)
|
|
}
|
|
if err := db.AutoMigrate(&domain.Tenant{}, &domain.User{}, &domain.UserLoginID{}); err != nil {
|
|
t.Fatalf("failed to migrate tenant delete models: %v", err)
|
|
}
|
|
return db
|
|
}
|
|
|
|
func setSeedTenantCSVForDeleteGuard(t *testing.T, slug string) {
|
|
t.Helper()
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "seed-tenant.csv")
|
|
csv := "name,type,parent_tenant_slug,slug,memo,email_domain\n" +
|
|
"Protected,COMPANY_GROUP,," + slug + ",Protected seed,\n"
|
|
if err := os.WriteFile(path, []byte(csv), 0o600); err != nil {
|
|
t.Fatalf("failed to write seed csv: %v", err)
|
|
}
|
|
t.Setenv("SEED_TENANT_CSV_PATH", path)
|
|
}
|
|
|
|
func TestTenantHandlerDeleteTenantRejectsSeedTenant(t *testing.T) {
|
|
setSeedTenantCSVForDeleteGuard(t, "protected-root")
|
|
db := newTenantHandlerSeedDeleteDB(t)
|
|
tenant := domain.Tenant{
|
|
ID: "00000000-0000-0000-0000-000000000001",
|
|
Name: "Protected",
|
|
Slug: "protected-root",
|
|
Type: domain.TenantTypeCompanyGroup,
|
|
Status: domain.TenantStatusActive,
|
|
}
|
|
if err := db.Create(&tenant).Error; err != nil {
|
|
t.Fatalf("failed to create tenant: %v", err)
|
|
}
|
|
|
|
app := fiber.New()
|
|
app.Delete("/tenants/:id", (&TenantHandler{DB: db}).DeleteTenant)
|
|
req := httptest.NewRequest(http.MethodDelete, "/tenants/"+tenant.ID, nil)
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("request failed: %v", err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusConflict {
|
|
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusConflict)
|
|
}
|
|
var count int64
|
|
if err := db.Model(&domain.Tenant{}).Where("id = ?", tenant.ID).Count(&count).Error; err != nil {
|
|
t.Fatalf("count tenant: %v", err)
|
|
}
|
|
if count != 1 {
|
|
t.Fatalf("seed tenant count = %d, want 1", count)
|
|
}
|
|
}
|
|
|
|
func TestTenantHandlerDeleteTenantReassignsUserMembershipsToParentTenant(t *testing.T) {
|
|
db := newTenantHandlerSeedDeleteDB(t)
|
|
parent := domain.Tenant{
|
|
ID: "10000000-0000-0000-0000-000000000001",
|
|
Name: "Parent",
|
|
Slug: "delete-policy-parent",
|
|
Type: domain.TenantTypeCompany,
|
|
Status: domain.TenantStatusActive,
|
|
}
|
|
child := domain.Tenant{
|
|
ID: "10000000-0000-0000-0000-000000000002",
|
|
Name: "Collaboration",
|
|
Slug: "delete-policy-collaboration",
|
|
Type: domain.TenantTypeUserGroup,
|
|
ParentID: &parent.ID,
|
|
Status: domain.TenantStatusActive,
|
|
}
|
|
user := domain.User{
|
|
ID: "10000000-0000-0000-0000-000000000101",
|
|
Email: "delete-policy-user@example.com",
|
|
Name: "Delete Policy User",
|
|
Role: domain.RoleUser,
|
|
TenantID: &child.ID,
|
|
}
|
|
loginID := domain.UserLoginID{
|
|
ID: "10000000-0000-0000-0000-000000000201",
|
|
UserID: user.ID,
|
|
TenantID: child.ID,
|
|
FieldKey: "employee_number",
|
|
LoginID: "delete-policy-user",
|
|
}
|
|
if err := db.Create(&parent).Error; err != nil {
|
|
t.Fatalf("failed to create parent tenant: %v", err)
|
|
}
|
|
if err := db.Create(&child).Error; err != nil {
|
|
t.Fatalf("failed to create child tenant: %v", err)
|
|
}
|
|
if err := db.Create(&user).Error; err != nil {
|
|
t.Fatalf("failed to create user: %v", err)
|
|
}
|
|
if err := db.Create(&loginID).Error; err != nil {
|
|
t.Fatalf("failed to create login id: %v", err)
|
|
}
|
|
|
|
staleIdentity := service.KratosIdentity{
|
|
ID: user.ID,
|
|
State: "active",
|
|
Traits: map[string]any{
|
|
"email": user.Email,
|
|
"name": user.Name,
|
|
"tenant_id": child.ID,
|
|
"primaryTenantId": child.ID,
|
|
"primaryTenantSlug": child.Slug,
|
|
"primaryTenantName": child.Name,
|
|
"additionalAppointments": []any{
|
|
map[string]any{
|
|
"tenantId": child.ID,
|
|
"tenantSlug": child.Slug,
|
|
"tenantName": child.Name,
|
|
"isPrimary": true,
|
|
"grade": "G5",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
updatedIdentity := staleIdentity
|
|
mockKratos := new(MockKratosAdmin)
|
|
mockKratos.On("GetIdentity", mock.Anything, user.ID).Return(&staleIdentity, nil).Once()
|
|
mockKratos.On("UpdateIdentity", mock.Anything, user.ID, mock.MatchedBy(func(traits map[string]any) bool {
|
|
if traits["tenant_id"] != parent.ID || traits["primaryTenantId"] != parent.ID {
|
|
return false
|
|
}
|
|
if traits["primaryTenantSlug"] != parent.Slug || traits["primaryTenantName"] != parent.Name {
|
|
return false
|
|
}
|
|
appointments, ok := traits["additionalAppointments"].([]any)
|
|
if !ok || len(appointments) != 1 {
|
|
return false
|
|
}
|
|
appointment, ok := appointments[0].(map[string]any)
|
|
return ok &&
|
|
appointment["tenantId"] == parent.ID &&
|
|
appointment["tenantSlug"] == parent.Slug &&
|
|
appointment["tenantName"] == parent.Name &&
|
|
appointment["grade"] == "G5" &&
|
|
appointment["isPrimary"] == true
|
|
}), "active").Run(func(args mock.Arguments) {
|
|
updatedIdentity.Traits = args.Get(2).(map[string]any)
|
|
}).Return(&updatedIdentity, nil).Once()
|
|
redis := &mockRedisRepo{data: map[string]string{}}
|
|
staleRaw, _ := json.Marshal(staleIdentity)
|
|
if err := redis.Set(identityMirrorKey(user.ID), string(staleRaw), 0); err != nil {
|
|
t.Fatalf("failed to seed identity mirror: %v", err)
|
|
}
|
|
|
|
app := fiber.New()
|
|
app.Delete("/tenants/:id", (&TenantHandler{DB: db, KratosAdmin: mockKratos, IdentityCache: redis}).DeleteTenant)
|
|
req := httptest.NewRequest(http.MethodDelete, "/tenants/"+child.ID, nil)
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("request failed: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusNoContent {
|
|
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusNoContent)
|
|
}
|
|
|
|
var foundUser domain.User
|
|
if err := db.First(&foundUser, "id = ?", user.ID).Error; err != nil {
|
|
t.Fatalf("failed to reload user: %v", err)
|
|
}
|
|
if foundUser.TenantID == nil || *foundUser.TenantID != parent.ID {
|
|
t.Fatalf("user tenant_id = %v, want %s", foundUser.TenantID, parent.ID)
|
|
}
|
|
var foundLogin domain.UserLoginID
|
|
if err := db.First(&foundLogin, "id = ?", loginID.ID).Error; err != nil {
|
|
t.Fatalf("failed to reload login id: %v", err)
|
|
}
|
|
if foundLogin.TenantID != parent.ID {
|
|
t.Fatalf("login tenant_id = %s, want %s", foundLogin.TenantID, parent.ID)
|
|
}
|
|
mockKratos.AssertExpectations(t)
|
|
|
|
var mirrored service.KratosIdentity
|
|
if err := json.Unmarshal([]byte(redis.data[identityMirrorKey(user.ID)]), &mirrored); err != nil {
|
|
t.Fatalf("failed to decode mirrored identity: %v", err)
|
|
}
|
|
if mirrored.Traits["tenant_id"] != parent.ID || mirrored.Traits["primaryTenantSlug"] != parent.Slug {
|
|
t.Fatalf("mirrored traits = %#v, want promoted tenant %s/%s", mirrored.Traits, parent.ID, parent.Slug)
|
|
}
|
|
}
|
|
|
|
func TestTenantHandlerDeleteTenantsBulkRejectsSeedTenant(t *testing.T) {
|
|
setSeedTenantCSVForDeleteGuard(t, "protected-root")
|
|
db := newTenantHandlerSeedDeleteDB(t)
|
|
seed := domain.Tenant{
|
|
ID: "00000000-0000-0000-0000-000000000011",
|
|
Name: "Protected",
|
|
Slug: "protected-root",
|
|
Type: domain.TenantTypeCompanyGroup,
|
|
Status: domain.TenantStatusActive,
|
|
}
|
|
normal := domain.Tenant{
|
|
ID: "00000000-0000-0000-0000-000000000012",
|
|
Name: "Normal",
|
|
Slug: "normal",
|
|
Type: domain.TenantTypeCompany,
|
|
Status: domain.TenantStatusActive,
|
|
}
|
|
if err := db.Create(&seed).Error; err != nil {
|
|
t.Fatalf("failed to create seed tenant: %v", err)
|
|
}
|
|
if err := db.Create(&normal).Error; err != nil {
|
|
t.Fatalf("failed to create normal tenant: %v", err)
|
|
}
|
|
|
|
app := fiber.New()
|
|
app.Use(func(c *fiber.Ctx) error {
|
|
c.Locals("user_profile", &domain.UserProfileResponse{Role: domain.RoleSuperAdmin})
|
|
return c.Next()
|
|
})
|
|
app.Delete("/tenants/bulk", (&TenantHandler{DB: db}).DeleteTenantsBulk)
|
|
body, _ := json.Marshal(map[string][]string{"ids": {seed.ID, normal.ID}})
|
|
req := httptest.NewRequest(http.MethodDelete, "/tenants/bulk", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("request failed: %v", err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusConflict {
|
|
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusConflict)
|
|
}
|
|
var count int64
|
|
if err := db.Model(&domain.Tenant{}).Where("id IN ?", []string{seed.ID, normal.ID}).Count(&count).Error; err != nil {
|
|
t.Fatalf("count tenants: %v", err)
|
|
}
|
|
if count != 2 {
|
|
t.Fatalf("remaining tenant count = %d, want 2", count)
|
|
}
|
|
}
|
|
|
|
func TestTenantHandlerDeleteTenantsBulkReassignsUsersToNearestRemainingAncestor(t *testing.T) {
|
|
db := newTenantHandlerSeedDeleteDB(t)
|
|
root := domain.Tenant{
|
|
ID: "10000000-0000-0000-0000-000000000011",
|
|
Name: "Root",
|
|
Slug: "delete-policy-root",
|
|
Type: domain.TenantTypeCompanyGroup,
|
|
Status: domain.TenantStatusActive,
|
|
}
|
|
parent := domain.Tenant{
|
|
ID: "10000000-0000-0000-0000-000000000012",
|
|
Name: "Parent",
|
|
Slug: "delete-policy-bulk-parent",
|
|
Type: domain.TenantTypeCompany,
|
|
ParentID: &root.ID,
|
|
Status: domain.TenantStatusActive,
|
|
}
|
|
child := domain.Tenant{
|
|
ID: "10000000-0000-0000-0000-000000000013",
|
|
Name: "Collaboration",
|
|
Slug: "delete-policy-bulk-collaboration",
|
|
Type: domain.TenantTypeUserGroup,
|
|
ParentID: &parent.ID,
|
|
Status: domain.TenantStatusActive,
|
|
}
|
|
user := domain.User{
|
|
ID: "10000000-0000-0000-0000-000000000111",
|
|
Email: "bulk-delete-policy-user@example.com",
|
|
Name: "Bulk Delete Policy User",
|
|
Role: domain.RoleUser,
|
|
TenantID: &child.ID,
|
|
}
|
|
if err := db.Create(&root).Error; err != nil {
|
|
t.Fatalf("failed to create root tenant: %v", err)
|
|
}
|
|
if err := db.Create(&parent).Error; err != nil {
|
|
t.Fatalf("failed to create parent tenant: %v", err)
|
|
}
|
|
if err := db.Create(&child).Error; err != nil {
|
|
t.Fatalf("failed to create child tenant: %v", err)
|
|
}
|
|
if err := db.Create(&user).Error; err != nil {
|
|
t.Fatalf("failed to create user: %v", err)
|
|
}
|
|
|
|
mockSvc := new(MockTenantService)
|
|
mockSvc.On("DeleteTenantsBulk", mock.Anything, []string{parent.ID, child.ID}).Return(nil).Once()
|
|
|
|
app := fiber.New()
|
|
app.Use(func(c *fiber.Ctx) error {
|
|
c.Locals("user_profile", &domain.UserProfileResponse{Role: domain.RoleSuperAdmin})
|
|
return c.Next()
|
|
})
|
|
app.Delete("/tenants/bulk", (&TenantHandler{DB: db, Service: mockSvc}).DeleteTenantsBulk)
|
|
body, _ := json.Marshal(map[string][]string{"ids": {parent.ID, child.ID}})
|
|
req := httptest.NewRequest(http.MethodDelete, "/tenants/bulk", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("request failed: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK)
|
|
}
|
|
|
|
var foundUser domain.User
|
|
if err := db.First(&foundUser, "id = ?", user.ID).Error; err != nil {
|
|
t.Fatalf("failed to reload user: %v", err)
|
|
}
|
|
if foundUser.TenantID == nil || *foundUser.TenantID != root.ID {
|
|
t.Fatalf("user tenant_id = %v, want %s", foundUser.TenantID, root.ID)
|
|
}
|
|
mockSvc.AssertExpectations(t)
|
|
}
|