1
0
forked from baron/baron-sso
Files
baron-sso/backend/internal/repository/user_projection_repository.go

228 lines
6.5 KiB
Go

package repository
import (
"baron-sso-backend/internal/domain"
"context"
"errors"
"fmt"
"strings"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type UserProjectionRepository interface {
IsReady(ctx context.Context) (bool, error)
GetStatus(ctx context.Context) (domain.UserProjectionStatus, error)
CountTenantMembers(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error)
CountTenantMembersRecursive(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error)
ReplaceAllFromKratos(ctx context.Context, users []domain.User) error
MarkFailed(ctx context.Context, syncErr error) error
}
type userProjectionRepository struct {
db *gorm.DB
}
func NewUserProjectionRepository(db *gorm.DB) UserProjectionRepository {
return &userProjectionRepository{db: db}
}
func (r *userProjectionRepository) IsReady(ctx context.Context) (bool, error) {
status, err := r.GetStatus(ctx)
if err != nil {
return false, err
}
return status.Ready, nil
}
func (r *userProjectionRepository) GetStatus(ctx context.Context) (domain.UserProjectionStatus, error) {
var projectedUsers int64
if err := r.db.WithContext(ctx).Model(&domain.User{}).Count(&projectedUsers).Error; err != nil {
return domain.UserProjectionStatus{}, err
}
var state domain.UserProjectionState
err := r.db.WithContext(ctx).First(&state, "name = ?", domain.UserProjectionNameKratos).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return domain.UserProjectionStatus{
Name: domain.UserProjectionNameKratos,
Status: domain.UserProjectionStatusFailed,
Ready: false,
ProjectedUsers: projectedUsers,
}, nil
}
if err != nil {
return domain.UserProjectionStatus{}, err
}
return domain.UserProjectionStatus{
Name: state.Name,
Status: state.Status,
Ready: state.Status == domain.UserProjectionStatusReady && state.LastSyncedAt != nil,
LastSyncedAt: state.LastSyncedAt,
LastError: state.LastError,
UpdatedAt: &state.UpdatedAt,
ProjectedUsers: projectedUsers,
}, nil
}
func (r *userProjectionRepository) CountTenantMembers(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
counts := make(map[string]int64, len(tenants))
for _, tenant := range tenants {
counts[tenant.ID] = 0
}
if len(tenants) == 0 {
return counts, nil
}
valuePlaceholders := make([]string, 0, len(tenants))
args := make([]any, 0, len(tenants)*2)
for _, tenant := range tenants {
valuePlaceholders = append(valuePlaceholders, "(?, ?)")
args = append(args, strings.TrimSpace(tenant.ID), strings.TrimSpace(tenant.Slug))
}
query := fmt.Sprintf(`
WITH requested(tenant_id, slug) AS (
VALUES %s
)
SELECT requested.tenant_id, COUNT(DISTINCT users.id) AS count
FROM requested
LEFT JOIN users ON users.deleted_at IS NULL AND (
users.tenant_id::text = requested.tenant_id
)
GROUP BY requested.tenant_id
`, strings.Join(valuePlaceholders, ","))
type result struct {
TenantID string
Count int64
}
var rows []result
if err := r.db.WithContext(ctx).Raw(query, args...).Scan(&rows).Error; err != nil {
return nil, err
}
for _, row := range rows {
counts[row.TenantID] = row.Count
}
return counts, nil
}
func (r *userProjectionRepository) CountTenantMembersRecursive(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
counts := make(map[string]int64, len(tenants))
for _, tenant := range tenants {
counts[tenant.ID] = 0
}
if len(tenants) == 0 {
return counts, nil
}
valuePlaceholders := make([]string, 0, len(tenants))
args := make([]any, 0, len(tenants))
for _, tenant := range tenants {
valuePlaceholders = append(valuePlaceholders, "(?)")
args = append(args, strings.TrimSpace(tenant.ID))
}
query := fmt.Sprintf(`
WITH RECURSIVE requested(tenant_id) AS (
VALUES %s
),
descendants(root_tenant_id, tenant_id) AS (
SELECT requested.tenant_id, requested.tenant_id
FROM requested
UNION ALL
SELECT descendants.root_tenant_id, child.id::text
FROM descendants
JOIN tenants child
ON child.parent_id::text = descendants.tenant_id
AND child.deleted_at IS NULL
)
SELECT requested.tenant_id, COUNT(DISTINCT users.id) AS count
FROM requested
LEFT JOIN descendants
ON descendants.root_tenant_id = requested.tenant_id
LEFT JOIN users
ON users.deleted_at IS NULL
AND users.tenant_id::text = descendants.tenant_id
GROUP BY requested.tenant_id
`, strings.Join(valuePlaceholders, ","))
type result struct {
TenantID string
Count int64
}
var rows []result
if err := r.db.WithContext(ctx).Raw(query, args...).Scan(&rows).Error; err != nil {
return nil, err
}
for _, row := range rows {
counts[row.TenantID] = row.Count
}
return counts, nil
}
func (r *userProjectionRepository) ReplaceAllFromKratos(ctx context.Context, users []domain.User) error {
now := time.Now()
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range users {
users[i].DeletedAt = gorm.DeletedAt{}
if users[i].CreatedAt.IsZero() {
users[i].CreatedAt = now
}
if users[i].UpdatedAt.IsZero() {
users[i].UpdatedAt = now
}
}
if len(users) > 0 {
// [FIX] Handle email conflicts before bulk upsert
for _, u := range users {
if u.Email != "" {
// Hard-delete any record with same email but different ID to clear unique constraint
_ = tx.Unscoped().Where("email = ? AND id != ?", u.Email, u.ID).Delete(&domain.User{}).Error
}
}
if err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "id"}},
UpdateAll: true,
}).Create(&users).Error; err != nil {
return err
}
}
return upsertUserProjectionState(tx, domain.UserProjectionStatusReady, &now, "")
})
}
func (r *userProjectionRepository) MarkFailed(ctx context.Context, syncErr error) error {
message := ""
if syncErr != nil {
message = syncErr.Error()
}
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
return upsertUserProjectionState(tx, domain.UserProjectionStatusFailed, nil, message)
})
}
func upsertUserProjectionState(tx *gorm.DB, status string, syncedAt *time.Time, lastError string) error {
state := domain.UserProjectionState{
Name: domain.UserProjectionNameKratos,
Status: status,
LastSyncedAt: syncedAt,
LastError: lastError,
UpdatedAt: time.Now(),
}
return tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "name"}},
DoUpdates: clause.AssignmentColumns([]string{
"status",
"last_synced_at",
"last_error",
"updated_at",
}),
}).Create(&state).Error
}