package repository import ( "baron-sso-backend/internal/domain" "context" "fmt" "strings" "gorm.io/gorm" "gorm.io/gorm/clause" ) type UserRepository interface { Create(ctx context.Context, user *domain.User) error Update(ctx context.Context, user *domain.User) error FindByEmail(ctx context.Context, email string) (*domain.User, error) FindByID(ctx context.Context, id string) (*domain.User, error) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) List(ctx context.Context, offset, limit int, search string, tenantSlug string) ([]domain.User, int64, error) CountByTenant(ctx context.Context, tenantID string) (int64, error) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) Delete(ctx context.Context, id string) error DB() *gorm.DB // Multiple identifiers support UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) } type userRepository struct { db *gorm.DB } func NewUserRepository(db *gorm.DB) UserRepository { return &userRepository{db: db} } func (r *userRepository) DB() *gorm.DB { return r.db } func (r *userRepository) Create(ctx context.Context, user *domain.User) error { return r.db.WithContext(ctx).Create(user).Error } func (r *userRepository) Update(ctx context.Context, user *domain.User) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // 1. Check for email collision (including soft-deleted) var existing domain.User if err := tx.Unscoped().Where("email = ?", user.Email).First(&existing).Error; err == nil { // If email exists but ID is different, we MUST clear the old one to avoid unique constraint violation if existing.ID != user.ID { // [Restored] Check if the existing user is archived if strings.EqualFold(strings.TrimSpace(existing.Status), domain.UserStatusArchived) { return fmt.Errorf("email is reserved by archived user: %s", user.Email) } // HARD DELETE the old record and its associated login IDs to free up the email and identifiers if err := tx.Unscoped().Where("user_id = ?", existing.ID).Delete(&domain.UserLoginID{}).Error; err != nil { return err } if err := tx.Unscoped().Delete(&domain.User{}, "id = ?", existing.ID).Error; err != nil { return err } } } // 2. Perform Upsert on the new/target ID return tx.Unscoped().Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "id"}}, DoUpdates: clause.Assignments(map[string]any{ "email": user.Email, "name": user.Name, "phone": user.Phone, "role": user.Role, "status": user.Status, "department": user.Department, "grade": user.Grade, "position": user.Position, "job_title": user.JobTitle, "metadata": user.Metadata, "tenant_id": user.TenantID, "affiliation_type": user.AffiliationType, "updated_at": user.UpdatedAt, "deleted_at": nil, // Ensure it's active }), }).Create(user).Error }) } func (r *userRepository) FindByEmail(ctx context.Context, email string) (*domain.User, error) { var user domain.User if err := r.db.WithContext(ctx).Preload("Tenant").Where("email = ?", email).First(&user).Error; err != nil { return nil, err } return &user, nil } func (r *userRepository) FindByID(ctx context.Context, id string) (*domain.User, error) { var user domain.User if err := r.db.WithContext(ctx).Preload("Tenant").Where("id = ?", id).First(&user).Error; err != nil { return nil, err } return &user, nil } func (r *userRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) { var users []domain.User if len(ids) == 0 { return users, nil } if err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&users).Error; err != nil { return nil, err } return users, nil } func (r *userRepository) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) { var users []domain.User if err := r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Find(&users).Error; err != nil { return nil, err } return users, nil } func (r *userRepository) CountByTenant(ctx context.Context, tenantID string) (int64, error) { var count int64 err := r.db.WithContext(ctx).Model(&domain.User{}).Where("tenant_id = ?", tenantID).Count(&count).Error return count, err } func (r *userRepository) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) { type result struct { TenantID *string Count int64 } var results []result counts := make(map[string]int64) if len(tenantIDs) == 0 { return counts, nil } if err := r.db.WithContext(ctx).Model(&domain.User{}). Select("tenant_id, count(*) as count"). Where("tenant_id IN ?", tenantIDs). Group("tenant_id"). Find(&results).Error; err != nil { return nil, err } for _, res := range results { if res.TenantID != nil && *res.TenantID != "" { counts[*res.TenantID] = res.Count } } // Ensure all requested tenant IDs are in the map, even if count is 0 for _, id := range tenantIDs { if _, ok := counts[id]; !ok { counts[id] = 0 } } return counts, nil } func (r *userRepository) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) { if len(codes) == 0 { return make(map[string]int64), nil } type result struct { TenantSlug string Count int64 } var results []result lowerCodes := lowerStrings(codes) if err := r.db.WithContext(ctx).Table("users"). Select("LOWER(tenants.slug) AS tenant_slug, count(DISTINCT users.id) AS count"). Joins("JOIN tenants ON users.tenant_id = tenants.id"). Where("users.deleted_at IS NULL AND LOWER(tenants.slug) IN ?", lowerCodes). Group("LOWER(tenants.slug)"). Scan(&results).Error; err != nil { return nil, err } counts := make(map[string]int64) for _, res := range results { counts[strings.ToLower(res.TenantSlug)] = res.Count } // Ensure all requested codes are present in results (even if count is 0) for _, code := range codes { lower := strings.ToLower(strings.TrimSpace(code)) if _, ok := counts[lower]; !ok { counts[lower] = 0 } } return counts, nil } func lowerStrings(arr []string) []string { res := make([]string, len(arr)) for i, s := range arr { res[i] = strings.ToLower(strings.TrimSpace(s)) } return res } func (r *userRepository) List(ctx context.Context, offset, limit int, search string, tenantSlug string) ([]domain.User, int64, error) { var users []domain.User var total int64 db := r.db.WithContext(ctx).Model(&domain.User{}) if tenantSlug != "" { db = db.Joins("LEFT JOIN tenants ON users.tenant_id = tenants.id"). Where("tenants.slug = ?", tenantSlug) } if search != "" { searchTerm := "%" + search + "%" db = db.Where("(users.email LIKE ? OR users.name LIKE ? OR users.metadata::text LIKE ?)", searchTerm, searchTerm, searchTerm) } if err := db.Count(&total).Error; err != nil { return nil, 0, err } if err := db.Offset(offset).Limit(limit).Preload("Tenant").Find(&users).Error; err != nil { return nil, 0, err } return users, total, nil } func (r *userRepository) Delete(ctx context.Context, id string) error { return r.db.WithContext(ctx).Delete(&domain.User{}, "id = ?", id).Error } func (r *userRepository) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // [FIX] Use Unscoped to permanently delete existing login IDs for this user // This prevents unique constraint violations with soft-deleted records if err := tx.Unscoped().Where("user_id = ?", userID).Delete(&domain.UserLoginID{}).Error; err != nil { return err } // Insert new login IDs if any if len(loginIDs) > 0 { if err := tx.Create(&loginIDs).Error; err != nil { return err } } return nil }) } func (r *userRepository) GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) { var results []domain.UserLoginID if err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&results).Error; err != nil { return nil, err } return results, nil } func (r *userRepository) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) { var count int64 if err := r.db.WithContext(ctx).Model(&domain.UserLoginID{}).Where("login_id = ?", loginID).Count(&count).Error; err != nil { return false, err } return count > 0, nil } func (r *userRepository) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) { var record domain.UserLoginID if err := r.db.WithContext(ctx).Where("login_id = ?", loginID).First(&record).Error; err != nil { return "", err } return record.TenantID, nil } func (r *userRepository) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) { var users []domain.User err := r.db.WithContext(ctx).Where("tenant_id IN ?", tenantIDs).Find(&users).Error return users, err } func (r *userRepository) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) { var users []domain.User err := r.db.WithContext(ctx). Joins("JOIN tenants ON users.tenant_id = tenants.id"). Where("LOWER(tenants.slug) IN ?", lowerStrings(codes)). Preload("Tenant"). Find(&users).Error return users, err }