package repository import ( "baron-sso-backend/internal/domain" "context" "strings" "github.com/lib/pq" "gorm.io/gorm" "gorm.io/gorm/clause" ) type UserRepository interface { Create(ctx context.Context, user *domain.User) error Update(ctx context.Context, user *domain.User) error FindByEmail(ctx context.Context, email string) (*domain.User, error) FindByID(ctx context.Context, id string) (*domain.User, error) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) List(ctx context.Context, offset, limit int, search string, companyCode string) ([]domain.User, int64, error) CountByTenant(ctx context.Context, tenantID string) (int64, error) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) Delete(ctx context.Context, id string) error // Multiple identifiers support UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) } type userRepository struct { db *gorm.DB } func NewUserRepository(db *gorm.DB) UserRepository { return &userRepository{db: db} } func (r *userRepository) Create(ctx context.Context, user *domain.User) error { return r.db.WithContext(ctx).Create(user).Error } func (r *userRepository) Update(ctx context.Context, user *domain.User) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // 1. Resolve email conflicts: If another user in the local DB has this email but a different ID, // we must remove the old local record because Kratos is the source of truth for ID <-> Email mapping. var existing domain.User if err := tx.Unscoped().Where("email = ?", user.Email).First(&existing).Error; err == nil { if existing.ID != user.ID { // Delete associated login IDs first to prevent FK constraint violation if err := tx.Unscoped().Where("user_id = ?", existing.ID).Delete(&domain.UserLoginID{}).Error; err != nil { return err } // Different ID holds this email locally. Hard delete the old record to avoid constraint violation. if err := tx.Unscoped().Delete(&domain.User{}, "id = ?", existing.ID).Error; err != nil { return err } } } // 2. Perform Upsert based on ID. // In GORM v2, true upsert requires Create() with OnConflict on the primary key. return tx.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "id"}}, UpdateAll: true, }).Create(user).Error }) } func (r *userRepository) FindByEmail(ctx context.Context, email string) (*domain.User, error) { var user domain.User if err := r.db.WithContext(ctx).Preload("Tenant").Where("email = ?", email).First(&user).Error; err != nil { return nil, err } return &user, nil } func (r *userRepository) FindByID(ctx context.Context, id string) (*domain.User, error) { var user domain.User if err := r.db.WithContext(ctx).Preload("Tenant").Where("id = ?", id).First(&user).Error; err != nil { return nil, err } return &user, nil } func (r *userRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) { var users []domain.User if len(ids) == 0 { return users, nil } if err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&users).Error; err != nil { return nil, err } return users, nil } func (r *userRepository) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) { var users []domain.User if err := r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Find(&users).Error; err != nil { return nil, err } return users, nil } func (r *userRepository) CountByTenant(ctx context.Context, tenantID string) (int64, error) { var count int64 err := r.db.WithContext(ctx).Model(&domain.User{}).Where("tenant_id = ?", tenantID).Count(&count).Error return count, err } func (r *userRepository) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) { type result struct { TenantID *string Count int64 } var results []result counts := make(map[string]int64) if len(tenantIDs) == 0 { return counts, nil } if err := r.db.WithContext(ctx).Model(&domain.User{}). Select("tenant_id, count(*) as count"). Where("tenant_id IN ?", tenantIDs). Group("tenant_id"). Find(&results).Error; err != nil { return nil, err } for _, res := range results { if res.TenantID != nil && *res.TenantID != "" { counts[*res.TenantID] = res.Count } } // Ensure all requested tenant IDs are in the map, even if count is 0 for _, id := range tenantIDs { if _, ok := counts[id]; !ok { counts[id] = 0 } } return counts, nil } func (r *userRepository) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) { if len(codes) == 0 { return make(map[string]int64), nil } type result struct { CompanyCode string Count int64 } var results []result lowerCodes := lowerStrings(codes) // Combine singular company_code and array company_codes using a subquery // to ensure we count each user accurately per company code they belong to. query := ` SELECT LOWER(comp_code) as company_code, count(DISTINCT id) as count FROM ( SELECT id, company_code as comp_code FROM users WHERE deleted_at IS NULL AND LOWER(company_code) = ANY($1) UNION ALL SELECT id, unnest(company_codes) as comp_code FROM users WHERE deleted_at IS NULL AND company_codes IS NOT NULL ) as combined WHERE LOWER(comp_code) = ANY($1) GROUP BY LOWER(comp_code) ` err := r.db.WithContext(ctx).Raw(query, pq.Array(lowerCodes)).Scan(&results).Error if err != nil { return nil, err } counts := make(map[string]int64) for _, res := range results { counts[strings.ToLower(res.CompanyCode)] = res.Count } // Ensure all requested codes are present in results (even if count is 0) for _, code := range codes { lower := strings.ToLower(strings.TrimSpace(code)) if _, ok := counts[lower]; !ok { counts[lower] = 0 } } return counts, nil } func lowerStrings(arr []string) []string { res := make([]string, len(arr)) for i, s := range arr { res[i] = strings.ToLower(strings.TrimSpace(s)) } return res } func (r *userRepository) List(ctx context.Context, offset, limit int, search string, companyCode string) ([]domain.User, int64, error) { var users []domain.User var total int64 db := r.db.WithContext(ctx).Model(&domain.User{}) if companyCode != "" { // [Matrix Fix] Match users either by their primary company code OR by being in the company_codes array OR by tenant slug db = db.Joins("LEFT JOIN tenants ON users.tenant_id = tenants.id"). Where("users.company_code = ? OR ? = ANY(users.company_codes) OR tenants.slug = ?", companyCode, companyCode, companyCode) } if search != "" { searchTerm := "%" + search + "%" db = db.Where("(users.email LIKE ? OR users.name LIKE ? OR users.company_code LIKE ? OR ? = ANY(users.company_codes) OR users.metadata::text LIKE ?)", searchTerm, searchTerm, searchTerm, search, searchTerm) } if err := db.Count(&total).Error; err != nil { return nil, 0, err } if err := db.Offset(offset).Limit(limit).Preload("Tenant").Find(&users).Error; err != nil { return nil, 0, err } return users, total, nil } func (r *userRepository) Delete(ctx context.Context, id string) error { return r.db.WithContext(ctx).Delete(&domain.User{}, "id = ?", id).Error } func (r *userRepository) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // Delete existing login IDs for this user if err := tx.Where("user_id = ?", userID).Delete(&domain.UserLoginID{}).Error; err != nil { return err } // Insert new login IDs if any if len(loginIDs) > 0 { if err := tx.Create(&loginIDs).Error; err != nil { return err } } return nil }) } func (r *userRepository) GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) { var results []domain.UserLoginID if err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&results).Error; err != nil { return nil, err } return results, nil } func (r *userRepository) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) { var count int64 if err := r.db.WithContext(ctx).Model(&domain.UserLoginID{}).Where("login_id = ?", loginID).Count(&count).Error; err != nil { return false, err } return count > 0, nil } func (r *userRepository) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) { var record domain.UserLoginID if err := r.db.WithContext(ctx).Where("login_id = ?", loginID).First(&record).Error; err != nil { return "", err } return record.TenantID, nil } func (r *userRepository) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) { var users []domain.User err := r.db.WithContext(ctx).Where("tenant_id IN ?", tenantIDs).Find(&users).Error return users, err } func (r *userRepository) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) { var users []domain.User err := r.db.WithContext(ctx).Where("company_code IN ?", codes).Find(&users).Error return users, err }