forked from baron/baron-sso
197 lines
5.5 KiB
Go
197 lines
5.5 KiB
Go
package repository
|
|
|
|
import (
|
|
"baron-sso-backend/internal/domain"
|
|
"context"
|
|
"strings"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type UserRepository interface {
|
|
Create(ctx context.Context, user *domain.User) error
|
|
Update(ctx context.Context, user *domain.User) error
|
|
FindByEmail(ctx context.Context, email string) (*domain.User, error)
|
|
FindByID(ctx context.Context, id string) (*domain.User, error)
|
|
FindByIDs(ctx context.Context, ids []string) ([]domain.User, error)
|
|
ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error)
|
|
List(ctx context.Context, offset, limit int, search string, companyCode string) ([]domain.User, int64, error)
|
|
CountByTenant(ctx context.Context, tenantID string) (int64, error)
|
|
CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error)
|
|
CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error)
|
|
Delete(ctx context.Context, id string) error
|
|
}
|
|
|
|
type userRepository struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
func NewUserRepository(db *gorm.DB) UserRepository {
|
|
return &userRepository{db: db}
|
|
}
|
|
|
|
func (r *userRepository) Create(ctx context.Context, user *domain.User) error {
|
|
return r.db.WithContext(ctx).Create(user).Error
|
|
}
|
|
|
|
func (r *userRepository) Update(ctx context.Context, user *domain.User) error {
|
|
return r.db.WithContext(ctx).Save(user).Error
|
|
}
|
|
|
|
func (r *userRepository) FindByEmail(ctx context.Context, email string) (*domain.User, error) {
|
|
var user domain.User
|
|
if err := r.db.WithContext(ctx).Preload("Tenant").Where("email = ?", email).First(&user).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
func (r *userRepository) FindByID(ctx context.Context, id string) (*domain.User, error) {
|
|
var user domain.User
|
|
if err := r.db.WithContext(ctx).Preload("Tenant").Where("id = ?", id).First(&user).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
func (r *userRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) {
|
|
var users []domain.User
|
|
if len(ids) == 0 {
|
|
return users, nil
|
|
}
|
|
if err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&users).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
func (r *userRepository) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
|
|
var users []domain.User
|
|
if err := r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Find(&users).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
func (r *userRepository) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
|
|
var count int64
|
|
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("tenant_id = ?", tenantID).Count(&count).Error
|
|
return count, err
|
|
}
|
|
|
|
func (r *userRepository) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
|
|
type result struct {
|
|
TenantID string
|
|
Count int64
|
|
}
|
|
var results []result
|
|
if len(tenantIDs) == 0 {
|
|
return make(map[string]int64), nil
|
|
}
|
|
if err := r.db.WithContext(ctx).Model(&domain.User{}).
|
|
Select("tenant_id, count(*) as count").
|
|
Where("tenant_id IN ?", tenantIDs).
|
|
Group("tenant_id").
|
|
Find(&results).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
counts := make(map[string]int64)
|
|
for _, res := range results {
|
|
if res.TenantID != "" {
|
|
counts[res.TenantID] = res.Count
|
|
}
|
|
}
|
|
// Ensure all requested tenant IDs are in the map, even if count is 0
|
|
for _, id := range tenantIDs {
|
|
if _, ok := counts[id]; !ok {
|
|
counts[id] = 0
|
|
}
|
|
}
|
|
return counts, nil
|
|
}
|
|
|
|
func (r *userRepository) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) {
|
|
if len(codes) == 0 {
|
|
return make(map[string]int64), nil
|
|
}
|
|
|
|
// 1. Resolve IDs for these codes to support dual counting (slug or ID)
|
|
var tenants []domain.Tenant
|
|
_ = r.db.WithContext(ctx).Where("slug IN ?", codes).Find(&tenants).Error
|
|
|
|
idToSlug := make(map[string]string)
|
|
slugToNormalized := make(map[string]string)
|
|
|
|
for _, code := range codes {
|
|
slugToNormalized[strings.ToLower(strings.TrimSpace(code))] = code
|
|
}
|
|
for _, t := range tenants {
|
|
idToSlug[t.ID] = t.Slug
|
|
}
|
|
|
|
type result struct {
|
|
CompanyCode string
|
|
TenantID string
|
|
Count int64
|
|
}
|
|
var results []result
|
|
|
|
// Use a more comprehensive aggregation
|
|
err := r.db.WithContext(ctx).Model(&domain.User{}).
|
|
Select("company_code, tenant_id, count(*) as count").
|
|
Where("company_code IN ? OR tenant_id IN (SELECT id FROM tenants WHERE slug IN ?)", codes, codes).
|
|
Group("company_code, tenant_id").
|
|
Scan(&results).Error
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
counts := make(map[string]int64)
|
|
for _, res := range results {
|
|
var slug string
|
|
if res.CompanyCode != "" {
|
|
slug = res.CompanyCode
|
|
} else if res.TenantID != "" {
|
|
slug = idToSlug[res.TenantID]
|
|
}
|
|
|
|
if slug != "" {
|
|
normalizedSlug := strings.ToLower(strings.TrimSpace(slug))
|
|
counts[normalizedSlug] += res.Count
|
|
}
|
|
}
|
|
|
|
return counts, nil
|
|
}
|
|
|
|
func (r *userRepository) List(ctx context.Context, offset, limit int, search string, companyCode string) ([]domain.User, int64, error) {
|
|
var users []domain.User
|
|
var total int64
|
|
db := r.db.WithContext(ctx).Model(&domain.User{})
|
|
|
|
if companyCode != "" {
|
|
db = db.Where("company_code = ?", companyCode)
|
|
}
|
|
|
|
if search != "" {
|
|
searchTerm := "%" + search + "%"
|
|
db = db.Where("(email LIKE ? OR name LIKE ? OR company_code LIKE ?)", searchTerm, searchTerm, searchTerm)
|
|
}
|
|
|
|
if err := db.Count(&total).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
if err := db.Offset(offset).Limit(limit).Find(&users).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
return users, total, nil
|
|
}
|
|
|
|
func (r *userRepository) Delete(ctx context.Context, id string) error {
|
|
return r.db.WithContext(ctx).Delete(&domain.User{}, "id = ?", id).Error
|
|
}
|