forked from baron/baron-sso
288 lines
9.4 KiB
Go
288 lines
9.4 KiB
Go
package repository
|
|
|
|
import (
|
|
"baron-sso-backend/internal/domain"
|
|
"context"
|
|
"strings"
|
|
|
|
"github.com/lib/pq"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
)
|
|
|
|
type UserRepository interface {
|
|
Create(ctx context.Context, user *domain.User) error
|
|
Update(ctx context.Context, user *domain.User) error
|
|
FindByEmail(ctx context.Context, email string) (*domain.User, error)
|
|
FindByID(ctx context.Context, id string) (*domain.User, error)
|
|
FindByIDs(ctx context.Context, ids []string) ([]domain.User, error)
|
|
ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error)
|
|
List(ctx context.Context, offset, limit int, search string, companyCode string) ([]domain.User, int64, error)
|
|
CountByTenant(ctx context.Context, tenantID string) (int64, error)
|
|
CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error)
|
|
CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error)
|
|
FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error)
|
|
FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error)
|
|
Delete(ctx context.Context, id string) error
|
|
|
|
// Multiple identifiers support
|
|
UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error
|
|
GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error)
|
|
IsLoginIDTaken(ctx context.Context, loginID string) (bool, error)
|
|
FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error)
|
|
}
|
|
|
|
type userRepository struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
func NewUserRepository(db *gorm.DB) UserRepository {
|
|
return &userRepository{db: db}
|
|
}
|
|
|
|
func (r *userRepository) Create(ctx context.Context, user *domain.User) error {
|
|
return r.db.WithContext(ctx).Create(user).Error
|
|
}
|
|
|
|
func (r *userRepository) Update(ctx context.Context, user *domain.User) error {
|
|
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
// 1. Resolve email conflicts: If another user in the local DB has this email but a different ID,
|
|
// we must remove the old local record because Kratos is the source of truth for ID <-> Email mapping.
|
|
var existing domain.User
|
|
if err := tx.Unscoped().Where("email = ?", user.Email).First(&existing).Error; err == nil {
|
|
if existing.ID != user.ID {
|
|
// Delete associated login IDs first to prevent FK constraint violation
|
|
if err := tx.Unscoped().Where("user_id = ?", existing.ID).Delete(&domain.UserLoginID{}).Error; err != nil {
|
|
return err
|
|
}
|
|
// Different ID holds this email locally. Hard delete the old record to avoid constraint violation.
|
|
if err := tx.Unscoped().Delete(&domain.User{}, "id = ?", existing.ID).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
// 2. Perform Upsert based on ID.
|
|
// In GORM v2, true upsert requires Create() with OnConflict on the primary key.
|
|
return tx.Clauses(clause.OnConflict{
|
|
Columns: []clause.Column{{Name: "id"}},
|
|
UpdateAll: true,
|
|
}).Create(user).Error
|
|
})
|
|
}
|
|
|
|
func (r *userRepository) FindByEmail(ctx context.Context, email string) (*domain.User, error) {
|
|
var user domain.User
|
|
if err := r.db.WithContext(ctx).Preload("Tenant").Where("email = ?", email).First(&user).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
func (r *userRepository) FindByID(ctx context.Context, id string) (*domain.User, error) {
|
|
var user domain.User
|
|
if err := r.db.WithContext(ctx).Preload("Tenant").Where("id = ?", id).First(&user).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
func (r *userRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) {
|
|
var users []domain.User
|
|
if len(ids) == 0 {
|
|
return users, nil
|
|
}
|
|
if err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&users).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
func (r *userRepository) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
|
|
var users []domain.User
|
|
if err := r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Find(&users).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
func (r *userRepository) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
|
|
var count int64
|
|
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("tenant_id = ?", tenantID).Count(&count).Error
|
|
return count, err
|
|
}
|
|
|
|
func (r *userRepository) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
|
|
type result struct {
|
|
TenantID *string
|
|
Count int64
|
|
}
|
|
var results []result
|
|
counts := make(map[string]int64)
|
|
|
|
if len(tenantIDs) == 0 {
|
|
return counts, nil
|
|
}
|
|
|
|
if err := r.db.WithContext(ctx).Model(&domain.User{}).
|
|
Select("tenant_id, count(*) as count").
|
|
Where("tenant_id IN ?", tenantIDs).
|
|
Group("tenant_id").
|
|
Find(&results).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, res := range results {
|
|
if res.TenantID != nil && *res.TenantID != "" {
|
|
counts[*res.TenantID] = res.Count
|
|
}
|
|
}
|
|
// Ensure all requested tenant IDs are in the map, even if count is 0
|
|
for _, id := range tenantIDs {
|
|
if _, ok := counts[id]; !ok {
|
|
counts[id] = 0
|
|
}
|
|
}
|
|
return counts, nil
|
|
}
|
|
|
|
func (r *userRepository) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) {
|
|
if len(codes) == 0 {
|
|
return make(map[string]int64), nil
|
|
}
|
|
|
|
type result struct {
|
|
CompanyCode string
|
|
Count int64
|
|
}
|
|
var results []result
|
|
|
|
lowerCodes := lowerStrings(codes)
|
|
|
|
// Combine singular company_code and array company_codes using a subquery
|
|
// to ensure we count each user accurately per company code they belong to.
|
|
query := `
|
|
SELECT LOWER(comp_code) as company_code, count(DISTINCT id) as count
|
|
FROM (
|
|
SELECT id, company_code as comp_code FROM users WHERE deleted_at IS NULL AND LOWER(company_code) = ANY($1)
|
|
UNION ALL
|
|
SELECT id, unnest(company_codes) as comp_code FROM users WHERE deleted_at IS NULL AND company_codes IS NOT NULL
|
|
) as combined
|
|
WHERE LOWER(comp_code) = ANY($1)
|
|
GROUP BY LOWER(comp_code)
|
|
`
|
|
err := r.db.WithContext(ctx).Raw(query, pq.Array(lowerCodes)).Scan(&results).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
counts := make(map[string]int64)
|
|
for _, res := range results {
|
|
counts[strings.ToLower(res.CompanyCode)] = res.Count
|
|
}
|
|
|
|
// Ensure all requested codes are present in results (even if count is 0)
|
|
for _, code := range codes {
|
|
lower := strings.ToLower(strings.TrimSpace(code))
|
|
if _, ok := counts[lower]; !ok {
|
|
counts[lower] = 0
|
|
}
|
|
}
|
|
|
|
return counts, nil
|
|
}
|
|
|
|
func lowerStrings(arr []string) []string {
|
|
res := make([]string, len(arr))
|
|
for i, s := range arr {
|
|
res[i] = strings.ToLower(strings.TrimSpace(s))
|
|
}
|
|
return res
|
|
}
|
|
|
|
func (r *userRepository) List(ctx context.Context, offset, limit int, search string, companyCode string) ([]domain.User, int64, error) {
|
|
var users []domain.User
|
|
var total int64
|
|
db := r.db.WithContext(ctx).Model(&domain.User{})
|
|
|
|
if companyCode != "" {
|
|
// [Matrix Fix] Match users either by their primary company code OR by being in the company_codes array OR by tenant slug
|
|
db = db.Joins("LEFT JOIN tenants ON users.tenant_id = tenants.id").
|
|
Where("users.company_code = ? OR ? = ANY(users.company_codes) OR tenants.slug = ?", companyCode, companyCode, companyCode)
|
|
}
|
|
|
|
if search != "" {
|
|
searchTerm := "%" + search + "%"
|
|
db = db.Where("(users.email LIKE ? OR users.name LIKE ? OR users.company_code LIKE ? OR ? = ANY(users.company_codes) OR users.metadata::text LIKE ?)",
|
|
searchTerm, searchTerm, searchTerm, search, searchTerm)
|
|
}
|
|
|
|
if err := db.Count(&total).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
if err := db.Offset(offset).Limit(limit).Preload("Tenant").Find(&users).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
return users, total, nil
|
|
}
|
|
|
|
func (r *userRepository) Delete(ctx context.Context, id string) error {
|
|
return r.db.WithContext(ctx).Delete(&domain.User{}, "id = ?", id).Error
|
|
}
|
|
|
|
func (r *userRepository) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error {
|
|
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
// Delete existing login IDs for this user
|
|
if err := tx.Where("user_id = ?", userID).Delete(&domain.UserLoginID{}).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
// Insert new login IDs if any
|
|
if len(loginIDs) > 0 {
|
|
if err := tx.Create(&loginIDs).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (r *userRepository) GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) {
|
|
var results []domain.UserLoginID
|
|
if err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&results).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
func (r *userRepository) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) {
|
|
var count int64
|
|
if err := r.db.WithContext(ctx).Model(&domain.UserLoginID{}).Where("login_id = ?", loginID).Count(&count).Error; err != nil {
|
|
return false, err
|
|
}
|
|
return count > 0, nil
|
|
}
|
|
|
|
func (r *userRepository) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) {
|
|
var record domain.UserLoginID
|
|
if err := r.db.WithContext(ctx).Where("login_id = ?", loginID).First(&record).Error; err != nil {
|
|
return "", err
|
|
}
|
|
return record.TenantID, nil
|
|
}
|
|
|
|
func (r *userRepository) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) {
|
|
var users []domain.User
|
|
err := r.db.WithContext(ctx).Where("tenant_id IN ?", tenantIDs).Find(&users).Error
|
|
return users, err
|
|
}
|
|
|
|
func (r *userRepository) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) {
|
|
var users []domain.User
|
|
err := r.db.WithContext(ctx).Where("company_code IN ?", codes).Find(&users).Error
|
|
return users, err
|
|
}
|