package repository import ( "baron-sso-backend/internal/domain" "baron-sso-backend/internal/pagination" "context" "encoding/json" "fmt" "strings" "gorm.io/gorm" "gorm.io/gorm/clause" ) type UserRepository interface { Create(ctx context.Context, user *domain.User) error Update(ctx context.Context, user *domain.User) error FindByEmail(ctx context.Context, email string) (*domain.User, error) FindByID(ctx context.Context, id string) (*domain.User, error) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) List(ctx context.Context, offset, limit int, search string, tenantIDs []string, cursor string) ([]domain.User, int64, string, error) CountByTenant(ctx context.Context, tenantID string) (int64, error) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) Delete(ctx context.Context, id string) error DB() *gorm.DB // Multiple identifiers support UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) } type userRepository struct { db *gorm.DB } func NewUserRepository(db *gorm.DB) UserRepository { return &userRepository{db: db} } func (r *userRepository) DB() *gorm.DB { return r.db } func (r *userRepository) withTenantMembershipFilter(db *gorm.DB, tenantIDs []string) *gorm.DB { if len(tenantIDs) == 0 { return db } clauses := []string{"tenant_id IN ?"} args := []any{tenantIDs} for _, tenantID := range tenantIDs { tenantID = strings.TrimSpace(tenantID) if tenantID == "" { continue } payload, err := json.Marshal(map[string]any{ "additionalAppointments": []map[string]string{ {"tenantId": tenantID}, }, }) if err != nil { continue } clauses = append(clauses, "metadata @> ?::jsonb") args = append(args, string(payload)) } return db.Where("("+strings.Join(clauses, " OR ")+")", args...) } func (r *userRepository) Create(ctx context.Context, user *domain.User) error { return r.db.WithContext(ctx).Create(user).Error } func (r *userRepository) Update(ctx context.Context, user *domain.User) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // 1. Check for email collision (including soft-deleted) var existing domain.User if err := tx.Unscoped().Where("email = ?", user.Email).First(&existing).Error; err == nil { // If email exists but ID is different, we MUST clear the old one to avoid unique constraint violation if existing.ID != user.ID { // [Restored] Check if the existing user is archived if strings.EqualFold(strings.TrimSpace(existing.Status), domain.UserStatusArchived) { return fmt.Errorf("email is reserved by archived user: %s", user.Email) } // HARD DELETE the old record and its associated login IDs to free up the email and identifiers if err := tx.Unscoped().Where("user_id = ?", existing.ID).Delete(&domain.UserLoginID{}).Error; err != nil { return err } if err := tx.Unscoped().Delete(&domain.User{}, "id = ?", existing.ID).Error; err != nil { return err } } } // 2. Perform Upsert on the new/target ID return tx.Unscoped().Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "id"}}, DoUpdates: clause.Assignments(map[string]any{ "email": user.Email, "name": user.Name, "phone": user.Phone, "role": user.Role, "status": user.Status, "department": user.Department, "grade": user.Grade, "position": user.Position, "job_title": user.JobTitle, "metadata": user.Metadata, "tenant_id": user.TenantID, "affiliation_type": user.AffiliationType, "updated_at": user.UpdatedAt, "deleted_at": nil, // Ensure it's active }), }).Create(user).Error }) } func (r *userRepository) FindByEmail(ctx context.Context, email string) (*domain.User, error) { var user domain.User if err := r.db.WithContext(ctx).Preload("Tenant").Where("email = ?", email).First(&user).Error; err != nil { return nil, err } return &user, nil } func (r *userRepository) FindByID(ctx context.Context, id string) (*domain.User, error) { var user domain.User if err := r.db.WithContext(ctx).Preload("Tenant").Where("id = ?", id).First(&user).Error; err != nil { return nil, err } return &user, nil } func (r *userRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) { var users []domain.User if len(ids) == 0 { return users, nil } if err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&users).Error; err != nil { return nil, err } return users, nil } func (r *userRepository) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) { var users []domain.User if err := r.withTenantMembershipFilter(r.db.WithContext(ctx), []string{tenantID}).Find(&users).Error; err != nil { return nil, err } return users, nil } func (r *userRepository) CountByTenant(ctx context.Context, tenantID string) (int64, error) { var count int64 err := r.withTenantMembershipFilter(r.db.WithContext(ctx).Model(&domain.User{}), []string{tenantID}).Count(&count).Error return count, err } func (r *userRepository) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) { counts := make(map[string]int64) if len(tenantIDs) == 0 { return counts, nil } for _, tenantID := range tenantIDs { var count int64 if err := r.withTenantMembershipFilter(r.db.WithContext(ctx).Model(&domain.User{}), []string{tenantID}).Count(&count).Error; err != nil { return nil, err } counts[tenantID] = count } return counts, nil } func (r *userRepository) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) { if len(codes) == 0 { return make(map[string]int64), nil } type result struct { TenantSlug string Count int64 } var results []result lowerCodes := lowerStrings(codes) if err := r.db.WithContext(ctx).Table("users"). Select("LOWER(tenants.slug) AS tenant_slug, count(DISTINCT users.id) AS count"). Joins("JOIN tenants ON users.tenant_id = tenants.id"). Where("users.deleted_at IS NULL AND LOWER(tenants.slug) IN ?", lowerCodes). Group("LOWER(tenants.slug)"). Scan(&results).Error; err != nil { return nil, err } counts := make(map[string]int64) for _, res := range results { counts[strings.ToLower(res.TenantSlug)] = res.Count } // Ensure all requested codes are present in results (even if count is 0) for _, code := range codes { lower := strings.ToLower(strings.TrimSpace(code)) if _, ok := counts[lower]; !ok { counts[lower] = 0 } } return counts, nil } func lowerStrings(arr []string) []string { res := make([]string, len(arr)) for i, s := range arr { res[i] = strings.ToLower(strings.TrimSpace(s)) } return res } func (r *userRepository) List(ctx context.Context, offset, limit int, search string, tenantIDs []string, cursorRaw string) ([]domain.User, int64, string, error) { var users []domain.User var total int64 db := r.db.WithContext(ctx).Model(&domain.User{}) if len(tenantIDs) > 0 { db = r.withTenantMembershipFilter(db, tenantIDs) } if search != "" { searchTerm := "%" + search + "%" db = db.Where("(users.email LIKE ? OR users.name LIKE ? OR users.metadata::text LIKE ?)", searchTerm, searchTerm, searchTerm) } if err := db.Count(&total).Error; err != nil { return nil, 0, "", err } if cursorRaw != "" { cursor, err := pagination.Decode(cursorRaw) if err != nil { return nil, 0, "", err } db = pagination.ApplyCreatedAtIDCursor(db, cursor, "created_at", "id") } else { db = db.Offset(offset) } if err := db.Order("created_at desc, id desc").Limit(limit + 1).Preload("Tenant").Find(&users).Error; err != nil { return nil, 0, "", err } var items []domain.User var nextCursor string if len(users) > limit { items = users[:limit] last := items[limit-1] nextCursor = pagination.Encode(last.CreatedAt, last.ID) } else { items = users } return items, total, nextCursor, nil } func (r *userRepository) Delete(ctx context.Context, id string) error { return r.db.WithContext(ctx).Delete(&domain.User{}, "id = ?", id).Error } func (r *userRepository) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // [FIX] Use Unscoped to permanently delete existing login IDs for this user // This prevents unique constraint violations with soft-deleted records if err := tx.Unscoped().Where("user_id = ?", userID).Delete(&domain.UserLoginID{}).Error; err != nil { return err } // Insert new login IDs if any if len(loginIDs) > 0 { if err := tx.Create(&loginIDs).Error; err != nil { return err } } return nil }) } func (r *userRepository) GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) { var results []domain.UserLoginID if err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&results).Error; err != nil { return nil, err } return results, nil } func (r *userRepository) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) { var count int64 if err := r.db.WithContext(ctx).Model(&domain.UserLoginID{}).Where("login_id = ?", loginID).Count(&count).Error; err != nil { return false, err } return count > 0, nil } func (r *userRepository) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) { var record domain.UserLoginID if err := r.db.WithContext(ctx).Where("login_id = ?", loginID).First(&record).Error; err != nil { return "", err } return record.TenantID, nil } func (r *userRepository) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) { var users []domain.User err := r.withTenantMembershipFilter(r.db.WithContext(ctx), tenantIDs).Find(&users).Error return users, err } func (r *userRepository) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) { var users []domain.User err := r.db.WithContext(ctx). Joins("JOIN tenants ON users.tenant_id = tenants.id"). Where("LOWER(tenants.slug) IN ?", lowerStrings(codes)). Preload("Tenant"). Find(&users).Error return users, err }