forked from baron/baron-sso
Merge branch 'dev' into feature/tenant-user-list-ui-improvement
This commit is contained in:
@@ -359,6 +359,8 @@ func mustHeadlessClientAssertionWithAlgorithm(t *testing.T, privateKey any, alg
|
||||
}
|
||||
|
||||
func runHeadlessPasswordLoginWithAssertion(t *testing.T, jwks map[string]any, clientAssertion string) *http.Response {
|
||||
t.Helper()
|
||||
t.Setenv("BACKEND_PUBLIC_URL", "")
|
||||
return runHeadlessPasswordLoginWithAssertionRequest(t, jwks, clientAssertion, "http://example.com/api/v1/auth/headless/password/login", nil)
|
||||
}
|
||||
|
||||
@@ -454,6 +456,8 @@ func runHeadlessPasswordLoginWithAssertionAndLogger(
|
||||
clientAssertion string,
|
||||
logger *slog.Logger,
|
||||
) *http.Response {
|
||||
t.Helper()
|
||||
t.Setenv("BACKEND_PUBLIC_URL", "")
|
||||
return runHeadlessPasswordLoginWithAssertionAndLoggerRequest(
|
||||
t,
|
||||
jwks,
|
||||
@@ -799,6 +803,8 @@ func TestPasswordLogin_UserFront_AuditIncludesDefaultClientMetadata(t *testing.T
|
||||
}
|
||||
|
||||
func TestHeadlessPasswordLogin_HeadlessLoginClientSuccess(t *testing.T) {
|
||||
t.Setenv("BACKEND_PUBLIC_URL", "")
|
||||
|
||||
if !testsupport.PortBindingAvailable() {
|
||||
t.Skip("skipping headless login tests because this environment cannot bind local TCP listeners")
|
||||
}
|
||||
@@ -1019,6 +1025,8 @@ func TestHeadlessPasswordLogin_AuditIncludesClientMetadata(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHeadlessPasswordLogin_IgnoresInlineHeadlessJWKSWhenJWKSURIIsConfigured(t *testing.T) {
|
||||
t.Setenv("BACKEND_PUBLIC_URL", "")
|
||||
|
||||
if !testsupport.PortBindingAvailable() {
|
||||
t.Skip("skipping headless login tests because this environment cannot bind local TCP listeners")
|
||||
}
|
||||
@@ -1106,6 +1114,8 @@ func TestHeadlessPasswordLogin_IgnoresInlineHeadlessJWKSWhenJWKSURIIsConfigured(
|
||||
}
|
||||
|
||||
func TestHeadlessPasswordLogin_RefreshesJWKSWhenSignatureFailsForCachedKid(t *testing.T) {
|
||||
t.Setenv("BACKEND_PUBLIC_URL", "")
|
||||
|
||||
if !testsupport.PortBindingAvailable() {
|
||||
t.Skip("skipping headless login tests because this environment cannot bind local TCP listeners")
|
||||
}
|
||||
@@ -1418,6 +1428,8 @@ func TestHeadlessPasswordLogin_AudienceMismatchReturnsDetailedCode(t *testing.T)
|
||||
}
|
||||
|
||||
func TestHeadlessPasswordLogin_AcceptsForwardedHTTPSAudience(t *testing.T) {
|
||||
t.Setenv("BACKEND_PUBLIC_URL", "")
|
||||
|
||||
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||
clientAssertion := mustHeadlessClientAssertion(
|
||||
t,
|
||||
|
||||
@@ -3,11 +3,15 @@ package repository
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestCheckDataIntegrityDetectsTenantAndUserProblems(t *testing.T) {
|
||||
@@ -60,7 +64,18 @@ func TestCheckDataIntegrityDetectsTenantAndUserProblems(t *testing.T) {
|
||||
CreatedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
deletedLoginUser := domain.User{
|
||||
ID: uuid.NewString(),
|
||||
Email: "deleted-login-user-" + suffix + "@example.com",
|
||||
Name: "Deleted Login User",
|
||||
Role: domain.RoleUser,
|
||||
TenantID: &child.ID,
|
||||
Status: domain.UserStatusActive,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
require.NoError(t, testDB.Create(&orphanUser).Error)
|
||||
require.NoError(t, testDB.Create(&deletedLoginUser).Error)
|
||||
require.NoError(t, testDB.Create(&domain.UserLoginID{
|
||||
ID: uuid.NewString(),
|
||||
UserID: orphanUser.ID,
|
||||
@@ -68,8 +83,14 @@ func TestCheckDataIntegrityDetectsTenantAndUserProblems(t *testing.T) {
|
||||
FieldKey: "emp_id",
|
||||
LoginID: "EMP-" + suffix,
|
||||
}).Error)
|
||||
// Missing UserID for UserLoginID cannot be inserted due to FK constraint fk_users_user_login_ids.
|
||||
// So we don't test orphan_user_login_id_users here.
|
||||
require.NoError(t, testDB.Create(&domain.UserLoginID{
|
||||
ID: uuid.NewString(),
|
||||
UserID: deletedLoginUser.ID,
|
||||
TenantID: child.ID,
|
||||
FieldKey: "emp_id",
|
||||
LoginID: "MISSING-" + suffix,
|
||||
}).Error)
|
||||
require.NoError(t, testDB.Delete(&domain.User{}, "id = ?", deletedLoginUser.ID).Error)
|
||||
|
||||
report, err := CheckDataIntegrity(ctx, testDB)
|
||||
require.NoError(t, err)
|
||||
@@ -80,7 +101,69 @@ func TestCheckDataIntegrityDetectsTenantAndUserProblems(t *testing.T) {
|
||||
requireIntegrityCheck(t, report, "tenant_integrity", "orphan_tenant_parents", domain.DataIntegrityStatusFail, 1)
|
||||
requireIntegrityCheck(t, report, "user_integrity", "orphan_user_tenant_memberships", domain.DataIntegrityStatusFail, 1)
|
||||
requireIntegrityCheck(t, report, "user_integrity", "orphan_user_login_id_tenants", domain.DataIntegrityStatusFail, 1)
|
||||
requireIntegrityCheck(t, report, "user_integrity", "orphan_user_login_id_users", domain.DataIntegrityStatusPass, 0)
|
||||
requireIntegrityCheck(t, report, "user_integrity", "orphan_user_login_id_users", domain.DataIntegrityStatusFail, 1)
|
||||
}
|
||||
|
||||
func TestCheckDataIntegrityDetectsHardOrphanUserLoginIDRows(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
suffix := uuid.NewString()
|
||||
rollback := errors.New("rollback hard orphan fixture")
|
||||
|
||||
err := testDB.Transaction(func(tx *gorm.DB) error {
|
||||
var constraintNames []string
|
||||
if err := tx.Raw(`
|
||||
SELECT conname
|
||||
FROM pg_constraint
|
||||
WHERE conrelid = 'user_login_ids'::regclass
|
||||
AND contype = 'f'
|
||||
`).Scan(&constraintNames).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, constraintName := range constraintNames {
|
||||
statement := fmt.Sprintf("ALTER TABLE user_login_ids DROP CONSTRAINT %s", pq.QuoteIdentifier(constraintName))
|
||||
if err := tx.Exec(statement).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
before, err := CheckDataIntegrity(ctx, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
beforeTenantCount, err := integrityCheckCount(before, "user_integrity", "orphan_user_login_id_tenants")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
beforeUserCount, err := integrityCheckCount(before, "user_integrity", "orphan_user_login_id_users")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Create(&domain.UserLoginID{
|
||||
ID: uuid.NewString(),
|
||||
UserID: uuid.NewString(),
|
||||
TenantID: uuid.NewString(),
|
||||
FieldKey: "emp_id",
|
||||
LoginID: "HARD-ORPHAN-" + suffix,
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
report, err := CheckDataIntegrity(ctx, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := expectIntegrityCheck(report, "user_integrity", "orphan_user_login_id_tenants", domain.DataIntegrityStatusFail, beforeTenantCount+1); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := expectIntegrityCheck(report, "user_integrity", "orphan_user_login_id_users", domain.DataIntegrityStatusFail, beforeUserCount+1); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return rollback
|
||||
})
|
||||
require.ErrorIs(t, err, rollback)
|
||||
}
|
||||
|
||||
func TestListAndDeleteOrphanUserLoginIDsOnlyDeletesRevalidatedTargets(t *testing.T) {
|
||||
@@ -189,17 +272,41 @@ func TestListAndDeleteOrphanUserLoginIDsOnlyDeletesRevalidatedTargets(t *testing
|
||||
|
||||
func requireIntegrityCheck(t *testing.T, report domain.DataIntegrityReport, sectionKey, checkKey string, status domain.DataIntegrityStatus, count int64) {
|
||||
t.Helper()
|
||||
require.NoError(t, expectIntegrityCheck(report, sectionKey, checkKey, status, count))
|
||||
}
|
||||
|
||||
func expectIntegrityCheck(report domain.DataIntegrityReport, sectionKey, checkKey string, status domain.DataIntegrityStatus, count int64) error {
|
||||
check, ok := findIntegrityCheck(report, sectionKey, checkKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("integrity check %s/%s not found", sectionKey, checkKey)
|
||||
}
|
||||
if check.Status != status {
|
||||
return fmt.Errorf("integrity check %s/%s status = %s, want %s", sectionKey, checkKey, check.Status, status)
|
||||
}
|
||||
if check.Count != count {
|
||||
return fmt.Errorf("integrity check %s/%s count = %d, want %d", sectionKey, checkKey, check.Count, count)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func integrityCheckCount(report domain.DataIntegrityReport, sectionKey, checkKey string) (int64, error) {
|
||||
check, ok := findIntegrityCheck(report, sectionKey, checkKey)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("integrity check %s/%s not found", sectionKey, checkKey)
|
||||
}
|
||||
return check.Count, nil
|
||||
}
|
||||
|
||||
func findIntegrityCheck(report domain.DataIntegrityReport, sectionKey, checkKey string) (domain.DataIntegrityCheck, bool) {
|
||||
for _, section := range report.Sections {
|
||||
if section.Key != sectionKey {
|
||||
continue
|
||||
}
|
||||
for _, check := range section.Checks {
|
||||
if check.Key == checkKey {
|
||||
require.Equal(t, status, check.Status)
|
||||
require.Equal(t, count, check.Count)
|
||||
return
|
||||
return check, true
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Fatalf("integrity check %s/%s not found", sectionKey, checkKey)
|
||||
return domain.DataIntegrityCheck{}, false
|
||||
}
|
||||
|
||||
@@ -26,20 +26,16 @@ func TestClearOrphanUserTenantMemberships(t *testing.T) {
|
||||
require.NoError(t, testDB.Delete(&domain.Tenant{}, "id = ?", deletedTenant.ID).Error)
|
||||
|
||||
activeUser := &domain.User{
|
||||
Email: "active-membership@example.com",
|
||||
Name: "Active Membership",
|
||||
Role: "user",
|
||||
TenantID: &activeTenant.ID,
|
||||
CompanyCode: activeTenant.Slug,
|
||||
CompanyCodes: []string{activeTenant.Slug},
|
||||
Email: "active-membership@example.com",
|
||||
Name: "Active Membership",
|
||||
Role: "user",
|
||||
TenantID: &activeTenant.ID,
|
||||
}
|
||||
orphanUser := &domain.User{
|
||||
Email: "orphan-membership@example.com",
|
||||
Name: "Orphan Membership",
|
||||
Role: "user",
|
||||
TenantID: &deletedTenant.ID,
|
||||
CompanyCode: deletedTenant.Slug,
|
||||
CompanyCodes: []string{deletedTenant.Slug},
|
||||
Email: "orphan-membership@example.com",
|
||||
Name: "Orphan Membership",
|
||||
Role: "user",
|
||||
TenantID: &deletedTenant.ID,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, activeUser))
|
||||
require.NoError(t, repo.Create(ctx, orphanUser))
|
||||
@@ -57,12 +53,10 @@ func TestClearOrphanUserTenantMemberships(t *testing.T) {
|
||||
require.NotNil(t, foundActive.TenantID)
|
||||
require.NotNil(t, foundActive.Tenant)
|
||||
assert.Equal(t, activeTenant.ID, *foundActive.TenantID)
|
||||
assert.Equal(t, activeTenant.Slug, foundActive.Tenant.Slug)
|
||||
|
||||
foundOrphan, err := repo.FindByEmail(ctx, orphanUser.Email)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, foundOrphan.TenantID)
|
||||
assert.Nil(t, foundOrphan.Tenant)
|
||||
|
||||
count, err = CountOrphanUserTenantMemberships(ctx, testDB)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -47,13 +47,12 @@ func TestUserProjectionRepository_ReplaceAllFromKratosMarksReadyAndRemovesStaleU
|
||||
UpdatedAt: time.Now(),
|
||||
},
|
||||
{
|
||||
ID: "00000000-0000-0000-0000-000000000102",
|
||||
Email: "two@example.com",
|
||||
Name: "Two",
|
||||
TenantID: &tenantID,
|
||||
CompanyCodes: []string{tenantSlug},
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
ID: "00000000-0000-0000-0000-000000000102",
|
||||
Email: "two@example.com",
|
||||
Name: "Two",
|
||||
TenantID: &tenantID,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUserRepository(t *testing.T) {
|
||||
@@ -76,19 +78,17 @@ func TestUserRepository(t *testing.T) {
|
||||
|
||||
t.Run("CountByCompanyCodes", func(t *testing.T) {
|
||||
// Clean start for this subtest
|
||||
testDB.Exec("DELETE FROM user_login_ids")
|
||||
testDB.Exec("DELETE FROM users")
|
||||
testDB.Exec("DELETE FROM tenants")
|
||||
|
||||
tenantA := &domain.Tenant{Name: "Tenant A", Slug: "tenant-a", Type: domain.TenantTypeCompany}
|
||||
tenantB := &domain.Tenant{Name: "Tenant B", Slug: "tenant-b", Type: domain.TenantTypeCompany}
|
||||
_ = testDB.Create(tenantA)
|
||||
_ = testDB.Create(tenantB)
|
||||
testDB.Exec("DELETE FROM tenant_domains")
|
||||
tenantA := createUserRepositoryTestTenant(t, "tenant-a")
|
||||
tenantB := createUserRepositoryTestTenant(t, "tenant-b")
|
||||
|
||||
users := []domain.User{
|
||||
{Email: "u1@a.com", Name: "U1", TenantID: &tenantA.ID},
|
||||
{Email: "u2@a.com", Name: "U2", TenantID: &tenantA.ID},
|
||||
{Email: "u3@b.com", Name: "U3", TenantID: &tenantB.ID},
|
||||
{Email: "u4@none.com", Name: "U4", TenantID: nil},
|
||||
{Email: "u4@none.com", Name: "U4"},
|
||||
}
|
||||
for _, u := range users {
|
||||
_ = repo.Create(ctx, &u)
|
||||
@@ -102,21 +102,20 @@ func TestUserRepository(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("CountByCompanyCodes excludes soft deleted cache rows", func(t *testing.T) {
|
||||
testDB.Exec("DELETE FROM user_login_ids")
|
||||
testDB.Exec("DELETE FROM users")
|
||||
testDB.Exec("DELETE FROM tenants")
|
||||
|
||||
tenantA := &domain.Tenant{Name: "Tenant A", Slug: "tenant-a", Type: domain.TenantTypeCompany}
|
||||
_ = testDB.Create(tenantA)
|
||||
testDB.Exec("DELETE FROM tenant_domains")
|
||||
tenantA := createUserRepositoryTestTenant(t, "tenant-a")
|
||||
|
||||
active := &domain.User{Email: "active@a.com", Name: "Active", TenantID: &tenantA.ID}
|
||||
deleted := &domain.User{Email: "deleted@a.com", Name: "Deleted", TenantID: &tenantA.ID}
|
||||
arrayDeleted := &domain.User{Email: "array-deleted@a.com", Name: "Array Deleted", TenantID: &tenantA.ID}
|
||||
secondDeleted := &domain.User{Email: "second-deleted@a.com", Name: "Second Deleted", TenantID: &tenantA.ID}
|
||||
|
||||
assert.NoError(t, repo.Create(ctx, active))
|
||||
assert.NoError(t, repo.Create(ctx, deleted))
|
||||
assert.NoError(t, repo.Create(ctx, arrayDeleted))
|
||||
assert.NoError(t, repo.Create(ctx, secondDeleted))
|
||||
assert.NoError(t, repo.Delete(ctx, deleted.ID))
|
||||
assert.NoError(t, repo.Delete(ctx, arrayDeleted.ID))
|
||||
assert.NoError(t, repo.Delete(ctx, secondDeleted.ID))
|
||||
|
||||
counts, err := repo.CountByCompanyCodes(ctx, []string{"tenant-a"})
|
||||
|
||||
@@ -174,3 +173,17 @@ func TestUserRepository(t *testing.T) {
|
||||
assert.Equal(t, "E002", saved[0].LoginID)
|
||||
})
|
||||
}
|
||||
|
||||
func createUserRepositoryTestTenant(t *testing.T, slug string) domain.Tenant {
|
||||
t.Helper()
|
||||
require.NoError(t, testDB.Unscoped().Where("slug = ?", slug).Delete(&domain.Tenant{}).Error)
|
||||
tenant := domain.Tenant{
|
||||
ID: uuid.NewString(),
|
||||
Name: "Tenant " + slug,
|
||||
Slug: slug,
|
||||
Type: domain.TenantTypeCompany,
|
||||
Status: domain.TenantStatusActive,
|
||||
}
|
||||
require.NoError(t, testDB.Create(&tenant).Error)
|
||||
return tenant
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user