forked from baron/baron-sso
172 lines
4.9 KiB
Go
172 lines
4.9 KiB
Go
package repository
|
|
|
|
import (
|
|
"baron-sso-backend/internal/domain"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
)
|
|
|
|
type UserProjectionRepository interface {
|
|
IsReady(ctx context.Context) (bool, error)
|
|
GetStatus(ctx context.Context) (domain.UserProjectionStatus, error)
|
|
CountTenantMembers(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error)
|
|
ReplaceAllFromKratos(ctx context.Context, users []domain.User) error
|
|
MarkFailed(ctx context.Context, syncErr error) error
|
|
}
|
|
|
|
type userProjectionRepository struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
func NewUserProjectionRepository(db *gorm.DB) UserProjectionRepository {
|
|
return &userProjectionRepository{db: db}
|
|
}
|
|
|
|
func (r *userProjectionRepository) IsReady(ctx context.Context) (bool, error) {
|
|
status, err := r.GetStatus(ctx)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return status.Ready, nil
|
|
}
|
|
|
|
func (r *userProjectionRepository) GetStatus(ctx context.Context) (domain.UserProjectionStatus, error) {
|
|
var projectedUsers int64
|
|
if err := r.db.WithContext(ctx).Model(&domain.User{}).Count(&projectedUsers).Error; err != nil {
|
|
return domain.UserProjectionStatus{}, err
|
|
}
|
|
|
|
var state domain.UserProjectionState
|
|
err := r.db.WithContext(ctx).First(&state, "name = ?", domain.UserProjectionNameKratos).Error
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return domain.UserProjectionStatus{
|
|
Name: domain.UserProjectionNameKratos,
|
|
Status: domain.UserProjectionStatusFailed,
|
|
Ready: false,
|
|
ProjectedUsers: projectedUsers,
|
|
}, nil
|
|
}
|
|
if err != nil {
|
|
return domain.UserProjectionStatus{}, err
|
|
}
|
|
return domain.UserProjectionStatus{
|
|
Name: state.Name,
|
|
Status: state.Status,
|
|
Ready: state.Status == domain.UserProjectionStatusReady && state.LastSyncedAt != nil,
|
|
LastSyncedAt: state.LastSyncedAt,
|
|
LastError: state.LastError,
|
|
UpdatedAt: &state.UpdatedAt,
|
|
ProjectedUsers: projectedUsers,
|
|
}, nil
|
|
}
|
|
|
|
func (r *userProjectionRepository) CountTenantMembers(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
|
|
counts := make(map[string]int64, len(tenants))
|
|
for _, tenant := range tenants {
|
|
counts[tenant.ID] = 0
|
|
}
|
|
if len(tenants) == 0 {
|
|
return counts, nil
|
|
}
|
|
|
|
valuePlaceholders := make([]string, 0, len(tenants))
|
|
args := make([]interface{}, 0, len(tenants)*2)
|
|
for _, tenant := range tenants {
|
|
valuePlaceholders = append(valuePlaceholders, "(?, ?)")
|
|
args = append(args, strings.TrimSpace(tenant.ID), strings.TrimSpace(tenant.Slug))
|
|
}
|
|
|
|
query := fmt.Sprintf(`
|
|
WITH requested(tenant_id, slug) AS (
|
|
VALUES %s
|
|
)
|
|
SELECT requested.tenant_id, COUNT(DISTINCT users.id) AS count
|
|
FROM requested
|
|
LEFT JOIN users ON users.deleted_at IS NULL AND (
|
|
users.tenant_id::text = requested.tenant_id
|
|
)
|
|
GROUP BY requested.tenant_id
|
|
`, strings.Join(valuePlaceholders, ","))
|
|
|
|
type result struct {
|
|
TenantID string
|
|
Count int64
|
|
}
|
|
var rows []result
|
|
if err := r.db.WithContext(ctx).Raw(query, args...).Scan(&rows).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
for _, row := range rows {
|
|
counts[row.TenantID] = row.Count
|
|
}
|
|
return counts, nil
|
|
}
|
|
|
|
func (r *userProjectionRepository) ReplaceAllFromKratos(ctx context.Context, users []domain.User) error {
|
|
now := time.Now()
|
|
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
ids := make([]string, 0, len(users))
|
|
for i := range users {
|
|
users[i].DeletedAt = gorm.DeletedAt{}
|
|
if users[i].CreatedAt.IsZero() {
|
|
users[i].CreatedAt = now
|
|
}
|
|
if users[i].UpdatedAt.IsZero() {
|
|
users[i].UpdatedAt = now
|
|
}
|
|
ids = append(ids, users[i].ID)
|
|
}
|
|
|
|
if len(users) > 0 {
|
|
if err := tx.Clauses(clause.OnConflict{
|
|
Columns: []clause.Column{{Name: "id"}},
|
|
UpdateAll: true,
|
|
}).Create(&users).Error; err != nil {
|
|
return err
|
|
}
|
|
if err := tx.Where("id NOT IN ?", ids).Delete(&domain.User{}).Error; err != nil {
|
|
return err
|
|
}
|
|
} else if err := tx.Where("1 = 1").Delete(&domain.User{}).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
return upsertUserProjectionState(tx, domain.UserProjectionStatusReady, &now, "")
|
|
})
|
|
}
|
|
|
|
func (r *userProjectionRepository) MarkFailed(ctx context.Context, syncErr error) error {
|
|
message := ""
|
|
if syncErr != nil {
|
|
message = syncErr.Error()
|
|
}
|
|
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
return upsertUserProjectionState(tx, domain.UserProjectionStatusFailed, nil, message)
|
|
})
|
|
}
|
|
|
|
func upsertUserProjectionState(tx *gorm.DB, status string, syncedAt *time.Time, lastError string) error {
|
|
state := domain.UserProjectionState{
|
|
Name: domain.UserProjectionNameKratos,
|
|
Status: status,
|
|
LastSyncedAt: syncedAt,
|
|
LastError: lastError,
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
return tx.Clauses(clause.OnConflict{
|
|
Columns: []clause.Column{{Name: "name"}},
|
|
DoUpdates: clause.AssignmentColumns([]string{
|
|
"status",
|
|
"last_synced_at",
|
|
"last_error",
|
|
"updated_at",
|
|
}),
|
|
}).Create(&state).Error
|
|
}
|