package repository import ( "baron-sso-backend/internal/domain" "context" "strings" "gorm.io/gorm" "gorm.io/gorm/clause" ) type UserRepository interface { Create(ctx context.Context, user *domain.User) error Update(ctx context.Context, user *domain.User) error FindByEmail(ctx context.Context, email string) (*domain.User, error) FindByID(ctx context.Context, id string) (*domain.User, error) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) List(ctx context.Context, offset, limit int, search string, companyCode string) ([]domain.User, int64, error) CountByTenant(ctx context.Context, tenantID string) (int64, error) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) Delete(ctx context.Context, id string) error } type userRepository struct { db *gorm.DB } func NewUserRepository(db *gorm.DB) UserRepository { return &userRepository{db: db} } func (r *userRepository) Create(ctx context.Context, user *domain.User) error { return r.db.WithContext(ctx).Create(user).Error } func (r *userRepository) Update(ctx context.Context, user *domain.User) error { // Use Upsert logic: if email exists, update all fields return r.db.WithContext(ctx).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "email"}}, UpdateAll: true, }).Save(user).Error } func (r *userRepository) FindByEmail(ctx context.Context, email string) (*domain.User, error) { var user domain.User if err := r.db.WithContext(ctx).Preload("Tenant").Where("email = ?", email).First(&user).Error; err != nil { return nil, err } return &user, nil } func (r *userRepository) FindByID(ctx context.Context, id string) (*domain.User, error) { var user domain.User if err := r.db.WithContext(ctx).Preload("Tenant").Where("id = ?", id).First(&user).Error; err != nil { return nil, err } return &user, nil } func (r *userRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) { var users []domain.User if len(ids) == 0 { return users, nil } if err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&users).Error; err != nil { return nil, err } return users, nil } func (r *userRepository) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) { var users []domain.User if err := r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Find(&users).Error; err != nil { return nil, err } return users, nil } func (r *userRepository) CountByTenant(ctx context.Context, tenantID string) (int64, error) { var count int64 err := r.db.WithContext(ctx).Model(&domain.User{}).Where("tenant_id = ?", tenantID).Count(&count).Error return count, err } func (r *userRepository) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) { type result struct { TenantID *string Count int64 } var results []result counts := make(map[string]int64) if len(tenantIDs) == 0 { return counts, nil } if err := r.db.WithContext(ctx).Model(&domain.User{}). Select("tenant_id, count(*) as count"). Where("tenant_id IN ?", tenantIDs). Group("tenant_id"). Find(&results).Error; err != nil { return nil, err } for _, res := range results { if res.TenantID != nil && *res.TenantID != "" { counts[*res.TenantID] = res.Count } } // Ensure all requested tenant IDs are in the map, even if count is 0 for _, id := range tenantIDs { if _, ok := counts[id]; !ok { counts[id] = 0 } } return counts, nil } func (r *userRepository) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) { if len(codes) == 0 { return make(map[string]int64), nil } type result struct { CompanyCode string Count int64 } var results []result // Search by company_code directly. Normalize inputs using LOWER for robust matching. err := r.db.WithContext(ctx).Model(&domain.User{}). Select("LOWER(company_code) as company_code, count(*) as count"). Where("LOWER(company_code) IN ?", lowerStrings(codes)). Group("LOWER(company_code)"). Scan(&results).Error if err != nil { return nil, err } counts := make(map[string]int64) for _, res := range results { counts[res.CompanyCode] = res.Count } // Ensure all requested codes are present in results (even if count is 0) for _, code := range codes { lower := strings.ToLower(strings.TrimSpace(code)) if _, ok := counts[lower]; !ok { counts[lower] = 0 } } return counts, nil } func lowerStrings(arr []string) []string { res := make([]string, len(arr)) for i, s := range arr { res[i] = strings.ToLower(strings.TrimSpace(s)) } return res } func (r *userRepository) List(ctx context.Context, offset, limit int, search string, companyCode string) ([]domain.User, int64, error) { var users []domain.User var total int64 db := r.db.WithContext(ctx).Model(&domain.User{}) if companyCode != "" { db = db.Where("company_code = ?", companyCode) } if search != "" { searchTerm := "%" + search + "%" // Search in basic fields and metadata (PostgreSQL JSONB) db = db.Where("(email LIKE ? OR name LIKE ? OR company_code LIKE ? OR metadata::text LIKE ?)", searchTerm, searchTerm, searchTerm, searchTerm) } if err := db.Count(&total).Error; err != nil { return nil, 0, err } if err := db.Offset(offset).Limit(limit).Find(&users).Error; err != nil { return nil, 0, err } return users, total, nil } func (r *userRepository) Delete(ctx context.Context, id string) error { return r.db.WithContext(ctx).Delete(&domain.User{}, "id = ?", id).Error }