package handler import ( "baron-sso-backend/internal/domain" "baron-sso-backend/internal/testsupport" "bytes" "context" "encoding/json" "log" "net/http" "net/http/httptest" "os" "path/filepath" "testing" "time" "github.com/gofiber/fiber/v2" "github.com/testcontainers/testcontainers-go" postgres_module "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/wait" gorm_postgres "gorm.io/driver/postgres" "gorm.io/gorm" ) func newTenantHandlerSeedDeleteDB(t *testing.T) *gorm.DB { t.Helper() if !testsupport.DockerAvailable() { t.Skip("Docker provider is unavailable in this environment") } ctx := context.Background() postgresContainer, err := postgres_module.Run(ctx, "postgres:16-alpine", postgres_module.WithDatabase("testdb"), postgres_module.WithUsername("user"), postgres_module.WithPassword("password"), testcontainers.WithWaitStrategy( wait.ForLog("database system is ready to accept connections"). WithOccurrence(2). WithStartupTimeout(30*time.Second)), ) if err != nil { t.Fatalf("failed to start postgres container: %v", err) } t.Cleanup(func() { if err := postgresContainer.Terminate(ctx); err != nil { log.Printf("failed to terminate postgres container: %v", err) } }) connStr, err := postgresContainer.ConnectionString(ctx, "sslmode=disable") if err != nil { t.Fatalf("failed to get postgres connection string: %v", err) } db, err := gorm.Open(gorm_postgres.Open(connStr), &gorm.Config{}) if err != nil { t.Fatalf("failed to open postgres connection: %v", err) } if err := db.AutoMigrate(&domain.Tenant{}); err != nil { t.Fatalf("failed to migrate tenants: %v", err) } return db } func setSeedTenantCSVForDeleteGuard(t *testing.T, slug string) { t.Helper() dir := t.TempDir() path := filepath.Join(dir, "seed-tenant.csv") csv := "name,type,parent_tenant_slug,slug,memo,email_domain\n" + "Protected,COMPANY_GROUP,," + slug + ",Protected seed,\n" if err := os.WriteFile(path, []byte(csv), 0o600); err != nil { t.Fatalf("failed to write seed csv: %v", err) } t.Setenv("SEED_TENANT_CSV_PATH", path) } func TestTenantHandlerDeleteTenantRejectsSeedTenant(t *testing.T) { setSeedTenantCSVForDeleteGuard(t, "protected-root") db := newTenantHandlerSeedDeleteDB(t) tenant := domain.Tenant{ ID: "00000000-0000-0000-0000-000000000001", Name: "Protected", Slug: "protected-root", Type: domain.TenantTypeCompanyGroup, Status: domain.TenantStatusActive, } if err := db.Create(&tenant).Error; err != nil { t.Fatalf("failed to create tenant: %v", err) } app := fiber.New() app.Delete("/tenants/:id", (&TenantHandler{DB: db}).DeleteTenant) req := httptest.NewRequest(http.MethodDelete, "/tenants/"+tenant.ID, nil) resp, err := app.Test(req) if err != nil { t.Fatalf("request failed: %v", err) } if resp.StatusCode != http.StatusConflict { t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusConflict) } var count int64 if err := db.Model(&domain.Tenant{}).Where("id = ?", tenant.ID).Count(&count).Error; err != nil { t.Fatalf("count tenant: %v", err) } if count != 1 { t.Fatalf("seed tenant count = %d, want 1", count) } } func TestTenantHandlerDeleteTenantsBulkRejectsSeedTenant(t *testing.T) { setSeedTenantCSVForDeleteGuard(t, "protected-root") db := newTenantHandlerSeedDeleteDB(t) seed := domain.Tenant{ ID: "00000000-0000-0000-0000-000000000011", Name: "Protected", Slug: "protected-root", Type: domain.TenantTypeCompanyGroup, Status: domain.TenantStatusActive, } normal := domain.Tenant{ ID: "00000000-0000-0000-0000-000000000012", Name: "Normal", Slug: "normal", Type: domain.TenantTypeCompany, Status: domain.TenantStatusActive, } if err := db.Create(&seed).Error; err != nil { t.Fatalf("failed to create seed tenant: %v", err) } if err := db.Create(&normal).Error; err != nil { t.Fatalf("failed to create normal tenant: %v", err) } app := fiber.New() app.Use(func(c *fiber.Ctx) error { c.Locals("user_profile", &domain.UserProfileResponse{Role: domain.RoleSuperAdmin}) return c.Next() }) app.Delete("/tenants/bulk", (&TenantHandler{DB: db}).DeleteTenantsBulk) body, _ := json.Marshal(map[string][]string{"ids": {seed.ID, normal.ID}}) req := httptest.NewRequest(http.MethodDelete, "/tenants/bulk", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") resp, err := app.Test(req) if err != nil { t.Fatalf("request failed: %v", err) } if resp.StatusCode != http.StatusConflict { t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusConflict) } var count int64 if err := db.Model(&domain.Tenant{}).Where("id IN ?", []string{seed.ID, normal.ID}).Count(&count).Error; err != nil { t.Fatalf("count tenants: %v", err) } if count != 2 { t.Fatalf("remaining tenant count = %d, want 2", count) } }