첫 커밋: 로컬 프로젝트 업로드
This commit is contained in:
335
baron-sso/backend/internal/repository/user_repository.go
Normal file
335
baron-sso/backend/internal/repository/user_repository.go
Normal file
@@ -0,0 +1,335 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/pagination"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type UserRepository interface {
|
||||
Create(ctx context.Context, user *domain.User) error
|
||||
Update(ctx context.Context, user *domain.User) error
|
||||
FindByEmail(ctx context.Context, email string) (*domain.User, error)
|
||||
FindByID(ctx context.Context, id string) (*domain.User, error)
|
||||
FindByIDs(ctx context.Context, ids []string) ([]domain.User, error)
|
||||
ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error)
|
||||
List(ctx context.Context, offset, limit int, search string, tenantIDs []string, cursor string) ([]domain.User, int64, string, error)
|
||||
CountByTenant(ctx context.Context, tenantID string) (int64, error)
|
||||
CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error)
|
||||
CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error)
|
||||
FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error)
|
||||
FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
DB() *gorm.DB
|
||||
|
||||
// Multiple identifiers support
|
||||
UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error
|
||||
GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error)
|
||||
IsLoginIDTaken(ctx context.Context, loginID string) (bool, error)
|
||||
FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error)
|
||||
}
|
||||
|
||||
type userRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewUserRepository(db *gorm.DB) UserRepository {
|
||||
return &userRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *userRepository) DB() *gorm.DB {
|
||||
return r.db
|
||||
}
|
||||
|
||||
func (r *userRepository) withTenantMembershipFilter(db *gorm.DB, tenantIDs []string) *gorm.DB {
|
||||
if len(tenantIDs) == 0 {
|
||||
return db
|
||||
}
|
||||
clauses := []string{"tenant_id IN ?"}
|
||||
args := []any{tenantIDs}
|
||||
for _, tenantID := range tenantIDs {
|
||||
tenantID = strings.TrimSpace(tenantID)
|
||||
if tenantID == "" {
|
||||
continue
|
||||
}
|
||||
payload, err := json.Marshal(map[string]any{
|
||||
"additionalAppointments": []map[string]string{
|
||||
{"tenantId": tenantID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
clauses = append(clauses, "metadata @> ?::jsonb")
|
||||
args = append(args, string(payload))
|
||||
}
|
||||
return db.Where("("+strings.Join(clauses, " OR ")+")", args...)
|
||||
}
|
||||
|
||||
func (r *userRepository) Create(ctx context.Context, user *domain.User) error {
|
||||
return r.db.WithContext(ctx).Create(user).Error
|
||||
}
|
||||
|
||||
func (r *userRepository) Update(ctx context.Context, user *domain.User) error {
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// 1. Check for email collision (including soft-deleted)
|
||||
var existing domain.User
|
||||
if err := tx.Unscoped().Where("email = ?", user.Email).First(&existing).Error; err == nil {
|
||||
// If email exists but ID is different, we MUST clear the old one to avoid unique constraint violation
|
||||
if existing.ID != user.ID {
|
||||
// [Restored] Check if the existing user is archived
|
||||
if strings.EqualFold(strings.TrimSpace(existing.Status), domain.UserStatusArchived) {
|
||||
return fmt.Errorf("email is reserved by archived user: %s", user.Email)
|
||||
}
|
||||
|
||||
// HARD DELETE the old record and its associated login IDs to free up the email and identifiers
|
||||
if err := tx.Unscoped().Where("user_id = ?", existing.ID).Delete(&domain.UserLoginID{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Unscoped().Delete(&domain.User{}, "id = ?", existing.ID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Perform Upsert on the new/target ID
|
||||
return tx.Unscoped().Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "id"}},
|
||||
DoUpdates: clause.Assignments(map[string]any{
|
||||
"email": user.Email,
|
||||
"name": user.Name,
|
||||
"phone": user.Phone,
|
||||
"role": user.Role,
|
||||
"status": user.Status,
|
||||
"department": user.Department,
|
||||
"grade": user.Grade,
|
||||
"position": user.Position,
|
||||
"job_title": user.JobTitle,
|
||||
"metadata": user.Metadata,
|
||||
"tenant_id": user.TenantID,
|
||||
"affiliation_type": user.AffiliationType,
|
||||
"updated_at": user.UpdatedAt,
|
||||
"deleted_at": nil, // Ensure it's active
|
||||
}),
|
||||
}).Create(user).Error
|
||||
})
|
||||
}
|
||||
|
||||
func (r *userRepository) FindByEmail(ctx context.Context, email string) (*domain.User, error) {
|
||||
var user domain.User
|
||||
if err := r.db.WithContext(ctx).Preload("Tenant").Where("email = ?", email).First(&user).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) FindByID(ctx context.Context, id string) (*domain.User, error) {
|
||||
var user domain.User
|
||||
if err := r.db.WithContext(ctx).Preload("Tenant").Where("id = ?", id).First(&user).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) {
|
||||
var users []domain.User
|
||||
if len(ids) == 0 {
|
||||
return users, nil
|
||||
}
|
||||
if err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
|
||||
var users []domain.User
|
||||
if err := r.withTenantMembershipFilter(r.db.WithContext(ctx), []string{tenantID}).Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
|
||||
var count int64
|
||||
err := r.withTenantMembershipFilter(r.db.WithContext(ctx).Model(&domain.User{}), []string{tenantID}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *userRepository) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
|
||||
counts := make(map[string]int64)
|
||||
|
||||
if len(tenantIDs) == 0 {
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
for _, tenantID := range tenantIDs {
|
||||
var count int64
|
||||
if err := r.withTenantMembershipFilter(r.db.WithContext(ctx).Model(&domain.User{}), []string{tenantID}).Count(&count).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts[tenantID] = count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) {
|
||||
if len(codes) == 0 {
|
||||
return make(map[string]int64), nil
|
||||
}
|
||||
|
||||
type result struct {
|
||||
TenantSlug string
|
||||
Count int64
|
||||
}
|
||||
var results []result
|
||||
|
||||
lowerCodes := lowerStrings(codes)
|
||||
|
||||
if err := r.db.WithContext(ctx).Table("users").
|
||||
Select("LOWER(tenants.slug) AS tenant_slug, count(DISTINCT users.id) AS count").
|
||||
Joins("JOIN tenants ON users.tenant_id = tenants.id").
|
||||
Where("users.deleted_at IS NULL AND LOWER(tenants.slug) IN ?", lowerCodes).
|
||||
Group("LOWER(tenants.slug)").
|
||||
Scan(&results).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
counts := make(map[string]int64)
|
||||
for _, res := range results {
|
||||
counts[strings.ToLower(res.TenantSlug)] = res.Count
|
||||
}
|
||||
|
||||
// Ensure all requested codes are present in results (even if count is 0)
|
||||
for _, code := range codes {
|
||||
lower := strings.ToLower(strings.TrimSpace(code))
|
||||
if _, ok := counts[lower]; !ok {
|
||||
counts[lower] = 0
|
||||
}
|
||||
}
|
||||
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func lowerStrings(arr []string) []string {
|
||||
res := make([]string, len(arr))
|
||||
for i, s := range arr {
|
||||
res[i] = strings.ToLower(strings.TrimSpace(s))
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (r *userRepository) List(ctx context.Context, offset, limit int, search string, tenantIDs []string, cursorRaw string) ([]domain.User, int64, string, error) {
|
||||
var users []domain.User
|
||||
var total int64
|
||||
db := r.db.WithContext(ctx).Model(&domain.User{})
|
||||
|
||||
if len(tenantIDs) > 0 {
|
||||
db = r.withTenantMembershipFilter(db, tenantIDs)
|
||||
}
|
||||
|
||||
if search != "" {
|
||||
searchTerm := "%" + search + "%"
|
||||
db = db.Where("(users.email LIKE ? OR users.name LIKE ? OR users.metadata::text LIKE ?)",
|
||||
searchTerm, searchTerm, searchTerm)
|
||||
}
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
|
||||
if cursorRaw != "" {
|
||||
cursor, err := pagination.Decode(cursorRaw)
|
||||
if err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
db = pagination.ApplyCreatedAtIDCursor(db, cursor, "created_at", "id")
|
||||
} else {
|
||||
db = db.Offset(offset)
|
||||
}
|
||||
|
||||
if err := db.Order("created_at desc, id desc").Limit(limit + 1).Preload("Tenant").Find(&users).Error; err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
|
||||
var items []domain.User
|
||||
var nextCursor string
|
||||
if len(users) > limit {
|
||||
items = users[:limit]
|
||||
last := items[limit-1]
|
||||
nextCursor = pagination.Encode(last.CreatedAt, last.ID)
|
||||
} else {
|
||||
items = users
|
||||
}
|
||||
|
||||
return items, total, nextCursor, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&domain.User{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error {
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// [FIX] Use Unscoped to permanently delete existing login IDs for this user
|
||||
// This prevents unique constraint violations with soft-deleted records
|
||||
if err := tx.Unscoped().Where("user_id = ?", userID).Delete(&domain.UserLoginID{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Insert new login IDs if any
|
||||
if len(loginIDs) > 0 {
|
||||
if err := tx.Create(&loginIDs).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *userRepository) GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) {
|
||||
var results []domain.UserLoginID
|
||||
if err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&results).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) {
|
||||
var count int64
|
||||
if err := r.db.WithContext(ctx).Model(&domain.UserLoginID{}).Where("login_id = ?", loginID).Count(&count).Error; err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) {
|
||||
var record domain.UserLoginID
|
||||
if err := r.db.WithContext(ctx).Where("login_id = ?", loginID).First(&record).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
return record.TenantID, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) {
|
||||
var users []domain.User
|
||||
err := r.withTenantMembershipFilter(r.db.WithContext(ctx), tenantIDs).Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
func (r *userRepository) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) {
|
||||
var users []domain.User
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("JOIN tenants ON users.tenant_id = tenants.id").
|
||||
Where("LOWER(tenants.slug) IN ?", lowerStrings(codes)).
|
||||
Preload("Tenant").
|
||||
Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
Reference in New Issue
Block a user