forked from baron/baron-sso
chore: consolidate local integration changes
This commit is contained in:
@@ -16,6 +16,7 @@ type UserProjectionRepository interface {
|
||||
IsReady(ctx context.Context) (bool, error)
|
||||
GetStatus(ctx context.Context) (domain.UserProjectionStatus, error)
|
||||
CountTenantMembers(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error)
|
||||
CountTenantMembersRecursive(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error)
|
||||
ReplaceAllFromKratos(ctx context.Context, users []domain.User) error
|
||||
MarkFailed(ctx context.Context, syncErr error) error
|
||||
}
|
||||
@@ -108,10 +109,63 @@ func (r *userProjectionRepository) CountTenantMembers(ctx context.Context, tenan
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (r *userProjectionRepository) CountTenantMembersRecursive(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
|
||||
counts := make(map[string]int64, len(tenants))
|
||||
for _, tenant := range tenants {
|
||||
counts[tenant.ID] = 0
|
||||
}
|
||||
if len(tenants) == 0 {
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
valuePlaceholders := make([]string, 0, len(tenants))
|
||||
args := make([]any, 0, len(tenants))
|
||||
for _, tenant := range tenants {
|
||||
valuePlaceholders = append(valuePlaceholders, "(?)")
|
||||
args = append(args, strings.TrimSpace(tenant.ID))
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
WITH RECURSIVE requested(tenant_id) AS (
|
||||
VALUES %s
|
||||
),
|
||||
descendants(root_tenant_id, tenant_id) AS (
|
||||
SELECT requested.tenant_id, requested.tenant_id
|
||||
FROM requested
|
||||
UNION ALL
|
||||
SELECT descendants.root_tenant_id, child.id::text
|
||||
FROM descendants
|
||||
JOIN tenants child
|
||||
ON child.parent_id::text = descendants.tenant_id
|
||||
AND child.deleted_at IS NULL
|
||||
)
|
||||
SELECT requested.tenant_id, COUNT(DISTINCT users.id) AS count
|
||||
FROM requested
|
||||
LEFT JOIN descendants
|
||||
ON descendants.root_tenant_id = requested.tenant_id
|
||||
LEFT JOIN users
|
||||
ON users.deleted_at IS NULL
|
||||
AND users.tenant_id::text = descendants.tenant_id
|
||||
GROUP BY requested.tenant_id
|
||||
`, strings.Join(valuePlaceholders, ","))
|
||||
|
||||
type result struct {
|
||||
TenantID string
|
||||
Count int64
|
||||
}
|
||||
var rows []result
|
||||
if err := r.db.WithContext(ctx).Raw(query, args...).Scan(&rows).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, row := range rows {
|
||||
counts[row.TenantID] = row.Count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (r *userProjectionRepository) ReplaceAllFromKratos(ctx context.Context, users []domain.User) error {
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
ids := make([]string, 0, len(users))
|
||||
for i := range users {
|
||||
users[i].DeletedAt = gorm.DeletedAt{}
|
||||
if users[i].CreatedAt.IsZero() {
|
||||
@@ -120,7 +174,6 @@ func (r *userProjectionRepository) ReplaceAllFromKratos(ctx context.Context, use
|
||||
if users[i].UpdatedAt.IsZero() {
|
||||
users[i].UpdatedAt = now
|
||||
}
|
||||
ids = append(ids, users[i].ID)
|
||||
}
|
||||
|
||||
if len(users) > 0 {
|
||||
@@ -138,11 +191,6 @@ func (r *userProjectionRepository) ReplaceAllFromKratos(ctx context.Context, use
|
||||
}).Create(&users).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Where("id NOT IN ?", ids).Delete(&domain.User{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := tx.Where("1 = 1").Delete(&domain.User{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return upsertUserProjectionState(tx, domain.UserProjectionStatusReady, &now, "")
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUserProjectionRepository_ReplaceAllFromKratosMarksReadyAndRemovesStaleUsers(t *testing.T) {
|
||||
func TestUserProjectionRepository_ReplaceAllFromKratosMarksReadyWithoutDeletingUsersMissingFromPartialList(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := NewUserProjectionRepository(testDB)
|
||||
|
||||
@@ -28,13 +28,14 @@ func TestUserProjectionRepository_ReplaceAllFromKratosMarksReadyAndRemovesStaleU
|
||||
Type: domain.TenantTypeCompany,
|
||||
Status: domain.TenantStatusActive,
|
||||
}).Error)
|
||||
stale := &domain.User{
|
||||
existing := &domain.User{
|
||||
ID: "00000000-0000-0000-0000-000000000099",
|
||||
Email: "stale@example.com",
|
||||
Name: "Stale",
|
||||
Email: "existing@example.com",
|
||||
Name: "Existing",
|
||||
CompanyCode: tenantSlug,
|
||||
TenantID: &tenantID,
|
||||
}
|
||||
require.NoError(t, NewUserRepository(testDB).Create(ctx, stale))
|
||||
require.NoError(t, NewUserRepository(testDB).Create(ctx, existing))
|
||||
|
||||
users := []domain.User{
|
||||
{
|
||||
@@ -66,11 +67,91 @@ func TestUserProjectionRepository_ReplaceAllFromKratosMarksReadyAndRemovesStaleU
|
||||
{ID: tenantID, Slug: tenantSlug},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(2), counts[tenantID])
|
||||
assert.Equal(t, int64(3), counts[tenantID])
|
||||
|
||||
var activeCount int64
|
||||
require.NoError(t, testDB.Model(&domain.User{}).Count(&activeCount).Error)
|
||||
assert.Equal(t, int64(2), activeCount)
|
||||
assert.Equal(t, int64(3), activeCount)
|
||||
|
||||
var existingCount int64
|
||||
require.NoError(t, testDB.Model(&domain.User{}).Where("id = ?", existing.ID).Count(&existingCount).Error)
|
||||
assert.Equal(t, int64(1), existingCount)
|
||||
|
||||
var existingRow domain.User
|
||||
require.NoError(t, testDB.Unscoped().First(&existingRow, "id = ?", existing.ID).Error)
|
||||
assert.False(t, existingRow.DeletedAt.Valid)
|
||||
}
|
||||
|
||||
func TestUserProjectionRepository_CountTenantMembersRecursiveIncludesDescendantsAndExcludesSoftDeletedUsers(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := NewUserProjectionRepository(testDB)
|
||||
|
||||
parentID := "20000000-0000-0000-0000-000000000001"
|
||||
childID := "20000000-0000-0000-0000-000000000002"
|
||||
grandchildID := "20000000-0000-0000-0000-000000000003"
|
||||
siblingID := "20000000-0000-0000-0000-000000000004"
|
||||
tenantIDs := []string{parentID, childID, grandchildID, siblingID}
|
||||
|
||||
require.NoError(t, testDB.Exec("DELETE FROM user_login_ids").Error)
|
||||
require.NoError(t, testDB.Exec("DELETE FROM users").Error)
|
||||
require.NoError(t, testDB.Unscoped().Where("id IN ?", tenantIDs).Delete(&domain.Tenant{}).Error)
|
||||
|
||||
require.NoError(t, testDB.Create(&domain.Tenant{
|
||||
ID: parentID,
|
||||
Name: "Recursive Parent",
|
||||
Slug: "recursive-parent",
|
||||
Type: domain.TenantTypeCompany,
|
||||
Status: domain.TenantStatusActive,
|
||||
}).Error)
|
||||
require.NoError(t, testDB.Create(&domain.Tenant{
|
||||
ID: childID,
|
||||
Name: "Recursive Child",
|
||||
Slug: "recursive-child",
|
||||
Type: domain.TenantTypeOrganization,
|
||||
Status: domain.TenantStatusActive,
|
||||
ParentID: &parentID,
|
||||
}).Error)
|
||||
require.NoError(t, testDB.Create(&domain.Tenant{
|
||||
ID: grandchildID,
|
||||
Name: "Recursive Grandchild",
|
||||
Slug: "recursive-grandchild",
|
||||
Type: domain.TenantTypeUserGroup,
|
||||
Status: domain.TenantStatusActive,
|
||||
ParentID: &childID,
|
||||
}).Error)
|
||||
require.NoError(t, testDB.Create(&domain.Tenant{
|
||||
ID: siblingID,
|
||||
Name: "Recursive Sibling",
|
||||
Slug: "recursive-sibling",
|
||||
Type: domain.TenantTypeCompany,
|
||||
Status: domain.TenantStatusActive,
|
||||
}).Error)
|
||||
|
||||
users := []domain.User{
|
||||
{ID: "21000000-0000-0000-0000-000000000001", Email: "parent@example.com", Name: "Parent", TenantID: &parentID},
|
||||
{ID: "21000000-0000-0000-0000-000000000002", Email: "child@example.com", Name: "Child", TenantID: &childID},
|
||||
{ID: "21000000-0000-0000-0000-000000000003", Email: "grandchild@example.com", Name: "Grandchild", TenantID: &grandchildID},
|
||||
{ID: "21000000-0000-0000-0000-000000000004", Email: "deleted-grandchild@example.com", Name: "Deleted Grandchild", TenantID: &grandchildID},
|
||||
{ID: "21000000-0000-0000-0000-000000000005", Email: "sibling@example.com", Name: "Sibling", TenantID: &siblingID},
|
||||
}
|
||||
for i := range users {
|
||||
require.NoError(t, testDB.Create(&users[i]).Error)
|
||||
}
|
||||
require.NoError(t, testDB.Delete(&domain.User{}, "id = ?", users[3].ID).Error)
|
||||
|
||||
directCounts, err := repo.CountTenantMembers(ctx, []domain.Tenant{{ID: parentID}, {ID: childID}, {ID: grandchildID}, {ID: siblingID}})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), directCounts[parentID])
|
||||
assert.Equal(t, int64(1), directCounts[childID])
|
||||
assert.Equal(t, int64(1), directCounts[grandchildID])
|
||||
assert.Equal(t, int64(1), directCounts[siblingID])
|
||||
|
||||
recursiveCounts, err := repo.CountTenantMembersRecursive(ctx, []domain.Tenant{{ID: parentID}, {ID: childID}, {ID: grandchildID}, {ID: siblingID}})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(3), recursiveCounts[parentID])
|
||||
assert.Equal(t, int64(2), recursiveCounts[childID])
|
||||
assert.Equal(t, int64(1), recursiveCounts[grandchildID])
|
||||
assert.Equal(t, int64(1), recursiveCounts[siblingID])
|
||||
}
|
||||
|
||||
func TestUserProjectionRepository_MarkFailedMakesProjectionNotReady(t *testing.T) {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/pagination"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -46,6 +47,31 @@ func (r *userRepository) DB() *gorm.DB {
|
||||
return r.db
|
||||
}
|
||||
|
||||
func (r *userRepository) withTenantMembershipFilter(db *gorm.DB, tenantIDs []string) *gorm.DB {
|
||||
if len(tenantIDs) == 0 {
|
||||
return db
|
||||
}
|
||||
clauses := []string{"tenant_id IN ?"}
|
||||
args := []any{tenantIDs}
|
||||
for _, tenantID := range tenantIDs {
|
||||
tenantID = strings.TrimSpace(tenantID)
|
||||
if tenantID == "" {
|
||||
continue
|
||||
}
|
||||
payload, err := json.Marshal(map[string]any{
|
||||
"additionalAppointments": []map[string]string{
|
||||
{"tenantId": tenantID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
clauses = append(clauses, "metadata @> ?::jsonb")
|
||||
args = append(args, string(payload))
|
||||
}
|
||||
return db.Where("("+strings.Join(clauses, " OR ")+")", args...)
|
||||
}
|
||||
|
||||
func (r *userRepository) Create(ctx context.Context, user *domain.User) error {
|
||||
return r.db.WithContext(ctx).Create(user).Error
|
||||
}
|
||||
@@ -124,7 +150,7 @@ func (r *userRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.
|
||||
|
||||
func (r *userRepository) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
|
||||
var users []domain.User
|
||||
if err := r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Find(&users).Error; err != nil {
|
||||
if err := r.withTenantMembershipFilter(r.db.WithContext(ctx), []string{tenantID}).Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return users, nil
|
||||
@@ -132,40 +158,23 @@ func (r *userRepository) ListByTenant(ctx context.Context, tenantID string) ([]d
|
||||
|
||||
func (r *userRepository) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("tenant_id = ?", tenantID).Count(&count).Error
|
||||
err := r.withTenantMembershipFilter(r.db.WithContext(ctx).Model(&domain.User{}), []string{tenantID}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *userRepository) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
|
||||
type result struct {
|
||||
TenantID *string
|
||||
Count int64
|
||||
}
|
||||
var results []result
|
||||
counts := make(map[string]int64)
|
||||
|
||||
if len(tenantIDs) == 0 {
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Model(&domain.User{}).
|
||||
Select("tenant_id, count(*) as count").
|
||||
Where("tenant_id IN ?", tenantIDs).
|
||||
Group("tenant_id").
|
||||
Find(&results).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, res := range results {
|
||||
if res.TenantID != nil && *res.TenantID != "" {
|
||||
counts[*res.TenantID] = res.Count
|
||||
}
|
||||
}
|
||||
// Ensure all requested tenant IDs are in the map, even if count is 0
|
||||
for _, id := range tenantIDs {
|
||||
if _, ok := counts[id]; !ok {
|
||||
counts[id] = 0
|
||||
for _, tenantID := range tenantIDs {
|
||||
var count int64
|
||||
if err := r.withTenantMembershipFilter(r.db.WithContext(ctx).Model(&domain.User{}), []string{tenantID}).Count(&count).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts[tenantID] = count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
@@ -222,7 +231,7 @@ func (r *userRepository) List(ctx context.Context, offset, limit int, search str
|
||||
db := r.db.WithContext(ctx).Model(&domain.User{})
|
||||
|
||||
if len(tenantIDs) > 0 {
|
||||
db = db.Where("tenant_id IN ?", tenantIDs)
|
||||
db = r.withTenantMembershipFilter(db, tenantIDs)
|
||||
}
|
||||
|
||||
if search != "" {
|
||||
@@ -311,7 +320,7 @@ func (r *userRepository) FindTenantIDByLoginID(ctx context.Context, loginID stri
|
||||
|
||||
func (r *userRepository) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) {
|
||||
var users []domain.User
|
||||
err := r.db.WithContext(ctx).Where("tenant_id IN ?", tenantIDs).Find(&users).Error
|
||||
err := r.withTenantMembershipFilter(r.db.WithContext(ctx), tenantIDs).Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
|
||||
@@ -204,6 +204,60 @@ func TestUserRepository(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserRepository_ListIncludesAdditionalTenantAppointments(t *testing.T) {
|
||||
repo := NewUserRepository(testDB)
|
||||
ctx := context.Background()
|
||||
require.NoError(t, testDB.Exec("DELETE FROM user_login_ids").Error)
|
||||
require.NoError(t, testDB.Exec("DELETE FROM users").Error)
|
||||
|
||||
primaryTenant := createUserRepositoryTestTenant(t, "repo-primary-tenant")
|
||||
additionalTenant := createUserRepositoryTestTenant(t, "repo-additional-tenant")
|
||||
primaryTenantID := primaryTenant.ID
|
||||
additionalTenantID := additionalTenant.ID
|
||||
users := []domain.User{
|
||||
{
|
||||
ID: uuid.NewString(),
|
||||
Email: "primary-member@example.com",
|
||||
Name: "Primary Member",
|
||||
Role: domain.RoleUser,
|
||||
TenantID: &additionalTenantID,
|
||||
},
|
||||
{
|
||||
ID: uuid.NewString(),
|
||||
Email: "additional-member@example.com",
|
||||
Name: "Additional Member",
|
||||
Role: domain.RoleUser,
|
||||
TenantID: &primaryTenantID,
|
||||
Metadata: domain.JSONMap{
|
||||
"additionalAppointments": []any{
|
||||
map[string]any{
|
||||
"tenantId": additionalTenant.ID,
|
||||
"tenantSlug": additionalTenant.Slug,
|
||||
"tenantName": additionalTenant.Name,
|
||||
"isPrimary": false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for i := range users {
|
||||
require.NoError(t, repo.Create(ctx, &users[i]))
|
||||
}
|
||||
|
||||
listed, total, _, err := repo.List(ctx, 0, 20, "", []string{additionalTenant.ID}, "")
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), total)
|
||||
require.Len(t, listed, 2)
|
||||
emails := []string{listed[0].Email, listed[1].Email}
|
||||
assert.Contains(t, emails, "primary-member@example.com")
|
||||
assert.Contains(t, emails, "additional-member@example.com")
|
||||
|
||||
counts, err := repo.CountByTenantIDs(ctx, []string{additionalTenant.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(2), counts[additionalTenant.ID])
|
||||
}
|
||||
|
||||
func createUserRepositoryTestTenant(t *testing.T, slug string) domain.Tenant {
|
||||
t.Helper()
|
||||
require.NoError(t, testDB.Unscoped().Where("slug = ?", slug).Delete(&domain.Tenant{}).Error)
|
||||
|
||||
Reference in New Issue
Block a user