forked from baron/baron-sso
chore: consolidate local integration changes
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/pagination"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -46,6 +47,31 @@ func (r *userRepository) DB() *gorm.DB {
|
||||
return r.db
|
||||
}
|
||||
|
||||
func (r *userRepository) withTenantMembershipFilter(db *gorm.DB, tenantIDs []string) *gorm.DB {
|
||||
if len(tenantIDs) == 0 {
|
||||
return db
|
||||
}
|
||||
clauses := []string{"tenant_id IN ?"}
|
||||
args := []any{tenantIDs}
|
||||
for _, tenantID := range tenantIDs {
|
||||
tenantID = strings.TrimSpace(tenantID)
|
||||
if tenantID == "" {
|
||||
continue
|
||||
}
|
||||
payload, err := json.Marshal(map[string]any{
|
||||
"additionalAppointments": []map[string]string{
|
||||
{"tenantId": tenantID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
clauses = append(clauses, "metadata @> ?::jsonb")
|
||||
args = append(args, string(payload))
|
||||
}
|
||||
return db.Where("("+strings.Join(clauses, " OR ")+")", args...)
|
||||
}
|
||||
|
||||
func (r *userRepository) Create(ctx context.Context, user *domain.User) error {
|
||||
return r.db.WithContext(ctx).Create(user).Error
|
||||
}
|
||||
@@ -124,7 +150,7 @@ func (r *userRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.
|
||||
|
||||
func (r *userRepository) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
|
||||
var users []domain.User
|
||||
if err := r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Find(&users).Error; err != nil {
|
||||
if err := r.withTenantMembershipFilter(r.db.WithContext(ctx), []string{tenantID}).Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return users, nil
|
||||
@@ -132,40 +158,23 @@ func (r *userRepository) ListByTenant(ctx context.Context, tenantID string) ([]d
|
||||
|
||||
func (r *userRepository) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("tenant_id = ?", tenantID).Count(&count).Error
|
||||
err := r.withTenantMembershipFilter(r.db.WithContext(ctx).Model(&domain.User{}), []string{tenantID}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *userRepository) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
|
||||
type result struct {
|
||||
TenantID *string
|
||||
Count int64
|
||||
}
|
||||
var results []result
|
||||
counts := make(map[string]int64)
|
||||
|
||||
if len(tenantIDs) == 0 {
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Model(&domain.User{}).
|
||||
Select("tenant_id, count(*) as count").
|
||||
Where("tenant_id IN ?", tenantIDs).
|
||||
Group("tenant_id").
|
||||
Find(&results).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, res := range results {
|
||||
if res.TenantID != nil && *res.TenantID != "" {
|
||||
counts[*res.TenantID] = res.Count
|
||||
}
|
||||
}
|
||||
// Ensure all requested tenant IDs are in the map, even if count is 0
|
||||
for _, id := range tenantIDs {
|
||||
if _, ok := counts[id]; !ok {
|
||||
counts[id] = 0
|
||||
for _, tenantID := range tenantIDs {
|
||||
var count int64
|
||||
if err := r.withTenantMembershipFilter(r.db.WithContext(ctx).Model(&domain.User{}), []string{tenantID}).Count(&count).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts[tenantID] = count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
@@ -222,7 +231,7 @@ func (r *userRepository) List(ctx context.Context, offset, limit int, search str
|
||||
db := r.db.WithContext(ctx).Model(&domain.User{})
|
||||
|
||||
if len(tenantIDs) > 0 {
|
||||
db = db.Where("tenant_id IN ?", tenantIDs)
|
||||
db = r.withTenantMembershipFilter(db, tenantIDs)
|
||||
}
|
||||
|
||||
if search != "" {
|
||||
@@ -311,7 +320,7 @@ func (r *userRepository) FindTenantIDByLoginID(ctx context.Context, loginID stri
|
||||
|
||||
func (r *userRepository) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) {
|
||||
var users []domain.User
|
||||
err := r.db.WithContext(ctx).Where("tenant_id IN ?", tenantIDs).Find(&users).Error
|
||||
err := r.withTenantMembershipFilter(r.db.WithContext(ctx), tenantIDs).Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user