1
0
forked from baron/baron-sso
Files
baron-sso/backend/internal/repository/user_repository.go
chan 31d107ff2e feat(user): support fixed UUID registration and enhance bulk import results
- Added support for fixed UUIDs during bulk registration (Search-first + ExternalID mapping)
- Implemented idempotency and visibility restoration for soft-deleted users
- Enhanced bulk upload UI to show 'New/Updated/Unchanged' status and modified fields
- Added logic to reclaim identifiers (login_id) from colliding records
- Added frontend E2E and backend unit tests for UUID integrity and conflict handling
- Fixed i18n, formatting, and mock tests to satisfy code-check
- Applied 'go fix' for 'omitzero' tags and general Go standards
2026-06-01 15:34:08 +09:00

307 lines
9.7 KiB
Go

package repository
import (
"baron-sso-backend/internal/domain"
"context"
"fmt"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type UserRepository interface {
Create(ctx context.Context, user *domain.User) error
Update(ctx context.Context, user *domain.User) error
FindByEmail(ctx context.Context, email string) (*domain.User, error)
FindByID(ctx context.Context, id string) (*domain.User, error)
FindByIDs(ctx context.Context, ids []string) ([]domain.User, error)
ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error)
List(ctx context.Context, offset, limit int, search string, tenantSlug string) ([]domain.User, int64, error)
CountByTenant(ctx context.Context, tenantID string) (int64, error)
CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error)
CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error)
FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error)
FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error)
Delete(ctx context.Context, id string) error
DB() *gorm.DB
// Multiple identifiers support
UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error
GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error)
IsLoginIDTaken(ctx context.Context, loginID string) (bool, error)
FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error)
}
type userRepository struct {
db *gorm.DB
}
func NewUserRepository(db *gorm.DB) UserRepository {
return &userRepository{db: db}
}
func (r *userRepository) DB() *gorm.DB {
return r.db
}
func (r *userRepository) Create(ctx context.Context, user *domain.User) error {
return r.db.WithContext(ctx).Create(user).Error
}
func (r *userRepository) Update(ctx context.Context, user *domain.User) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. Check for email collision (including soft-deleted)
var existing domain.User
if err := tx.Unscoped().Where("email = ?", user.Email).First(&existing).Error; err == nil {
// If email exists but ID is different, we MUST clear the old one to avoid unique constraint violation
if existing.ID != user.ID {
// [Restored] Check if the existing user is archived
if strings.EqualFold(strings.TrimSpace(existing.Status), domain.UserStatusArchived) {
return fmt.Errorf("email is reserved by archived user: %s", user.Email)
}
// HARD DELETE the old record and its associated login IDs to free up the email and identifiers
if err := tx.Unscoped().Where("user_id = ?", existing.ID).Delete(&domain.UserLoginID{}).Error; err != nil {
return err
}
if err := tx.Unscoped().Delete(&domain.User{}, "id = ?", existing.ID).Error; err != nil {
return err
}
}
}
// 2. Perform Upsert on the new/target ID
return tx.Unscoped().Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "id"}},
DoUpdates: clause.Assignments(map[string]any{
"email": user.Email,
"name": user.Name,
"phone": user.Phone,
"role": user.Role,
"status": user.Status,
"department": user.Department,
"grade": user.Grade,
"position": user.Position,
"job_title": user.JobTitle,
"metadata": user.Metadata,
"tenant_id": user.TenantID,
"affiliation_type": user.AffiliationType,
"updated_at": user.UpdatedAt,
"deleted_at": nil, // Ensure it's active
}),
}).Create(user).Error
})
}
func (r *userRepository) FindByEmail(ctx context.Context, email string) (*domain.User, error) {
var user domain.User
if err := r.db.WithContext(ctx).Preload("Tenant").Where("email = ?", email).First(&user).Error; err != nil {
return nil, err
}
return &user, nil
}
func (r *userRepository) FindByID(ctx context.Context, id string) (*domain.User, error) {
var user domain.User
if err := r.db.WithContext(ctx).Preload("Tenant").Where("id = ?", id).First(&user).Error; err != nil {
return nil, err
}
return &user, nil
}
func (r *userRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) {
var users []domain.User
if len(ids) == 0 {
return users, nil
}
if err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&users).Error; err != nil {
return nil, err
}
return users, nil
}
func (r *userRepository) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
var users []domain.User
if err := r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Find(&users).Error; err != nil {
return nil, err
}
return users, nil
}
func (r *userRepository) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("tenant_id = ?", tenantID).Count(&count).Error
return count, err
}
func (r *userRepository) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
type result struct {
TenantID *string
Count int64
}
var results []result
counts := make(map[string]int64)
if len(tenantIDs) == 0 {
return counts, nil
}
if err := r.db.WithContext(ctx).Model(&domain.User{}).
Select("tenant_id, count(*) as count").
Where("tenant_id IN ?", tenantIDs).
Group("tenant_id").
Find(&results).Error; err != nil {
return nil, err
}
for _, res := range results {
if res.TenantID != nil && *res.TenantID != "" {
counts[*res.TenantID] = res.Count
}
}
// Ensure all requested tenant IDs are in the map, even if count is 0
for _, id := range tenantIDs {
if _, ok := counts[id]; !ok {
counts[id] = 0
}
}
return counts, nil
}
func (r *userRepository) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) {
if len(codes) == 0 {
return make(map[string]int64), nil
}
type result struct {
TenantSlug string
Count int64
}
var results []result
lowerCodes := lowerStrings(codes)
if err := r.db.WithContext(ctx).Table("users").
Select("LOWER(tenants.slug) AS tenant_slug, count(DISTINCT users.id) AS count").
Joins("JOIN tenants ON users.tenant_id = tenants.id").
Where("users.deleted_at IS NULL AND LOWER(tenants.slug) IN ?", lowerCodes).
Group("LOWER(tenants.slug)").
Scan(&results).Error; err != nil {
return nil, err
}
counts := make(map[string]int64)
for _, res := range results {
counts[strings.ToLower(res.TenantSlug)] = res.Count
}
// Ensure all requested codes are present in results (even if count is 0)
for _, code := range codes {
lower := strings.ToLower(strings.TrimSpace(code))
if _, ok := counts[lower]; !ok {
counts[lower] = 0
}
}
return counts, nil
}
func lowerStrings(arr []string) []string {
res := make([]string, len(arr))
for i, s := range arr {
res[i] = strings.ToLower(strings.TrimSpace(s))
}
return res
}
func (r *userRepository) List(ctx context.Context, offset, limit int, search string, tenantSlug string) ([]domain.User, int64, error) {
var users []domain.User
var total int64
db := r.db.WithContext(ctx).Model(&domain.User{})
if tenantSlug != "" {
db = db.Joins("LEFT JOIN tenants ON users.tenant_id = tenants.id").
Where("tenants.slug = ?", tenantSlug)
}
if search != "" {
searchTerm := "%" + search + "%"
db = db.Where("(users.email LIKE ? OR users.name LIKE ? OR users.metadata::text LIKE ?)",
searchTerm, searchTerm, searchTerm)
}
if err := db.Count(&total).Error; err != nil {
return nil, 0, err
}
if err := db.Offset(offset).Limit(limit).Preload("Tenant").Find(&users).Error; err != nil {
return nil, 0, err
}
return users, total, nil
}
func (r *userRepository) Delete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&domain.User{}, "id = ?", id).Error
}
func (r *userRepository) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// [FIX] Use Unscoped to permanently delete existing login IDs for this user
// This prevents unique constraint violations with soft-deleted records
if err := tx.Unscoped().Where("user_id = ?", userID).Delete(&domain.UserLoginID{}).Error; err != nil {
return err
}
// Insert new login IDs if any
if len(loginIDs) > 0 {
if err := tx.Create(&loginIDs).Error; err != nil {
return err
}
}
return nil
})
}
func (r *userRepository) GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) {
var results []domain.UserLoginID
if err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&results).Error; err != nil {
return nil, err
}
return results, nil
}
func (r *userRepository) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) {
var count int64
if err := r.db.WithContext(ctx).Model(&domain.UserLoginID{}).Where("login_id = ?", loginID).Count(&count).Error; err != nil {
return false, err
}
return count > 0, nil
}
func (r *userRepository) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) {
var record domain.UserLoginID
if err := r.db.WithContext(ctx).Where("login_id = ?", loginID).First(&record).Error; err != nil {
return "", err
}
return record.TenantID, nil
}
func (r *userRepository) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) {
var users []domain.User
err := r.db.WithContext(ctx).Where("tenant_id IN ?", tenantIDs).Find(&users).Error
return users, err
}
func (r *userRepository) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) {
var users []domain.User
err := r.db.WithContext(ctx).
Joins("JOIN tenants ON users.tenant_id = tenants.id").
Where("LOWER(tenants.slug) IN ?", lowerStrings(codes)).
Preload("Tenant").
Find(&users).Error
return users, err
}