첫 커밋: 로컬 프로젝트 업로드
This commit is contained in:
449
baron-sso/backend/internal/repository/clickhouse_repo.go
Normal file
449
baron-sso/backend/internal/repository/clickhouse_repo.go
Normal file
@@ -0,0 +1,449 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ClickHouse/clickhouse-go/v2"
|
||||
"github.com/ClickHouse/clickhouse-go/v2/lib/driver"
|
||||
)
|
||||
|
||||
type ClickHouseRepository struct {
|
||||
conn driver.Conn
|
||||
}
|
||||
|
||||
func NewClickHouseRepository(host string, port int, user, password, db string) (*ClickHouseRepository, error) {
|
||||
// 1. Connect to 'default' database first to ensure target DB exists
|
||||
tmpConn, err := clickhouse.Open(&clickhouse.Options{
|
||||
Addr: []string{fmt.Sprintf("%s:%d", host, port)},
|
||||
Auth: clickhouse.Auth{
|
||||
Database: "default",
|
||||
Username: user,
|
||||
Password: password,
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
_ = tmpConn.Exec(context.Background(), fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", db))
|
||||
_ = tmpConn.Close()
|
||||
}
|
||||
|
||||
// 2. Now connect to the target database
|
||||
conn, err := clickhouse.Open(&clickhouse.Options{
|
||||
Addr: []string{fmt.Sprintf("%s:%d", host, port)},
|
||||
Auth: clickhouse.Auth{
|
||||
Database: db,
|
||||
Username: user,
|
||||
Password: password,
|
||||
},
|
||||
Debug: false,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open clickhouse connection: %w", err)
|
||||
}
|
||||
|
||||
if err := conn.Ping(context.Background()); err != nil {
|
||||
return nil, fmt.Errorf("failed to ping clickhouse: %w", err)
|
||||
}
|
||||
|
||||
// Ensure Table Exists
|
||||
// Note: In production, use migrations.
|
||||
query := `
|
||||
CREATE TABLE IF NOT EXISTS audit_logs (
|
||||
event_id String,
|
||||
timestamp DateTime DEFAULT now(),
|
||||
user_id String,
|
||||
tenant_id String,
|
||||
event_type String,
|
||||
status String,
|
||||
ip_address String,
|
||||
user_agent String,
|
||||
device_id String,
|
||||
details String
|
||||
) ENGINE = MergeTree()
|
||||
ORDER BY timestamp
|
||||
`
|
||||
if err := conn.Exec(context.Background(), query); err != nil {
|
||||
return nil, fmt.Errorf("failed to create table: %w", err)
|
||||
}
|
||||
|
||||
alterQuery := `
|
||||
ALTER TABLE audit_logs
|
||||
ADD COLUMN IF NOT EXISTS tenant_id String,
|
||||
ADD COLUMN IF NOT EXISTS event_id String
|
||||
`
|
||||
if err := conn.Exec(context.Background(), alterQuery); err != nil {
|
||||
return nil, fmt.Errorf("failed to alter table: %w", err)
|
||||
}
|
||||
|
||||
if err := ensureRPUsageTables(context.Background(), conn); err != nil {
|
||||
return nil, fmt.Errorf("failed to create rp usage tables: %w", err)
|
||||
}
|
||||
|
||||
return &ClickHouseRepository{conn: conn}, nil
|
||||
}
|
||||
|
||||
func ensureRPUsageTables(ctx context.Context, conn driver.Conn) error {
|
||||
factQuery := `
|
||||
CREATE TABLE IF NOT EXISTS rp_usage_events (
|
||||
event_id String,
|
||||
occurred_at DateTime64(3) DEFAULT now64(3),
|
||||
event_type String,
|
||||
subject String,
|
||||
tenant_id String,
|
||||
tenant_type String,
|
||||
client_id String,
|
||||
client_name String,
|
||||
session_id String,
|
||||
scopes Array(String),
|
||||
source String,
|
||||
correlation_id String,
|
||||
payload String
|
||||
) ENGINE = MergeTree()
|
||||
ORDER BY (occurred_at, event_id)
|
||||
`
|
||||
if err := conn.Exec(ctx, factQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
aggregateQuery := `
|
||||
CREATE TABLE IF NOT EXISTS rp_usage_daily_aggregate (
|
||||
event_date Date,
|
||||
tenant_id String,
|
||||
tenant_type String,
|
||||
client_id String,
|
||||
client_name String,
|
||||
event_type String,
|
||||
events_count AggregateFunction(count),
|
||||
unique_subjects AggregateFunction(uniqExact, String)
|
||||
) ENGINE = AggregatingMergeTree()
|
||||
ORDER BY (event_date, tenant_id, client_id, event_type)
|
||||
`
|
||||
if err := conn.Exec(ctx, aggregateQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
viewQuery := `
|
||||
CREATE MATERIALIZED VIEW IF NOT EXISTS rp_usage_daily_aggregate_mv
|
||||
TO rp_usage_daily_aggregate
|
||||
AS
|
||||
SELECT
|
||||
toDate(occurred_at) AS event_date,
|
||||
tenant_id,
|
||||
tenant_type,
|
||||
client_id,
|
||||
any(client_name) AS client_name,
|
||||
event_type,
|
||||
countState() AS events_count,
|
||||
uniqExactState(subject) AS unique_subjects
|
||||
FROM rp_usage_events
|
||||
WHERE tenant_type IN ('COMPANY', 'ORGANIZATION')
|
||||
GROUP BY event_date, tenant_id, tenant_type, client_id, event_type
|
||||
`
|
||||
return conn.Exec(ctx, viewQuery)
|
||||
}
|
||||
|
||||
func (r *ClickHouseRepository) Create(log *domain.AuditLog) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if log.Timestamp.IsZero() {
|
||||
log.Timestamp = time.Now()
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO audit_logs (event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
return r.conn.Exec(ctx, query,
|
||||
log.EventID,
|
||||
log.Timestamp,
|
||||
log.UserID,
|
||||
log.TenantID,
|
||||
log.EventType,
|
||||
log.Status,
|
||||
log.IPAddress,
|
||||
log.UserAgent,
|
||||
log.DeviceID,
|
||||
log.Details,
|
||||
)
|
||||
}
|
||||
|
||||
func (r *ClickHouseRepository) CreateRPUsageEvent(ctx context.Context, event domain.RPUsageEvent) error {
|
||||
if r == nil || r.conn == nil {
|
||||
return fmt.Errorf("clickhouse connection is nil")
|
||||
}
|
||||
if event.OccurredAt.IsZero() {
|
||||
event.OccurredAt = time.Now()
|
||||
}
|
||||
payloadBytes, err := json.Marshal(event.Payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal rp usage payload: %w", err)
|
||||
}
|
||||
query := `
|
||||
INSERT INTO rp_usage_events (
|
||||
event_id, occurred_at, event_type, subject, tenant_id, tenant_type,
|
||||
client_id, client_name, session_id, scopes, source, correlation_id, payload
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
return r.conn.Exec(ctx, query,
|
||||
event.ID,
|
||||
event.OccurredAt,
|
||||
event.EventType,
|
||||
event.Subject,
|
||||
event.TenantID,
|
||||
event.TenantType,
|
||||
event.ClientID,
|
||||
event.ClientName,
|
||||
event.SessionID,
|
||||
[]string(event.Scopes),
|
||||
event.Source,
|
||||
event.CorrelationID,
|
||||
string(payloadBytes),
|
||||
)
|
||||
}
|
||||
|
||||
func (r *ClickHouseRepository) FindRPUsage(ctx context.Context, rpQuery domain.RPUsageQuery) ([]domain.RPUsageDailyMetric, error) {
|
||||
if r == nil || r.conn == nil {
|
||||
return nil, fmt.Errorf("clickhouse connection is nil")
|
||||
}
|
||||
days := rpQuery.Days
|
||||
if days <= 0 || days > 90 {
|
||||
days = 14
|
||||
}
|
||||
periodExpr := "event_date"
|
||||
switch rpQuery.Period {
|
||||
case "week":
|
||||
periodExpr = "toMonday(event_date)"
|
||||
case "month":
|
||||
periodExpr = "toStartOfMonth(event_date)"
|
||||
case "day", "":
|
||||
periodExpr = "event_date"
|
||||
default:
|
||||
periodExpr = "event_date"
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
date,
|
||||
tenant_id,
|
||||
tenant_type,
|
||||
client_id,
|
||||
any(client_name) AS client_name,
|
||||
sumIf(events, event_type = ?) AS login_requests,
|
||||
sumIf(events, event_type != ?) AS other_requests,
|
||||
max(unique_subjects) AS unique_subjects
|
||||
FROM (
|
||||
SELECT
|
||||
toString(%s) AS date,
|
||||
tenant_id,
|
||||
tenant_type,
|
||||
client_id,
|
||||
any(client_name) AS client_name,
|
||||
event_type,
|
||||
countMerge(events_count) AS events,
|
||||
uniqExactMerge(unique_subjects) AS unique_subjects
|
||||
FROM rp_usage_daily_aggregate
|
||||
WHERE event_date >= today() - ?
|
||||
AND tenant_type IN ('COMPANY', 'ORGANIZATION')
|
||||
`, periodExpr)
|
||||
args := []any{domain.RPUsageEventTypeAuthorizationGranted, domain.RPUsageEventTypeAuthorizationGranted, days - 1}
|
||||
if rpQuery.TenantID != "" {
|
||||
query += " AND tenant_id = ?\n"
|
||||
args = append(args, rpQuery.TenantID)
|
||||
}
|
||||
query += fmt.Sprintf(`
|
||||
GROUP BY %s, tenant_id, tenant_type, client_id, event_type
|
||||
)
|
||||
GROUP BY date, tenant_id, tenant_type, client_id
|
||||
ORDER BY date ASC, tenant_id ASC, client_id ASC
|
||||
`, periodExpr)
|
||||
rows, err := r.conn.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query rp usage daily aggregate: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
metrics := make([]domain.RPUsageDailyMetric, 0)
|
||||
for rows.Next() {
|
||||
var metric domain.RPUsageDailyMetric
|
||||
if err := rows.Scan(
|
||||
&metric.Date,
|
||||
&metric.TenantID,
|
||||
&metric.TenantType,
|
||||
&metric.ClientID,
|
||||
&metric.ClientName,
|
||||
&metric.LoginRequests,
|
||||
&metric.OtherRequests,
|
||||
&metric.UniqueSubjects,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan rp usage daily aggregate: %w", err)
|
||||
}
|
||||
if metric.ClientName == "" {
|
||||
metric.ClientName = metric.ClientID
|
||||
}
|
||||
metrics = append(metrics, metric)
|
||||
}
|
||||
return metrics, nil
|
||||
}
|
||||
|
||||
func (r *ClickHouseRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor, tenantID string) ([]domain.AuditLog, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
query := `
|
||||
SELECT event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details
|
||||
FROM audit_logs
|
||||
WHERE 1=1
|
||||
`
|
||||
args := make([]any, 0, 5)
|
||||
|
||||
if tenantID != "" {
|
||||
query += " AND (tenant_id = ? OR (tenant_id = '' AND JSONExtractString(details, 'tenant_id') = ?))"
|
||||
args = append(args, tenantID, tenantID)
|
||||
}
|
||||
|
||||
if cursor != nil {
|
||||
query += `
|
||||
AND ((timestamp < ?) OR (timestamp = ? AND event_id < ?))
|
||||
`
|
||||
args = append(args, cursor.Timestamp, cursor.Timestamp, cursor.EventID)
|
||||
}
|
||||
query += `
|
||||
ORDER BY timestamp DESC, event_id DESC
|
||||
LIMIT ?
|
||||
`
|
||||
args = append(args, limit)
|
||||
|
||||
rows, err := r.conn.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query audit logs: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var logs []domain.AuditLog
|
||||
for rows.Next() {
|
||||
var log domain.AuditLog
|
||||
if err := rows.Scan(
|
||||
&log.EventID,
|
||||
&log.Timestamp,
|
||||
&log.UserID,
|
||||
&log.TenantID,
|
||||
&log.EventType,
|
||||
&log.Status,
|
||||
&log.IPAddress,
|
||||
&log.UserAgent,
|
||||
&log.DeviceID,
|
||||
&log.Details,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan audit log: %w", err)
|
||||
}
|
||||
logs = append(logs, log)
|
||||
}
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
func (r *ClickHouseRepository) FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]domain.AuditLog, error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
query := `
|
||||
SELECT event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details
|
||||
FROM audit_logs
|
||||
WHERE user_id = ? AND event_type IN (?)
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
`
|
||||
rows, err := r.conn.Query(ctx, query, userID, eventTypes, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query audit logs by user/events: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var logs []domain.AuditLog
|
||||
for rows.Next() {
|
||||
var log domain.AuditLog
|
||||
if err := rows.Scan(
|
||||
&log.EventID,
|
||||
&log.Timestamp,
|
||||
&log.UserID,
|
||||
&log.TenantID,
|
||||
&log.EventType,
|
||||
&log.Status,
|
||||
&log.IPAddress,
|
||||
&log.UserAgent,
|
||||
&log.DeviceID,
|
||||
&log.Details,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan audit log: %w", err)
|
||||
}
|
||||
logs = append(logs, log)
|
||||
}
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
func (r *ClickHouseRepository) Ping(ctx context.Context) error {
|
||||
if r.conn == nil {
|
||||
return fmt.Errorf("clickhouse connection is nil")
|
||||
}
|
||||
return r.conn.Ping(ctx)
|
||||
}
|
||||
|
||||
func (r *ClickHouseRepository) CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
|
||||
query := `
|
||||
SELECT count()
|
||||
FROM audit_logs
|
||||
WHERE status = 'failure' AND timestamp >= ?
|
||||
`
|
||||
args := []any{since}
|
||||
if tenantID != "" {
|
||||
query += " AND JSONExtractString(details, 'tenant_id') = ?"
|
||||
args = append(args, tenantID)
|
||||
}
|
||||
|
||||
var count int64
|
||||
err := r.conn.QueryRow(ctx, query, args...).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to count failures: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *ClickHouseRepository) CountEventsSince(ctx context.Context, since time.Time) (int64, error) {
|
||||
sinceUTC := since.UTC().Format("2006-01-02 15:04:05")
|
||||
query := fmt.Sprintf(`
|
||||
SELECT count()
|
||||
FROM audit_logs
|
||||
WHERE timestamp >= toDateTime('%s')
|
||||
`, sinceUTC)
|
||||
var count int64
|
||||
err := r.conn.QueryRow(ctx, query).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to count audit events: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *ClickHouseRepository) CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
|
||||
// We use uniqExact(session_id) to count unique sessions that had success events recently.
|
||||
query := `
|
||||
SELECT uniqExact(session_id)
|
||||
FROM audit_logs
|
||||
WHERE status = 'success' AND timestamp >= ? AND session_id != ''
|
||||
`
|
||||
args := []any{since}
|
||||
if tenantID != "" {
|
||||
query += " AND JSONExtractString(details, 'tenant_id') = ?"
|
||||
args = append(args, tenantID)
|
||||
}
|
||||
|
||||
var count int64
|
||||
err := r.conn.QueryRow(ctx, query, args...).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to count active sessions: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ClientConsentRepository interface {
|
||||
Upsert(ctx context.Context, consent *domain.ClientConsent) error
|
||||
Delete(ctx context.Context, subject, clientID string) error
|
||||
DeleteByClient(ctx context.Context, clientID string) error
|
||||
List(ctx context.Context, clientID string, limit, offset int) ([]domain.ClientConsentWithTenantInfo, int64, error)
|
||||
ListByTenant(ctx context.Context, clientID, tenantID string, limit, offset int) ([]domain.ClientConsentWithTenantInfo, int64, error)
|
||||
ListBySubject(ctx context.Context, subject string) ([]domain.ClientConsent, error)
|
||||
ListSubjectsByClient(ctx context.Context, clientID string) ([]string, error)
|
||||
Find(ctx context.Context, clientID, subject string) (*domain.ClientConsent, error)
|
||||
}
|
||||
|
||||
type clientConsentRepo struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewClientConsentRepository(db *gorm.DB) ClientConsentRepository {
|
||||
return &clientConsentRepo{db: db}
|
||||
}
|
||||
|
||||
func (r *clientConsentRepo) Find(ctx context.Context, clientID, subject string) (*domain.ClientConsent, error) {
|
||||
var consent domain.ClientConsent
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("client_id = ? AND subject = ?", clientID, subject).
|
||||
First(&consent).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &consent, nil
|
||||
}
|
||||
|
||||
func (r *clientConsentRepo) Upsert(ctx context.Context, consent *domain.ClientConsent) error {
|
||||
return r.db.WithContext(ctx).Unscoped().
|
||||
Where("client_id = ? AND subject = ?", consent.ClientID, consent.Subject).
|
||||
Assign(map[string]any{
|
||||
"granted_scopes": consent.GrantedScopes,
|
||||
"updated_at": gorm.Expr("NOW()"),
|
||||
"deleted_at": nil,
|
||||
}).
|
||||
FirstOrCreate(consent).Error
|
||||
}
|
||||
|
||||
func (r *clientConsentRepo) Delete(ctx context.Context, subject, clientID string) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Where("subject = ? AND client_id = ?", subject, clientID).
|
||||
Delete(&domain.ClientConsent{}).Error
|
||||
}
|
||||
|
||||
func (r *clientConsentRepo) DeleteByClient(ctx context.Context, clientID string) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Where("client_id = ?", clientID).
|
||||
Delete(&domain.ClientConsent{}).Error
|
||||
}
|
||||
|
||||
func (r *clientConsentRepo) List(ctx context.Context, clientID string, limit, offset int) ([]domain.ClientConsentWithTenantInfo, int64, error) {
|
||||
var consents []domain.ClientConsentWithTenantInfo
|
||||
var total int64
|
||||
|
||||
// Base query for counting
|
||||
countQuery := r.db.WithContext(ctx).Unscoped().Model(&domain.ClientConsent{}).Where("client_id = ?", clientID)
|
||||
if err := countQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Query for fetching data
|
||||
query := r.db.WithContext(ctx).Unscoped().
|
||||
Model(&domain.ClientConsent{}).
|
||||
Select("client_consents.*, users.tenant_id, tenants.name as tenant_name").
|
||||
Joins("LEFT JOIN users ON users.id::text = client_consents.subject").
|
||||
Joins("LEFT JOIN tenants ON tenants.id = users.tenant_id").
|
||||
Where("client_consents.client_id = ?", clientID)
|
||||
|
||||
err := query.Limit(limit).Offset(offset).Order("client_consents.updated_at DESC").Scan(&consents).Error
|
||||
return consents, total, err
|
||||
}
|
||||
|
||||
func (r *clientConsentRepo) ListByTenant(ctx context.Context, clientID, tenantID string, limit, offset int) ([]domain.ClientConsentWithTenantInfo, int64, error) {
|
||||
var consents []domain.ClientConsentWithTenantInfo
|
||||
var total int64
|
||||
|
||||
// Base query for counting
|
||||
countQuery := r.db.WithContext(ctx).Unscoped().
|
||||
Model(&domain.ClientConsent{}).
|
||||
Joins("JOIN users ON users.id::text = client_consents.subject").
|
||||
Where("client_consents.client_id = ? AND users.tenant_id = ?", clientID, tenantID)
|
||||
|
||||
if err := countQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Query for fetching data
|
||||
query := r.db.WithContext(ctx).Unscoped().
|
||||
Model(&domain.ClientConsent{}).
|
||||
Select("client_consents.*, users.tenant_id, tenants.name as tenant_name").
|
||||
Joins("JOIN users ON users.id::text = client_consents.subject").
|
||||
Joins("JOIN tenants ON tenants.id = users.tenant_id").
|
||||
Where("client_consents.client_id = ? AND users.tenant_id = ?", clientID, tenantID)
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Order("client_consents.updated_at DESC").
|
||||
Scan(&consents).Error
|
||||
|
||||
return consents, total, err
|
||||
}
|
||||
|
||||
func (r *clientConsentRepo) ListBySubject(ctx context.Context, subject string) ([]domain.ClientConsent, error) {
|
||||
var consents []domain.ClientConsent
|
||||
err := r.db.WithContext(ctx).Unscoped().
|
||||
Where("subject = ?", subject).
|
||||
Order("updated_at DESC").
|
||||
Find(&consents).Error
|
||||
return consents, err
|
||||
}
|
||||
|
||||
func (r *clientConsentRepo) ListSubjectsByClient(ctx context.Context, clientID string) ([]string, error) {
|
||||
var subjects []string
|
||||
err := r.db.WithContext(ctx).Unscoped().
|
||||
Model(&domain.ClientConsent{}).
|
||||
Distinct("subject").
|
||||
Where("client_id = ?", clientID).
|
||||
Order("subject ASC").
|
||||
Pluck("subject", &subjects).Error
|
||||
return subjects, err
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestClientConsentRepository_Find_IgnoresSoftDeletedConsent(t *testing.T) {
|
||||
repo := NewClientConsentRepository(testDB)
|
||||
ctx := context.Background()
|
||||
|
||||
consent := &domain.ClientConsent{
|
||||
ClientID: "client-soft-delete",
|
||||
Subject: "user-soft-delete",
|
||||
GrantedScopes: pq.StringArray{"openid", "profile"},
|
||||
}
|
||||
|
||||
err := repo.Upsert(ctx, consent)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = repo.Delete(ctx, consent.Subject, consent.ClientID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
found, err := repo.Find(ctx, consent.ClientID, consent.Subject)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, found)
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type clientSecretRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewClientSecretRepository(db *gorm.DB) domain.ClientSecretRepository {
|
||||
return &clientSecretRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *clientSecretRepository) Upsert(ctx context.Context, clientID, secret string) error {
|
||||
cs := domain.ClientSecret{
|
||||
ClientID: clientID,
|
||||
ClientSecret: secret,
|
||||
}
|
||||
return r.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "client_id"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"client_secret", "updated_at"}),
|
||||
}).Create(&cs).Error
|
||||
}
|
||||
|
||||
func (r *clientSecretRepository) GetByID(ctx context.Context, clientID string) (string, error) {
|
||||
var cs domain.ClientSecret
|
||||
if err := r.db.WithContext(ctx).Where("client_id = ?", clientID).First(&cs).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
return cs.ClientSecret, nil
|
||||
}
|
||||
|
||||
func (r *clientSecretRepository) Delete(ctx context.Context, clientID string) error {
|
||||
return r.db.WithContext(ctx).Where("client_id = ?", clientID).Delete(&domain.ClientSecret{}).Error
|
||||
}
|
||||
@@ -0,0 +1,380 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type DataIntegrityChecker interface {
|
||||
CheckDataIntegrity(ctx context.Context) (domain.DataIntegrityReport, error)
|
||||
ListOrphanUserLoginIDs(ctx context.Context) ([]domain.OrphanUserLoginID, error)
|
||||
DeleteOrphanUserLoginIDs(ctx context.Context, ids []string) (domain.DeleteOrphanUserLoginIDsResult, error)
|
||||
}
|
||||
|
||||
type dataIntegrityChecker struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewDataIntegrityChecker(db *gorm.DB) DataIntegrityChecker {
|
||||
return &dataIntegrityChecker{db: db}
|
||||
}
|
||||
|
||||
func (c *dataIntegrityChecker) CheckDataIntegrity(ctx context.Context) (domain.DataIntegrityReport, error) {
|
||||
return CheckDataIntegrity(ctx, c.db)
|
||||
}
|
||||
|
||||
func (c *dataIntegrityChecker) ListOrphanUserLoginIDs(ctx context.Context) ([]domain.OrphanUserLoginID, error) {
|
||||
return ListOrphanUserLoginIDs(ctx, c.db, nil)
|
||||
}
|
||||
|
||||
func (c *dataIntegrityChecker) DeleteOrphanUserLoginIDs(ctx context.Context, ids []string) (domain.DeleteOrphanUserLoginIDsResult, error) {
|
||||
return DeleteOrphanUserLoginIDs(ctx, c.db, ids)
|
||||
}
|
||||
|
||||
func CheckDataIntegrity(ctx context.Context, db *gorm.DB) (domain.DataIntegrityReport, error) {
|
||||
tenantChecks := []domain.DataIntegrityCheck{
|
||||
{
|
||||
Key: "duplicate_tenant_slugs",
|
||||
Label: "중복 테넌트 slug",
|
||||
Description: "삭제되지 않은 tenant의 slug를 대소문자 무시 기준으로 검사합니다.",
|
||||
Severity: "error",
|
||||
Count: 0,
|
||||
},
|
||||
{
|
||||
Key: "orphan_tenant_parents",
|
||||
Label: "유령 상위 테넌트 참조",
|
||||
Description: "tenant.parent_id가 없거나 삭제된 tenant를 참조하는지 검사합니다.",
|
||||
Severity: "error",
|
||||
Count: 0,
|
||||
},
|
||||
}
|
||||
userChecks := []domain.DataIntegrityCheck{
|
||||
{
|
||||
Key: "orphan_user_tenant_memberships",
|
||||
Label: "유령 테넌트 사용자 소속",
|
||||
Description: "users.tenant_id가 없거나 삭제된 tenant를 참조하는지 검사합니다.",
|
||||
Severity: "error",
|
||||
Count: 0,
|
||||
},
|
||||
{
|
||||
Key: "orphan_user_login_id_tenants",
|
||||
Label: "유령 테넌트 로그인 ID",
|
||||
Description: "user_login_ids.tenant_id가 없거나 삭제된 tenant를 참조하는지 검사합니다.",
|
||||
Severity: "error",
|
||||
Count: 0,
|
||||
},
|
||||
{
|
||||
Key: "orphan_user_login_id_users",
|
||||
Label: "유령 사용자 로그인 ID",
|
||||
Description: "user_login_ids.user_id가 없거나 삭제된 user를 참조하는지 검사합니다.",
|
||||
Severity: "error",
|
||||
Count: 0,
|
||||
},
|
||||
}
|
||||
|
||||
counts := []struct {
|
||||
target *int64
|
||||
query string
|
||||
}{
|
||||
{
|
||||
target: &tenantChecks[0].Count,
|
||||
query: `
|
||||
SELECT COUNT(*)
|
||||
FROM (
|
||||
SELECT LOWER(TRIM(slug)) AS normalized_slug
|
||||
FROM tenants
|
||||
WHERE deleted_at IS NULL
|
||||
AND status <> 'deleted'
|
||||
AND TRIM(slug) <> ''
|
||||
GROUP BY LOWER(TRIM(slug))
|
||||
HAVING COUNT(*) > 1
|
||||
) AS duplicate_slugs
|
||||
`,
|
||||
},
|
||||
{
|
||||
target: &tenantChecks[1].Count,
|
||||
query: `
|
||||
SELECT COUNT(*)
|
||||
FROM tenants AS child
|
||||
WHERE child.deleted_at IS NULL
|
||||
AND child.parent_id IS NOT NULL
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM tenants AS parent
|
||||
WHERE parent.id = child.parent_id
|
||||
AND parent.deleted_at IS NULL
|
||||
)
|
||||
`,
|
||||
},
|
||||
{
|
||||
target: &userChecks[0].Count,
|
||||
query: `
|
||||
SELECT COUNT(*)
|
||||
FROM users AS u
|
||||
WHERE u.deleted_at IS NULL
|
||||
AND u.tenant_id IS NOT NULL
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM tenants AS t
|
||||
WHERE t.id = u.tenant_id
|
||||
AND t.deleted_at IS NULL
|
||||
)
|
||||
`,
|
||||
},
|
||||
{
|
||||
target: &userChecks[1].Count,
|
||||
query: `
|
||||
SELECT COUNT(*)
|
||||
FROM user_login_ids AS uli
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM tenants AS t
|
||||
WHERE t.id = uli.tenant_id
|
||||
AND t.deleted_at IS NULL
|
||||
)
|
||||
`,
|
||||
},
|
||||
{
|
||||
target: &userChecks[2].Count,
|
||||
query: `
|
||||
SELECT COUNT(*)
|
||||
FROM user_login_ids AS uli
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM users AS u
|
||||
WHERE u.id = uli.user_id
|
||||
AND u.deleted_at IS NULL
|
||||
)
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, item := range counts {
|
||||
if err := db.WithContext(ctx).Raw(item.query).Scan(item.target).Error; err != nil {
|
||||
return domain.DataIntegrityReport{}, err
|
||||
}
|
||||
}
|
||||
|
||||
tenantChecks = applyIntegrityStatuses(tenantChecks)
|
||||
userChecks = applyIntegrityStatuses(userChecks)
|
||||
sections := []domain.DataIntegritySection{
|
||||
{
|
||||
Key: "tenant_integrity",
|
||||
Label: "테넌트 정합성",
|
||||
Status: summarizeIntegrityStatus(tenantChecks),
|
||||
Checks: tenantChecks,
|
||||
},
|
||||
{
|
||||
Key: "user_integrity",
|
||||
Label: "사용자 정합성",
|
||||
Status: summarizeIntegrityStatus(userChecks),
|
||||
Checks: userChecks,
|
||||
},
|
||||
}
|
||||
|
||||
summary := domain.DataIntegritySummary{}
|
||||
for _, section := range sections {
|
||||
for _, check := range section.Checks {
|
||||
summary.TotalChecks++
|
||||
switch check.Status {
|
||||
case domain.DataIntegrityStatusFail:
|
||||
summary.Failures += check.Count
|
||||
case domain.DataIntegrityStatusWarning:
|
||||
summary.Warnings++
|
||||
default:
|
||||
summary.Passed++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return domain.DataIntegrityReport{
|
||||
Status: summarizeSectionStatus(sections),
|
||||
CheckedAt: time.Now().UTC(),
|
||||
Summary: summary,
|
||||
Sections: sections,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func applyIntegrityStatuses(checks []domain.DataIntegrityCheck) []domain.DataIntegrityCheck {
|
||||
for i := range checks {
|
||||
if checks[i].Count > 0 {
|
||||
checks[i].Status = domain.DataIntegrityStatusFail
|
||||
} else {
|
||||
checks[i].Status = domain.DataIntegrityStatusPass
|
||||
}
|
||||
}
|
||||
return checks
|
||||
}
|
||||
|
||||
func summarizeIntegrityStatus(checks []domain.DataIntegrityCheck) domain.DataIntegrityStatus {
|
||||
status := domain.DataIntegrityStatusPass
|
||||
for _, check := range checks {
|
||||
if check.Status == domain.DataIntegrityStatusFail {
|
||||
return domain.DataIntegrityStatusFail
|
||||
}
|
||||
if check.Status == domain.DataIntegrityStatusWarning {
|
||||
status = domain.DataIntegrityStatusWarning
|
||||
}
|
||||
}
|
||||
return status
|
||||
}
|
||||
|
||||
func summarizeSectionStatus(sections []domain.DataIntegritySection) domain.DataIntegrityStatus {
|
||||
status := domain.DataIntegrityStatusPass
|
||||
for _, section := range sections {
|
||||
if section.Status == domain.DataIntegrityStatusFail {
|
||||
return domain.DataIntegrityStatusFail
|
||||
}
|
||||
if section.Status == domain.DataIntegrityStatusWarning {
|
||||
status = domain.DataIntegrityStatusWarning
|
||||
}
|
||||
}
|
||||
return status
|
||||
}
|
||||
|
||||
func ListOrphanUserLoginIDs(ctx context.Context, db *gorm.DB, ids []string) ([]domain.OrphanUserLoginID, error) {
|
||||
type orphanRow struct {
|
||||
ID string
|
||||
UserID string
|
||||
UserEmail string
|
||||
UserDeletedAt *time.Time
|
||||
TenantID string
|
||||
TenantSlug string
|
||||
TenantDeletedAt *time.Time
|
||||
FieldKey string
|
||||
LoginID string
|
||||
MissingUser bool
|
||||
DeletedUser bool
|
||||
MissingTenant bool
|
||||
DeletedTenant bool
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
uli.id,
|
||||
uli.user_id,
|
||||
COALESCE(u.email, '') AS user_email,
|
||||
u.deleted_at AS user_deleted_at,
|
||||
uli.tenant_id,
|
||||
COALESCE(t.slug, '') AS tenant_slug,
|
||||
t.deleted_at AS tenant_deleted_at,
|
||||
uli.field_key,
|
||||
uli.login_id,
|
||||
(u.id IS NULL) AS missing_user,
|
||||
(u.id IS NOT NULL AND u.deleted_at IS NOT NULL) AS deleted_user,
|
||||
(t.id IS NULL) AS missing_tenant,
|
||||
(t.id IS NOT NULL AND t.deleted_at IS NOT NULL) AS deleted_tenant
|
||||
FROM user_login_ids AS uli
|
||||
LEFT JOIN users AS u ON u.id = uli.user_id
|
||||
LEFT JOIN tenants AS t ON t.id = uli.tenant_id
|
||||
WHERE (
|
||||
u.id IS NULL
|
||||
OR u.deleted_at IS NOT NULL
|
||||
OR t.id IS NULL
|
||||
OR t.deleted_at IS NOT NULL
|
||||
)
|
||||
`
|
||||
args := []any{}
|
||||
if len(ids) > 0 {
|
||||
query += " AND uli.id IN ?\n"
|
||||
args = append(args, ids)
|
||||
}
|
||||
query += "ORDER BY uli.login_id, uli.id"
|
||||
|
||||
var rows []orphanRow
|
||||
if err := db.WithContext(ctx).Raw(query, args...).Scan(&rows).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
items := make([]domain.OrphanUserLoginID, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
reasons := make([]string, 0, 4)
|
||||
if row.MissingUser {
|
||||
reasons = append(reasons, "missing_user")
|
||||
}
|
||||
if row.DeletedUser {
|
||||
reasons = append(reasons, "deleted_user")
|
||||
}
|
||||
if row.MissingTenant {
|
||||
reasons = append(reasons, "missing_tenant")
|
||||
}
|
||||
if row.DeletedTenant {
|
||||
reasons = append(reasons, "deleted_tenant")
|
||||
}
|
||||
items = append(items, domain.OrphanUserLoginID{
|
||||
ID: row.ID,
|
||||
UserID: row.UserID,
|
||||
UserEmail: row.UserEmail,
|
||||
UserDeletedAt: row.UserDeletedAt,
|
||||
TenantID: row.TenantID,
|
||||
TenantSlug: row.TenantSlug,
|
||||
TenantDeletedAt: row.TenantDeletedAt,
|
||||
FieldKey: row.FieldKey,
|
||||
LoginID: row.LoginID,
|
||||
Reasons: reasons,
|
||||
})
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func DeleteOrphanUserLoginIDs(ctx context.Context, db *gorm.DB, ids []string) (domain.DeleteOrphanUserLoginIDsResult, error) {
|
||||
ids = normalizeIDList(ids)
|
||||
result := domain.DeleteOrphanUserLoginIDsResult{
|
||||
Deleted: []domain.OrphanUserLoginID{},
|
||||
SkippedIDs: []string{},
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
err := db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
items, err := ListOrphanUserLoginIDs(ctx, tx, ids)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
deletableIDs := make([]string, 0, len(items))
|
||||
deletableSet := make(map[string]bool, len(items))
|
||||
for _, item := range items {
|
||||
deletableIDs = append(deletableIDs, item.ID)
|
||||
deletableSet[item.ID] = true
|
||||
}
|
||||
for _, id := range ids {
|
||||
if !deletableSet[id] {
|
||||
result.SkippedIDs = append(result.SkippedIDs, id)
|
||||
}
|
||||
}
|
||||
if len(deletableIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
deleteResult := tx.Exec("DELETE FROM user_login_ids WHERE id IN ?", deletableIDs)
|
||||
if deleteResult.Error != nil {
|
||||
return deleteResult.Error
|
||||
}
|
||||
result.Deleted = items
|
||||
result.DeletedCount = deleteResult.RowsAffected
|
||||
return nil
|
||||
})
|
||||
return result, err
|
||||
}
|
||||
|
||||
func normalizeIDList(ids []string) []string {
|
||||
normalized := make([]string, 0, len(ids))
|
||||
seen := map[string]bool{}
|
||||
for _, id := range ids {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" || seen[id] {
|
||||
continue
|
||||
}
|
||||
seen[id] = true
|
||||
normalized = append(normalized, id)
|
||||
}
|
||||
slices.Sort(normalized)
|
||||
return normalized
|
||||
}
|
||||
@@ -0,0 +1,312 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestCheckDataIntegrityDetectsTenantAndUserProblems(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
suffix := uuid.NewString()
|
||||
|
||||
parent := domain.Tenant{
|
||||
ID: uuid.NewString(),
|
||||
Name: "Deleted Parent " + suffix,
|
||||
Slug: "deleted-parent-" + suffix,
|
||||
Type: domain.TenantTypeCompany,
|
||||
Status: domain.TenantStatusActive,
|
||||
}
|
||||
child := domain.Tenant{
|
||||
ID: uuid.NewString(),
|
||||
Name: "Orphan Child " + suffix,
|
||||
Slug: "orphan-child-" + suffix,
|
||||
Type: domain.TenantTypeOrganization,
|
||||
ParentID: &parent.ID,
|
||||
Status: domain.TenantStatusActive,
|
||||
}
|
||||
dupA := domain.Tenant{
|
||||
ID: uuid.NewString(),
|
||||
Name: "Duplicate A " + suffix,
|
||||
Slug: "Dup-" + suffix,
|
||||
Type: domain.TenantTypeCompany,
|
||||
Status: domain.TenantStatusActive,
|
||||
}
|
||||
dupB := domain.Tenant{
|
||||
ID: uuid.NewString(),
|
||||
Name: "Duplicate B " + suffix,
|
||||
Slug: "dup-" + suffix,
|
||||
Type: domain.TenantTypeCompany,
|
||||
Status: domain.TenantStatusActive,
|
||||
}
|
||||
|
||||
require.NoError(t, testDB.Create(&parent).Error)
|
||||
require.NoError(t, testDB.Create(&child).Error)
|
||||
require.NoError(t, testDB.Create(&dupA).Error)
|
||||
require.NoError(t, testDB.Create(&dupB).Error)
|
||||
require.NoError(t, testDB.Delete(&domain.Tenant{}, "id = ?", parent.ID).Error)
|
||||
|
||||
orphanUser := domain.User{
|
||||
ID: uuid.NewString(),
|
||||
Email: "orphan-" + suffix + "@example.com",
|
||||
Name: "Orphan User",
|
||||
Role: domain.RoleUser,
|
||||
TenantID: &parent.ID,
|
||||
Status: domain.UserStatusActive,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
deletedLoginUser := domain.User{
|
||||
ID: uuid.NewString(),
|
||||
Email: "deleted-login-user-" + suffix + "@example.com",
|
||||
Name: "Deleted Login User",
|
||||
Role: domain.RoleUser,
|
||||
TenantID: &child.ID,
|
||||
Status: domain.UserStatusActive,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
require.NoError(t, testDB.Create(&orphanUser).Error)
|
||||
require.NoError(t, testDB.Create(&deletedLoginUser).Error)
|
||||
require.NoError(t, testDB.Create(&domain.UserLoginID{
|
||||
ID: uuid.NewString(),
|
||||
UserID: orphanUser.ID,
|
||||
TenantID: parent.ID,
|
||||
FieldKey: "emp_id",
|
||||
LoginID: "EMP-" + suffix,
|
||||
}).Error)
|
||||
require.NoError(t, testDB.Create(&domain.UserLoginID{
|
||||
ID: uuid.NewString(),
|
||||
UserID: deletedLoginUser.ID,
|
||||
TenantID: child.ID,
|
||||
FieldKey: "emp_id",
|
||||
LoginID: "MISSING-" + suffix,
|
||||
}).Error)
|
||||
require.NoError(t, testDB.Delete(&domain.User{}, "id = ?", deletedLoginUser.ID).Error)
|
||||
|
||||
report, err := CheckDataIntegrity(ctx, testDB)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, domain.DataIntegrityStatusFail, report.Status)
|
||||
require.Equal(t, int64(5), report.Summary.Failures) // Reverted back to 5 due to successful soft delete simulation
|
||||
|
||||
requireIntegrityCheck(t, report, "tenant_integrity", "duplicate_tenant_slugs", domain.DataIntegrityStatusFail, 1)
|
||||
requireIntegrityCheck(t, report, "tenant_integrity", "orphan_tenant_parents", domain.DataIntegrityStatusFail, 1)
|
||||
requireIntegrityCheck(t, report, "user_integrity", "orphan_user_tenant_memberships", domain.DataIntegrityStatusFail, 1)
|
||||
requireIntegrityCheck(t, report, "user_integrity", "orphan_user_login_id_tenants", domain.DataIntegrityStatusFail, 1)
|
||||
requireIntegrityCheck(t, report, "user_integrity", "orphan_user_login_id_users", domain.DataIntegrityStatusFail, 1)
|
||||
}
|
||||
|
||||
func TestCheckDataIntegrityDetectsHardOrphanUserLoginIDRows(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
suffix := uuid.NewString()
|
||||
rollback := errors.New("rollback hard orphan fixture")
|
||||
|
||||
err := testDB.Transaction(func(tx *gorm.DB) error {
|
||||
var constraintNames []string
|
||||
if err := tx.Raw(`
|
||||
SELECT conname
|
||||
FROM pg_constraint
|
||||
WHERE conrelid = 'user_login_ids'::regclass
|
||||
AND contype = 'f'
|
||||
`).Scan(&constraintNames).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, constraintName := range constraintNames {
|
||||
statement := fmt.Sprintf("ALTER TABLE user_login_ids DROP CONSTRAINT %s", pq.QuoteIdentifier(constraintName))
|
||||
if err := tx.Exec(statement).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
before, err := CheckDataIntegrity(ctx, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
beforeTenantCount, err := integrityCheckCount(before, "user_integrity", "orphan_user_login_id_tenants")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
beforeUserCount, err := integrityCheckCount(before, "user_integrity", "orphan_user_login_id_users")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Create(&domain.UserLoginID{
|
||||
ID: uuid.NewString(),
|
||||
UserID: uuid.NewString(),
|
||||
TenantID: uuid.NewString(),
|
||||
FieldKey: "emp_id",
|
||||
LoginID: "HARD-ORPHAN-" + suffix,
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
report, err := CheckDataIntegrity(ctx, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := expectIntegrityCheck(report, "user_integrity", "orphan_user_login_id_tenants", domain.DataIntegrityStatusFail, beforeTenantCount+1); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := expectIntegrityCheck(report, "user_integrity", "orphan_user_login_id_users", domain.DataIntegrityStatusFail, beforeUserCount+1); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return rollback
|
||||
})
|
||||
require.ErrorIs(t, err, rollback)
|
||||
}
|
||||
|
||||
func TestListAndDeleteOrphanUserLoginIDsOnlyDeletesRevalidatedTargets(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
suffix := uuid.NewString()
|
||||
|
||||
validTenant := domain.Tenant{
|
||||
ID: uuid.NewString(),
|
||||
Name: "Valid Tenant " + suffix,
|
||||
Slug: "valid-tenant-" + suffix,
|
||||
Type: domain.TenantTypeCompany,
|
||||
Status: domain.TenantStatusActive,
|
||||
}
|
||||
deletedTenant := domain.Tenant{
|
||||
ID: uuid.NewString(),
|
||||
Name: "Deleted Tenant " + suffix,
|
||||
Slug: "deleted-tenant-" + suffix,
|
||||
Type: domain.TenantTypeCompany,
|
||||
Status: domain.TenantStatusActive,
|
||||
}
|
||||
require.NoError(t, testDB.Create(&validTenant).Error)
|
||||
require.NoError(t, testDB.Create(&deletedTenant).Error)
|
||||
|
||||
validUser := domain.User{
|
||||
ID: uuid.NewString(),
|
||||
Email: "valid-login-" + suffix + "@example.com",
|
||||
Name: "Valid Login User",
|
||||
Role: domain.RoleUser,
|
||||
TenantID: &validTenant.ID,
|
||||
Status: domain.UserStatusActive,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
deletedUser := domain.User{
|
||||
ID: uuid.NewString(),
|
||||
Email: "deleted-login-" + suffix + "@example.com",
|
||||
Name: "Deleted Login User",
|
||||
Role: domain.RoleUser,
|
||||
TenantID: &validTenant.ID,
|
||||
Status: domain.UserStatusActive,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
require.NoError(t, testDB.Create(&validUser).Error)
|
||||
require.NoError(t, testDB.Create(&deletedUser).Error)
|
||||
|
||||
validLogin := domain.UserLoginID{
|
||||
ID: uuid.NewString(),
|
||||
UserID: validUser.ID,
|
||||
TenantID: validTenant.ID,
|
||||
FieldKey: "emp_id",
|
||||
LoginID: "VALID-" + suffix,
|
||||
}
|
||||
deletedTenantLogin := domain.UserLoginID{
|
||||
ID: uuid.NewString(),
|
||||
UserID: validUser.ID,
|
||||
TenantID: deletedTenant.ID,
|
||||
FieldKey: "emp_id",
|
||||
LoginID: "DELETED-TENANT-" + suffix,
|
||||
}
|
||||
deletedUserLogin := domain.UserLoginID{
|
||||
ID: uuid.NewString(),
|
||||
UserID: deletedUser.ID,
|
||||
TenantID: validTenant.ID,
|
||||
FieldKey: "emp_id",
|
||||
LoginID: "DELETED-USER-" + suffix,
|
||||
}
|
||||
require.NoError(t, testDB.Create(&validLogin).Error)
|
||||
require.NoError(t, testDB.Create(&deletedTenantLogin).Error)
|
||||
require.NoError(t, testDB.Create(&deletedUserLogin).Error)
|
||||
require.NoError(t, testDB.Delete(&domain.Tenant{}, "id = ?", deletedTenant.ID).Error)
|
||||
require.NoError(t, testDB.Delete(&domain.User{}, "id = ?", deletedUser.ID).Error)
|
||||
|
||||
items, err := ListOrphanUserLoginIDs(ctx, testDB, nil)
|
||||
require.NoError(t, err)
|
||||
orphanReasons := map[string][]string{}
|
||||
for _, item := range items {
|
||||
orphanReasons[item.ID] = item.Reasons
|
||||
}
|
||||
require.Equal(t, []string{"deleted_tenant"}, orphanReasons[deletedTenantLogin.ID])
|
||||
require.Equal(t, []string{"deleted_user"}, orphanReasons[deletedUserLogin.ID])
|
||||
require.NotContains(t, orphanReasons, validLogin.ID)
|
||||
|
||||
result, err := DeleteOrphanUserLoginIDs(ctx, testDB, []string{
|
||||
deletedTenantLogin.ID,
|
||||
validLogin.ID,
|
||||
"00000000-0000-0000-0000-000000000000",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), result.DeletedCount)
|
||||
require.Len(t, result.Deleted, 1)
|
||||
require.Equal(t, deletedTenantLogin.ID, result.Deleted[0].ID)
|
||||
require.ElementsMatch(t, []string{
|
||||
validLogin.ID,
|
||||
"00000000-0000-0000-0000-000000000000",
|
||||
}, result.SkippedIDs)
|
||||
|
||||
var deletedTenantLoginCount int64
|
||||
require.NoError(t, testDB.Model(&domain.UserLoginID{}).Where("id = ?", deletedTenantLogin.ID).Count(&deletedTenantLoginCount).Error)
|
||||
require.Equal(t, int64(0), deletedTenantLoginCount)
|
||||
|
||||
var validLoginCount int64
|
||||
require.NoError(t, testDB.Model(&domain.UserLoginID{}).Where("id = ?", validLogin.ID).Count(&validLoginCount).Error)
|
||||
require.Equal(t, int64(1), validLoginCount)
|
||||
}
|
||||
|
||||
func requireIntegrityCheck(t *testing.T, report domain.DataIntegrityReport, sectionKey, checkKey string, status domain.DataIntegrityStatus, count int64) {
|
||||
t.Helper()
|
||||
require.NoError(t, expectIntegrityCheck(report, sectionKey, checkKey, status, count))
|
||||
}
|
||||
|
||||
func expectIntegrityCheck(report domain.DataIntegrityReport, sectionKey, checkKey string, status domain.DataIntegrityStatus, count int64) error {
|
||||
check, ok := findIntegrityCheck(report, sectionKey, checkKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("integrity check %s/%s not found", sectionKey, checkKey)
|
||||
}
|
||||
if check.Status != status {
|
||||
return fmt.Errorf("integrity check %s/%s status = %s, want %s", sectionKey, checkKey, check.Status, status)
|
||||
}
|
||||
if check.Count != count {
|
||||
return fmt.Errorf("integrity check %s/%s count = %d, want %d", sectionKey, checkKey, check.Count, count)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func integrityCheckCount(report domain.DataIntegrityReport, sectionKey, checkKey string) (int64, error) {
|
||||
check, ok := findIntegrityCheck(report, sectionKey, checkKey)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("integrity check %s/%s not found", sectionKey, checkKey)
|
||||
}
|
||||
return check.Count, nil
|
||||
}
|
||||
|
||||
func findIntegrityCheck(report domain.DataIntegrityReport, sectionKey, checkKey string) (domain.DataIntegrityCheck, bool) {
|
||||
for _, section := range report.Sections {
|
||||
if section.Key != sectionKey {
|
||||
continue
|
||||
}
|
||||
for _, check := range section.Checks {
|
||||
if check.Key == checkKey {
|
||||
return check, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return domain.DataIntegrityCheck{}, false
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
)
|
||||
|
||||
type FederationRepository interface {
|
||||
FindProviderByID(ctx context.Context, providerID string) (*domain.IdentityProviderConfig, error)
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type GormFederationRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewGormFederationRepository(db *gorm.DB) *GormFederationRepository {
|
||||
return &GormFederationRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *GormFederationRepository) FindProviderByID(ctx context.Context, providerID string) (*domain.IdentityProviderConfig, error) {
|
||||
var provider domain.IdentityProviderConfig
|
||||
if err := r.db.WithContext(ctx).First(&provider, "id = ?", providerID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &provider, nil
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type KetoOutboxRepository interface {
|
||||
Create(ctx context.Context, entry *domain.KetoOutbox) error
|
||||
CreateWithTx(tx *gorm.DB, entry *domain.KetoOutbox) error
|
||||
FindPending(ctx context.Context, limit int) ([]domain.KetoOutbox, error)
|
||||
ListCurrentBySubject(ctx context.Context, namespace, subject string) ([]domain.KetoOutbox, error)
|
||||
UpdateStatus(ctx context.Context, id string, status string, retryCount int, lastError string) error
|
||||
MarkProcessed(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
type ketoOutboxRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewKetoOutboxRepository(db *gorm.DB) KetoOutboxRepository {
|
||||
return &ketoOutboxRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *ketoOutboxRepository) Create(ctx context.Context, entry *domain.KetoOutbox) error {
|
||||
return r.db.WithContext(ctx).Create(entry).Error
|
||||
}
|
||||
|
||||
func (r *ketoOutboxRepository) CreateWithTx(tx *gorm.DB, entry *domain.KetoOutbox) error {
|
||||
return tx.Create(entry).Error
|
||||
}
|
||||
|
||||
func (r *ketoOutboxRepository) FindPending(ctx context.Context, limit int) ([]domain.KetoOutbox, error) {
|
||||
var entries []domain.KetoOutbox
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ?", domain.KetoOutboxStatusPending).
|
||||
Order("created_at asc").
|
||||
Limit(limit).
|
||||
Find(&entries).Error
|
||||
return entries, err
|
||||
}
|
||||
|
||||
func (r *ketoOutboxRepository) ListCurrentBySubject(ctx context.Context, namespace, subject string) ([]domain.KetoOutbox, error) {
|
||||
var entries []domain.KetoOutbox
|
||||
if err := r.db.WithContext(ctx).
|
||||
Where("namespace = ? AND subject = ? AND status <> ?", namespace, subject, domain.KetoOutboxStatusFailed).
|
||||
Order("created_at desc").
|
||||
Order("updated_at desc").
|
||||
Find(&entries).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
current := make([]domain.KetoOutbox, 0, len(entries))
|
||||
seen := make(map[string]struct{}, len(entries))
|
||||
for _, entry := range entries {
|
||||
key := entry.Namespace + "\x00" + entry.Object + "\x00" + entry.Relation + "\x00" + entry.Subject
|
||||
if _, exists := seen[key]; exists {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
if entry.Action == domain.KetoOutboxActionCreate {
|
||||
current = append(current, entry)
|
||||
}
|
||||
}
|
||||
|
||||
return current, nil
|
||||
}
|
||||
|
||||
func (r *ketoOutboxRepository) UpdateStatus(ctx context.Context, id string, status string, retryCount int, lastError string) error {
|
||||
return r.db.WithContext(ctx).Model(&domain.KetoOutbox{}).Where("id = ?", id).Updates(map[string]any{
|
||||
"status": status,
|
||||
"retry_count": retryCount,
|
||||
"last_error": lastError,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *ketoOutboxRepository) MarkProcessed(ctx context.Context, id string) error {
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Model(&domain.KetoOutbox{}).Where("id = ?", id).Updates(map[string]any{
|
||||
"status": domain.KetoOutboxStatusProcessed,
|
||||
"processed_at": &now,
|
||||
"updated_at": now,
|
||||
}).Error
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestKetoOutboxRepository_ListCurrentBySubject(t *testing.T) {
|
||||
repo := NewKetoOutboxRepository(testDB)
|
||||
ctx := context.Background()
|
||||
|
||||
require.NoError(t, testDB.Exec("DELETE FROM keto_outbox").Error)
|
||||
|
||||
entries := []domain.KetoOutbox{
|
||||
{
|
||||
Namespace: "RelyingParty",
|
||||
Object: "client-1",
|
||||
Relation: "admins",
|
||||
Subject: "User:user-1",
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
Status: domain.KetoOutboxStatusProcessed,
|
||||
},
|
||||
{
|
||||
Namespace: "RelyingParty",
|
||||
Object: "client-1",
|
||||
Relation: "admins",
|
||||
Subject: "User:user-1",
|
||||
Action: domain.KetoOutboxActionDelete,
|
||||
Status: domain.KetoOutboxStatusProcessed,
|
||||
},
|
||||
{
|
||||
Namespace: "RelyingParty",
|
||||
Object: "client-2",
|
||||
Relation: "config_editor",
|
||||
Subject: "User:user-1",
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
Status: domain.KetoOutboxStatusProcessed,
|
||||
},
|
||||
{
|
||||
Namespace: "RelyingParty",
|
||||
Object: "client-3",
|
||||
Relation: "audit_viewer",
|
||||
Subject: "User:user-1",
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
Status: domain.KetoOutboxStatusFailed,
|
||||
},
|
||||
{
|
||||
Namespace: "Tenant",
|
||||
Object: "tenant-1",
|
||||
Relation: "members",
|
||||
Subject: "User:user-1",
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
Status: domain.KetoOutboxStatusProcessed,
|
||||
},
|
||||
}
|
||||
for i := range entries {
|
||||
require.NoError(t, repo.Create(ctx, &entries[i]))
|
||||
}
|
||||
|
||||
current, err := repo.ListCurrentBySubject(ctx, "RelyingParty", "User:user-1")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, current, 1)
|
||||
require.Equal(t, "client-2", current[0].Object)
|
||||
require.Equal(t, "config_editor", current[0].Relation)
|
||||
}
|
||||
74
baron-sso/backend/internal/repository/main_test.go
Normal file
74
baron-sso/backend/internal/repository/main_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/testsupport"
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
postgres_module "github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
gorm_postgres "gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var testDB *gorm.DB
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if !testsupport.DockerAvailable() {
|
||||
log.Printf("skipping repository tests: Docker provider is unavailable in this environment")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Start PostgreSQL container
|
||||
dbName := "testdb"
|
||||
dbUser := "user"
|
||||
dbPassword := "password"
|
||||
|
||||
postgresContainer, err := postgres_module.Run(ctx,
|
||||
"postgres:16-alpine",
|
||||
postgres_module.WithDatabase(dbName),
|
||||
postgres_module.WithUsername(dbUser),
|
||||
postgres_module.WithPassword(dbPassword),
|
||||
testcontainers.WithWaitStrategy(
|
||||
wait.ForLog("database system is ready to accept connections").
|
||||
WithOccurrence(2).
|
||||
WithStartupTimeout(30*time.Second)),
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to start container: %s", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := postgresContainer.Terminate(ctx); err != nil {
|
||||
log.Fatalf("failed to terminate container: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
connStr, err := postgresContainer.ConnectionString(ctx, "sslmode=disable")
|
||||
if err != nil {
|
||||
log.Fatalf("failed to get connection string: %s", err)
|
||||
}
|
||||
|
||||
// Connect to test database
|
||||
db, err := gorm.Open(gorm_postgres.Open(connStr), &gorm.Config{})
|
||||
if err != nil {
|
||||
log.Fatalf("failed to connect to database: %s", err)
|
||||
}
|
||||
|
||||
// Auto-migrate
|
||||
err = db.AutoMigrate(&domain.Tenant{}, &domain.TenantDomain{}, &domain.User{}, &domain.UserLoginID{}, &domain.UserProjectionState{}, &domain.ClientConsent{}, &domain.RPUserMetadata{}, &domain.RPUsageEvent{}, &domain.KetoOutbox{}, &domain.WorksmobileOutbox{})
|
||||
if err != nil {
|
||||
log.Fatalf("failed to migrate database: %s", err)
|
||||
}
|
||||
|
||||
testDB = db
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
@@ -0,0 +1,160 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ClickHouse/clickhouse-go/v2"
|
||||
"github.com/ClickHouse/clickhouse-go/v2/lib/driver"
|
||||
)
|
||||
|
||||
type OathkeeperClickHouseRepository struct {
|
||||
conn driver.Conn
|
||||
}
|
||||
|
||||
func NewOathkeeperClickHouseRepository(host string, port int, user, password, db string) (*OathkeeperClickHouseRepository, error) {
|
||||
conn, err := clickhouse.Open(&clickhouse.Options{
|
||||
Addr: []string{fmt.Sprintf("%s:%d", host, port)},
|
||||
Auth: clickhouse.Auth{
|
||||
Database: db,
|
||||
Username: user,
|
||||
Password: password,
|
||||
},
|
||||
Debug: false,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open ory clickhouse connection: %w", err)
|
||||
}
|
||||
if err := conn.Ping(context.Background()); err != nil {
|
||||
return nil, fmt.Errorf("failed to ping ory clickhouse: %w", err)
|
||||
}
|
||||
return &OathkeeperClickHouseRepository{conn: conn}, nil
|
||||
}
|
||||
|
||||
func (r *OathkeeperClickHouseRepository) FindPageBySubject(ctx context.Context, subject string, limit int, cursor *domain.AuditCursor) ([]domain.OathkeeperAccessLog, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
query, args := buildOathkeeperQuery(subject, limit, cursor, true)
|
||||
rows, err := r.conn.Query(ctx, query, args...)
|
||||
if err != nil && isMissingColumnError(err, "client_id") {
|
||||
query, args = buildOathkeeperQuery(subject, limit, cursor, false)
|
||||
rows, err = r.conn.Query(ctx, query, args...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query oathkeeper logs: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
withClientID := strings.Contains(query, "client_id")
|
||||
var logs []domain.OathkeeperAccessLog
|
||||
for rows.Next() {
|
||||
var log domain.OathkeeperAccessLog
|
||||
if withClientID {
|
||||
if err := rows.Scan(
|
||||
&log.Timestamp,
|
||||
&log.RequestID,
|
||||
&log.Method,
|
||||
&log.Path,
|
||||
&log.Status,
|
||||
&log.LatencyMs,
|
||||
&log.ClientID,
|
||||
&log.RP,
|
||||
&log.Action,
|
||||
&log.Target,
|
||||
&log.Subject,
|
||||
&log.ClientIP,
|
||||
&log.UserAgent,
|
||||
&log.Decision,
|
||||
&log.TraceID,
|
||||
&log.SpanID,
|
||||
&log.Raw,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan oathkeeper log: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := rows.Scan(
|
||||
&log.Timestamp,
|
||||
&log.RequestID,
|
||||
&log.Method,
|
||||
&log.Path,
|
||||
&log.Status,
|
||||
&log.LatencyMs,
|
||||
&log.RP,
|
||||
&log.Action,
|
||||
&log.Target,
|
||||
&log.Subject,
|
||||
&log.ClientIP,
|
||||
&log.UserAgent,
|
||||
&log.Decision,
|
||||
&log.TraceID,
|
||||
&log.SpanID,
|
||||
&log.Raw,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan oathkeeper log: %w", err)
|
||||
}
|
||||
}
|
||||
logs = append(logs, log)
|
||||
}
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
func buildOathkeeperQuery(subject string, limit int, cursor *domain.AuditCursor, withClientID bool) (string, []any) {
|
||||
selectCols := "timestamp, request_id, method, path, status, latency_ms, rp, action, target, subject, client_ip, user_agent, decision, trace_id, span_id, raw"
|
||||
if withClientID {
|
||||
selectCols = "timestamp, request_id, method, path, status, latency_ms, client_id, rp, action, target, subject, client_ip, user_agent, decision, trace_id, span_id, raw"
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("SELECT %s FROM oathkeeper_access_logs", selectCols)
|
||||
args := make([]any, 0, 5)
|
||||
if subject != "" {
|
||||
query += `
|
||||
WHERE subject = ?
|
||||
`
|
||||
args = append(args, subject)
|
||||
if cursor != nil {
|
||||
query += `
|
||||
AND ((timestamp < ?) OR (timestamp = ? AND request_id < ?))
|
||||
`
|
||||
args = append(args, cursor.Timestamp, cursor.Timestamp, cursor.EventID)
|
||||
}
|
||||
} else if cursor != nil {
|
||||
query += `
|
||||
WHERE (timestamp < ?) OR (timestamp = ? AND request_id < ?)
|
||||
`
|
||||
args = append(args, cursor.Timestamp, cursor.Timestamp, cursor.EventID)
|
||||
}
|
||||
query += `
|
||||
ORDER BY timestamp DESC, request_id DESC
|
||||
LIMIT ?
|
||||
`
|
||||
args = append(args, limit)
|
||||
return query, args
|
||||
}
|
||||
|
||||
func isMissingColumnError(err error, column string) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
column = strings.ToLower(column)
|
||||
if strings.Contains(msg, "unknown identifier") && strings.Contains(msg, column) {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(msg, "unknown expression identifier") && strings.Contains(msg, column) {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(msg, "missing columns") && strings.Contains(msg, column) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *OathkeeperClickHouseRepository) Ping(ctx context.Context) error {
|
||||
if r == nil || r.conn == nil {
|
||||
return fmt.Errorf("ory clickhouse connection is nil")
|
||||
}
|
||||
return r.conn.Ping(ctx)
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type RelyingPartyRepository interface {
|
||||
Create(ctx context.Context, rp *domain.RelyingParty) error
|
||||
Update(ctx context.Context, rp *domain.RelyingParty) error
|
||||
Delete(ctx context.Context, clientID string) error
|
||||
FindByID(ctx context.Context, clientID string) (*domain.RelyingParty, error)
|
||||
ListByTenantID(ctx context.Context, tenantID string) ([]domain.RelyingParty, error)
|
||||
ListAll(ctx context.Context) ([]domain.RelyingParty, error)
|
||||
}
|
||||
|
||||
func (r *relyingPartyRepository) ListAll(ctx context.Context) ([]domain.RelyingParty, error) {
|
||||
var rps []domain.RelyingParty
|
||||
if err := r.db.WithContext(ctx).Find(&rps).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rps, nil
|
||||
}
|
||||
|
||||
type relyingPartyRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewRelyingPartyRepository(db *gorm.DB) RelyingPartyRepository {
|
||||
return &relyingPartyRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *relyingPartyRepository) Create(ctx context.Context, rp *domain.RelyingParty) error {
|
||||
return r.db.WithContext(ctx).Create(rp).Error
|
||||
}
|
||||
|
||||
func (r *relyingPartyRepository) Update(ctx context.Context, rp *domain.RelyingParty) error {
|
||||
return r.db.WithContext(ctx).Save(rp).Error
|
||||
}
|
||||
|
||||
func (r *relyingPartyRepository) Delete(ctx context.Context, clientID string) error {
|
||||
return r.db.WithContext(ctx).Delete(&domain.RelyingParty{}, "client_id = ?", clientID).Error
|
||||
}
|
||||
|
||||
func (r *relyingPartyRepository) FindByID(ctx context.Context, clientID string) (*domain.RelyingParty, error) {
|
||||
var rp domain.RelyingParty
|
||||
if err := r.db.WithContext(ctx).First(&rp, "client_id = ?", clientID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rp, nil
|
||||
}
|
||||
|
||||
func (r *relyingPartyRepository) ListByTenantID(ctx context.Context, tenantID string) ([]domain.RelyingParty, error) {
|
||||
var rps []domain.RelyingParty
|
||||
if err := r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Find(&rps).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rps, nil
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type RPUsageOutboxRepository interface {
|
||||
Create(ctx context.Context, event *domain.RPUsageEvent) error
|
||||
ListReady(ctx context.Context, limit int) ([]domain.RPUsageEvent, error)
|
||||
MarkProcessing(ctx context.Context, id string) error
|
||||
MarkProcessed(ctx context.Context, id string) error
|
||||
MarkFailed(ctx context.Context, id string, message string, nextAttemptAt time.Time) error
|
||||
}
|
||||
|
||||
type rpUsageOutboxRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewRPUsageOutboxRepository(db *gorm.DB) RPUsageOutboxRepository {
|
||||
return &rpUsageOutboxRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *rpUsageOutboxRepository) Create(ctx context.Context, event *domain.RPUsageEvent) error {
|
||||
if event.Payload == nil {
|
||||
event.Payload = domain.JSONMap{}
|
||||
}
|
||||
if event.Status == "" {
|
||||
event.Status = domain.RPUsageOutboxStatusPending
|
||||
}
|
||||
if event.OccurredAt.IsZero() {
|
||||
event.OccurredAt = time.Now()
|
||||
}
|
||||
return r.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "dedupe_key"}},
|
||||
DoNothing: true,
|
||||
}).Create(event).Error
|
||||
}
|
||||
|
||||
func (r *rpUsageOutboxRepository) ListReady(ctx context.Context, limit int) ([]domain.RPUsageEvent, error) {
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 20
|
||||
}
|
||||
var rows []domain.RPUsageEvent
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ? AND (next_attempt_at IS NULL OR next_attempt_at <= ?)", domain.RPUsageOutboxStatusPending, time.Now()).
|
||||
Order("occurred_at asc, created_at asc").
|
||||
Limit(limit).
|
||||
Find(&rows).Error
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (r *rpUsageOutboxRepository) MarkProcessing(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&domain.RPUsageEvent{}).
|
||||
Where("id = ? AND status = ?", id, domain.RPUsageOutboxStatusPending).
|
||||
Updates(map[string]any{
|
||||
"status": domain.RPUsageOutboxStatusProcessing,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *rpUsageOutboxRepository) MarkProcessed(ctx context.Context, id string) error {
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&domain.RPUsageEvent{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"status": domain.RPUsageOutboxStatusProcessed,
|
||||
"last_error": "",
|
||||
"processed_at": &now,
|
||||
"updated_at": now,
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *rpUsageOutboxRepository) MarkFailed(ctx context.Context, id string, message string, nextAttemptAt time.Time) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&domain.RPUsageEvent{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"status": domain.RPUsageOutboxStatusFailed,
|
||||
"retry_count": gorm.Expr("retry_count + 1"),
|
||||
"last_error": message,
|
||||
"next_attempt_at": &nextAttemptAt,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type RPUserMetadataRepository interface {
|
||||
Get(ctx context.Context, clientID, userID string) (*domain.RPUserMetadata, error)
|
||||
Upsert(ctx context.Context, metadata *domain.RPUserMetadata) error
|
||||
}
|
||||
|
||||
type rpUserMetadataRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewRPUserMetadataRepository(db *gorm.DB) RPUserMetadataRepository {
|
||||
return &rpUserMetadataRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *rpUserMetadataRepository) Get(ctx context.Context, clientID, userID string) (*domain.RPUserMetadata, error) {
|
||||
var metadata domain.RPUserMetadata
|
||||
if err := r.db.WithContext(ctx).First(&metadata, "client_id = ? AND user_id = ?", clientID, userID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &metadata, nil
|
||||
}
|
||||
|
||||
func (r *rpUserMetadataRepository) Upsert(ctx context.Context, metadata *domain.RPUserMetadata) error {
|
||||
return r.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{
|
||||
{Name: "client_id"},
|
||||
{Name: "user_id"},
|
||||
},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"metadata", "updated_at"}),
|
||||
}).Create(metadata).Error
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type SharedLinkRepository interface {
|
||||
Create(ctx context.Context, link *domain.SharedLink) error
|
||||
FindByToken(ctx context.Context, token string) (*domain.SharedLink, error)
|
||||
FindByTenantID(ctx context.Context, tenantID string) ([]domain.SharedLink, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
Update(ctx context.Context, link *domain.SharedLink) error
|
||||
}
|
||||
|
||||
type sharedLinkRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewSharedLinkRepository(db *gorm.DB) SharedLinkRepository {
|
||||
return &sharedLinkRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *sharedLinkRepository) Create(ctx context.Context, link *domain.SharedLink) error {
|
||||
return r.db.WithContext(ctx).Create(link).Error
|
||||
}
|
||||
|
||||
func (r *sharedLinkRepository) FindByToken(ctx context.Context, token string) (*domain.SharedLink, error) {
|
||||
var link domain.SharedLink
|
||||
err := r.db.WithContext(ctx).Where("token = ? AND is_active = ?", token, true).First(&link).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &link, nil
|
||||
}
|
||||
|
||||
func (r *sharedLinkRepository) FindByTenantID(ctx context.Context, tenantID string) ([]domain.SharedLink, error) {
|
||||
var links []domain.SharedLink
|
||||
err := r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Find(&links).Error
|
||||
return links, err
|
||||
}
|
||||
|
||||
func (r *sharedLinkRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&domain.SharedLink{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func (r *sharedLinkRepository) Update(ctx context.Context, link *domain.SharedLink) error {
|
||||
return r.db.WithContext(ctx).Save(link).Error
|
||||
}
|
||||
185
baron-sso/backend/internal/repository/tenant_repository.go
Normal file
185
baron-sso/backend/internal/repository/tenant_repository.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type TenantRepository interface {
|
||||
Create(ctx context.Context, tenant *domain.Tenant) error
|
||||
Update(ctx context.Context, tenant *domain.Tenant) error
|
||||
FindByID(ctx context.Context, id string) (*domain.Tenant, error)
|
||||
FindBySlug(ctx context.Context, slug string) (*domain.Tenant, error)
|
||||
FindByName(ctx context.Context, name string) (*domain.Tenant, error)
|
||||
FindByDomain(ctx context.Context, domainName string) (*domain.Tenant, error)
|
||||
FindByIDs(ctx context.Context, ids []string) ([]domain.Tenant, error)
|
||||
AddDomain(ctx context.Context, tenantID string, domainName string, verified bool) error
|
||||
List(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error)
|
||||
ListByType(ctx context.Context, tenantType string) ([]domain.Tenant, error)
|
||||
DeleteBulk(ctx context.Context, ids []string) error
|
||||
}
|
||||
|
||||
type tenantRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewTenantRepository(db *gorm.DB) TenantRepository {
|
||||
return &tenantRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *tenantRepository) Create(ctx context.Context, tenant *domain.Tenant) error {
|
||||
tenant.Slug = strings.ToLower(strings.TrimSpace(tenant.Slug))
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if tenant.Slug != "" {
|
||||
suffix := "-deleted-" + strconv.FormatInt(time.Now().UTC().UnixNano(), 10)
|
||||
if err := tx.Unscoped().
|
||||
Model(&domain.Tenant{}).
|
||||
Where("slug = ? AND deleted_at IS NOT NULL", tenant.Slug).
|
||||
Update("slug", gorm.Expr("slug || ?", suffix)).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Create(tenant).Error
|
||||
})
|
||||
}
|
||||
|
||||
func (r *tenantRepository) Update(ctx context.Context, tenant *domain.Tenant) error {
|
||||
return r.db.WithContext(ctx).Save(tenant).Error
|
||||
}
|
||||
|
||||
func (r *tenantRepository) FindByID(ctx context.Context, id string) (*domain.Tenant, error) {
|
||||
var tenant domain.Tenant
|
||||
if err := r.db.WithContext(ctx).Preload("Domains").First(&tenant, "id = ?", id).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tenant, nil
|
||||
}
|
||||
|
||||
func (r *tenantRepository) FindBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
|
||||
var tenant domain.Tenant
|
||||
if err := r.db.WithContext(ctx).Preload("Domains").Where("slug = ?", strings.ToLower(slug)).First(&tenant).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tenant, nil
|
||||
}
|
||||
|
||||
func (r *tenantRepository) FindByName(ctx context.Context, name string) (*domain.Tenant, error) {
|
||||
var tenant domain.Tenant
|
||||
if err := r.db.WithContext(ctx).Preload("Domains").Where("name = ?", name).First(&tenant).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tenant, nil
|
||||
}
|
||||
|
||||
func (r *tenantRepository) FindByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
|
||||
var tenantDomain domain.TenantDomain
|
||||
if err := r.db.WithContext(ctx).Where("domain = ?", domainName).First(&tenantDomain).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tenant domain.Tenant
|
||||
if err := r.db.WithContext(ctx).Preload("Domains").First(&tenant, "id = ?", tenantDomain.TenantID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tenant, nil
|
||||
}
|
||||
|
||||
func (r *tenantRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.Tenant, error) {
|
||||
var tenants []domain.Tenant
|
||||
if len(ids) == 0 {
|
||||
return tenants, nil
|
||||
}
|
||||
if err := r.db.WithContext(ctx).Preload("Domains").Where("id IN ?", ids).Find(&tenants).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tenants, nil
|
||||
}
|
||||
|
||||
func (r *tenantRepository) AddDomain(ctx context.Context, tenantID string, domainName string, verified bool) error {
|
||||
var existing domain.TenantDomain
|
||||
err := r.db.WithContext(ctx).Unscoped().
|
||||
Where("tenant_id = ? AND domain = ?", tenantID, domainName).
|
||||
First(&existing).Error
|
||||
if err == nil {
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&existing).Updates(map[string]any{
|
||||
"verified": verified,
|
||||
"deleted_at": nil,
|
||||
}).Error
|
||||
}
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
td := domain.TenantDomain{
|
||||
TenantID: tenantID,
|
||||
Domain: domainName,
|
||||
Verified: verified,
|
||||
}
|
||||
return r.db.WithContext(ctx).Create(&td).Error
|
||||
}
|
||||
|
||||
func (r *tenantRepository) List(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error) {
|
||||
var tenants []domain.Tenant
|
||||
var total int64
|
||||
db := r.db.WithContext(ctx).Model(&domain.Tenant{})
|
||||
|
||||
if parentID != "" {
|
||||
db = db.Where("parent_id = ?", parentID)
|
||||
}
|
||||
|
||||
if search != "" {
|
||||
searchTerm := "%" + strings.ToLower(search) + "%"
|
||||
db = db.Where("LOWER(name) LIKE ? OR LOWER(slug) LIKE ? OR LOWER(description) LIKE ?", searchTerm, searchTerm, searchTerm)
|
||||
}
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if err := db.Order("created_at desc, id desc").Limit(limit).Offset(offset).Preload("Domains").Find(&tenants).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return tenants, total, nil
|
||||
}
|
||||
|
||||
func (r *tenantRepository) ListByType(ctx context.Context, tenantType string) ([]domain.Tenant, error) {
|
||||
var tenants []domain.Tenant
|
||||
if err := r.db.WithContext(ctx).Where("type = ?", tenantType).Preload("Domains").Find(&tenants).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tenants, nil
|
||||
}
|
||||
|
||||
func (r *tenantRepository) DeleteBulk(ctx context.Context, ids []string) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// 1. Release slugs for all target tenants to allow reuse
|
||||
suffix := "-deleted-" + time.Now().Format("20060102150405")
|
||||
if err := tx.Model(&domain.Tenant{}).Where("id IN ?", ids).
|
||||
Update("slug", gorm.Expr("slug || ?", suffix)).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. Soft delete tenants
|
||||
if err := tx.Where("id IN ?", ids).Delete(&domain.Tenant{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 3. Also delete related UserGroups if any (Type USER_GROUP tenants have records in user_groups table)
|
||||
if err := tx.Where("id IN ?", ids).Delete(&domain.UserGroup{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
192
baron-sso/backend/internal/repository/tenant_repository_test.go
Normal file
192
baron-sso/backend/internal/repository/tenant_repository_test.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTenantRepository(t *testing.T) {
|
||||
repo := NewTenantRepository(testDB)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Create and FindByID", func(t *testing.T) {
|
||||
tenant := &domain.Tenant{
|
||||
Name: "Test Tenant",
|
||||
Slug: "test-tenant",
|
||||
Type: domain.TenantTypeCompany,
|
||||
}
|
||||
|
||||
err := repo.Create(ctx, tenant)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, tenant.ID)
|
||||
|
||||
found, err := repo.FindByID(ctx, tenant.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tenant.Name, found.Name)
|
||||
assert.Equal(t, tenant.Slug, found.Slug)
|
||||
})
|
||||
|
||||
t.Run("FindBySlug", func(t *testing.T) {
|
||||
tenant := &domain.Tenant{
|
||||
Name: "Slug Test",
|
||||
Slug: "slug-test",
|
||||
Type: domain.TenantTypeCompany,
|
||||
}
|
||||
_ = repo.Create(ctx, tenant)
|
||||
|
||||
found, err := repo.FindBySlug(ctx, "slug-test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tenant.ID, found.ID)
|
||||
})
|
||||
|
||||
t.Run("AddDomain and FindByDomain", func(t *testing.T) {
|
||||
tenant := &domain.Tenant{
|
||||
Name: "Domain Test",
|
||||
Slug: "domain-test",
|
||||
Type: domain.TenantTypeCompany,
|
||||
}
|
||||
_ = repo.Create(ctx, tenant)
|
||||
|
||||
err := repo.AddDomain(ctx, tenant.ID, "test-domain.com", true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByDomain(ctx, "test-domain.com")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tenant.ID, found.ID)
|
||||
assert.Len(t, found.Domains, 1)
|
||||
assert.Equal(t, "test-domain.com", found.Domains[0].Domain)
|
||||
})
|
||||
|
||||
t.Run("AddDomain allows same domain on multiple tenants", func(t *testing.T) {
|
||||
first := &domain.Tenant{
|
||||
Name: "Saman Existing",
|
||||
Slug: "saman-existing",
|
||||
Type: domain.TenantTypeCompany,
|
||||
}
|
||||
second := &domain.Tenant{
|
||||
Name: "Saman Current",
|
||||
Slug: "saman-current",
|
||||
Type: domain.TenantTypeCompany,
|
||||
}
|
||||
assert.NoError(t, repo.Create(ctx, first))
|
||||
assert.NoError(t, repo.Create(ctx, second))
|
||||
|
||||
assert.NoError(t, repo.AddDomain(ctx, first.ID, "samaneng.com", true))
|
||||
assert.NoError(t, repo.AddDomain(ctx, second.ID, "samaneng.com", true))
|
||||
|
||||
var rows []domain.TenantDomain
|
||||
err := testDB.Where("domain = ?", "samaneng.com").Find(&rows).Error
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, rows, 2)
|
||||
})
|
||||
|
||||
t.Run("AddDomain restores deleted tenant domain", func(t *testing.T) {
|
||||
tenant := &domain.Tenant{
|
||||
Name: "Domain Restore",
|
||||
Slug: "domain-restore",
|
||||
Type: domain.TenantTypeCompany,
|
||||
}
|
||||
assert.NoError(t, repo.Create(ctx, tenant))
|
||||
assert.NoError(t, repo.AddDomain(ctx, tenant.ID, "restore.samaneng.com", true))
|
||||
assert.NoError(t, testDB.Where("tenant_id = ? AND domain = ?", tenant.ID, "restore.samaneng.com").Delete(&domain.TenantDomain{}).Error)
|
||||
|
||||
assert.NoError(t, repo.AddDomain(ctx, tenant.ID, "restore.samaneng.com", true))
|
||||
|
||||
var rows []domain.TenantDomain
|
||||
err := testDB.Where("tenant_id = ? AND domain = ?", tenant.ID, "restore.samaneng.com").Find(&rows).Error
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, rows, 1) {
|
||||
assert.True(t, rows[0].Verified)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Update", func(t *testing.T) {
|
||||
tenant := &domain.Tenant{
|
||||
Name: "Before Update",
|
||||
Slug: "before-update",
|
||||
Type: domain.TenantTypeCompany,
|
||||
}
|
||||
_ = repo.Create(ctx, tenant)
|
||||
|
||||
tenant.Name = "After Update"
|
||||
err := repo.Update(ctx, tenant)
|
||||
assert.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByID(ctx, tenant.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "After Update", found.Name)
|
||||
})
|
||||
|
||||
t.Run("Hierarchy", func(t *testing.T) {
|
||||
parent := &domain.Tenant{
|
||||
Name: "Parent Tenant",
|
||||
Slug: "parent-hierarchy",
|
||||
Type: domain.TenantTypeCompanyGroup,
|
||||
}
|
||||
err := repo.Create(ctx, parent)
|
||||
assert.NoError(t, err)
|
||||
|
||||
child := &domain.Tenant{
|
||||
Name: "Child Tenant",
|
||||
Slug: "child-hierarchy",
|
||||
Type: domain.TenantTypeCompany,
|
||||
ParentID: &parent.ID,
|
||||
}
|
||||
err = repo.Create(ctx, child)
|
||||
assert.NoError(t, err)
|
||||
|
||||
foundChild, err := repo.FindByID(ctx, child.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, parent.ID, *foundChild.ParentID)
|
||||
})
|
||||
|
||||
t.Run("Unique Constraint on Slug", func(t *testing.T) {
|
||||
slug := "unique-slug-test"
|
||||
tenant1 := &domain.Tenant{
|
||||
Name: "First",
|
||||
Slug: slug,
|
||||
Type: domain.TenantTypeCompany,
|
||||
}
|
||||
err := repo.Create(ctx, tenant1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
tenant2 := &domain.Tenant{
|
||||
Name: "Second",
|
||||
Slug: slug,
|
||||
Type: domain.TenantTypeCompany,
|
||||
}
|
||||
err = repo.Create(ctx, tenant2)
|
||||
assert.Error(t, err) // Should fail due to UNIQUE constraint
|
||||
})
|
||||
|
||||
t.Run("Create reuses slug held by legacy soft-deleted tenant", func(t *testing.T) {
|
||||
slug := "legacy-soft-delete-reuse"
|
||||
require.NoError(t, testDB.Unscoped().Where("slug = ?", slug).Delete(&domain.Tenant{}).Error)
|
||||
|
||||
legacy := &domain.Tenant{
|
||||
Name: "Legacy Deleted",
|
||||
Slug: slug,
|
||||
Type: domain.TenantTypeCompany,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, legacy))
|
||||
require.NoError(t, testDB.Delete(&domain.Tenant{}, "id = ?", legacy.ID).Error)
|
||||
|
||||
_, err := repo.FindBySlug(ctx, slug)
|
||||
require.Error(t, err)
|
||||
|
||||
replacement := &domain.Tenant{
|
||||
Name: "Replacement",
|
||||
Slug: slug,
|
||||
Type: domain.TenantTypeCompany,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, replacement))
|
||||
|
||||
found, err := repo.FindBySlug(ctx, slug)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, replacement.ID, found.ID)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UserGroupRepository interface {
|
||||
Create(ctx context.Context, group *domain.UserGroup) error
|
||||
Update(ctx context.Context, group *domain.UserGroup) error
|
||||
Delete(ctx context.Context, id string) error
|
||||
FindByID(ctx context.Context, id string) (*domain.UserGroup, error)
|
||||
ListByTenantID(ctx context.Context, tenantID string) ([]domain.UserGroup, error)
|
||||
}
|
||||
|
||||
type userGroupRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewUserGroupRepository(db *gorm.DB) UserGroupRepository {
|
||||
return &userGroupRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *userGroupRepository) Create(ctx context.Context, group *domain.UserGroup) error {
|
||||
return r.db.WithContext(ctx).Create(group).Error
|
||||
}
|
||||
|
||||
func (r *userGroupRepository) Update(ctx context.Context, group *domain.UserGroup) error {
|
||||
return r.db.WithContext(ctx).Save(group).Error
|
||||
}
|
||||
|
||||
func (r *userGroupRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&domain.UserGroup{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func (r *userGroupRepository) FindByID(ctx context.Context, id string) (*domain.UserGroup, error) {
|
||||
var group domain.UserGroup
|
||||
// Using Where to be more explicit and avoid issues with GORM's default primary key handling if ID is string/uuid
|
||||
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&group).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
func (r *userGroupRepository) ListByTenantID(ctx context.Context, tenantID string) ([]domain.UserGroup, error) {
|
||||
var groups []domain.UserGroup
|
||||
if err := r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Find(&groups).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func CountOrphanUserTenantMemberships(ctx context.Context, db *gorm.DB) (int64, error) {
|
||||
var count int64
|
||||
err := db.WithContext(ctx).Raw(`
|
||||
SELECT COUNT(*)
|
||||
FROM users AS u
|
||||
WHERE u.deleted_at IS NULL
|
||||
AND (
|
||||
u.tenant_id IS NOT NULL
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM tenants AS t
|
||||
WHERE t.id = u.tenant_id
|
||||
AND t.deleted_at IS NULL
|
||||
)
|
||||
)
|
||||
`).Scan(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func ClearOrphanUserTenantMemberships(ctx context.Context, db *gorm.DB) (int64, error) {
|
||||
result := db.WithContext(ctx).Exec(`
|
||||
WITH orphan_users AS (
|
||||
SELECT u.id
|
||||
FROM users AS u
|
||||
WHERE u.deleted_at IS NULL
|
||||
AND (
|
||||
u.tenant_id IS NOT NULL
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM tenants AS t
|
||||
WHERE t.id = u.tenant_id
|
||||
AND t.deleted_at IS NULL
|
||||
)
|
||||
)
|
||||
)
|
||||
UPDATE users AS u
|
||||
SET tenant_id = NULL,
|
||||
updated_at = NOW()
|
||||
FROM orphan_users AS ou
|
||||
WHERE u.id = ou.id
|
||||
`)
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClearOrphanUserTenantMemberships(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := NewUserRepository(testDB)
|
||||
tenantRepo := NewTenantRepository(testDB)
|
||||
|
||||
require.NoError(t, testDB.Exec("DELETE FROM user_login_ids").Error)
|
||||
require.NoError(t, testDB.Exec("DELETE FROM users").Error)
|
||||
require.NoError(t, testDB.Exec("DELETE FROM tenant_domains").Error)
|
||||
require.NoError(t, testDB.Unscoped().Where("slug IN ?", []string{"orphan-active", "orphan-deleted"}).Delete(&domain.Tenant{}).Error)
|
||||
|
||||
activeTenant := &domain.Tenant{Name: "Active Tenant", Slug: "orphan-active", Type: domain.TenantTypeCompany}
|
||||
deletedTenant := &domain.Tenant{Name: "Deleted Tenant", Slug: "orphan-deleted", Type: domain.TenantTypeCompany}
|
||||
require.NoError(t, tenantRepo.Create(ctx, activeTenant))
|
||||
require.NoError(t, tenantRepo.Create(ctx, deletedTenant))
|
||||
require.NoError(t, testDB.Delete(&domain.Tenant{}, "id = ?", deletedTenant.ID).Error)
|
||||
|
||||
activeUser := &domain.User{
|
||||
Email: "active-membership@example.com",
|
||||
Name: "Active Membership",
|
||||
Role: "user",
|
||||
TenantID: &activeTenant.ID,
|
||||
}
|
||||
orphanUser := &domain.User{
|
||||
Email: "orphan-membership@example.com",
|
||||
Name: "Orphan Membership",
|
||||
Role: "user",
|
||||
TenantID: &deletedTenant.ID,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, activeUser))
|
||||
require.NoError(t, repo.Create(ctx, orphanUser))
|
||||
|
||||
count, err := CountOrphanUserTenantMemberships(ctx, testDB)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), count)
|
||||
|
||||
affected, err := ClearOrphanUserTenantMemberships(ctx, testDB)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), affected)
|
||||
|
||||
foundActive, err := repo.FindByEmail(ctx, activeUser.Email)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, foundActive.TenantID)
|
||||
require.NotNil(t, foundActive.Tenant)
|
||||
assert.Equal(t, activeTenant.ID, *foundActive.TenantID)
|
||||
|
||||
foundOrphan, err := repo.FindByEmail(ctx, orphanUser.Email)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, foundOrphan.TenantID)
|
||||
|
||||
count, err = CountOrphanUserTenantMemberships(ctx, testDB)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count)
|
||||
}
|
||||
@@ -0,0 +1,227 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type UserProjectionRepository interface {
|
||||
IsReady(ctx context.Context) (bool, error)
|
||||
GetStatus(ctx context.Context) (domain.UserProjectionStatus, error)
|
||||
CountTenantMembers(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error)
|
||||
CountTenantMembersRecursive(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error)
|
||||
ReplaceAllFromKratos(ctx context.Context, users []domain.User) error
|
||||
MarkFailed(ctx context.Context, syncErr error) error
|
||||
}
|
||||
|
||||
type userProjectionRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewUserProjectionRepository(db *gorm.DB) UserProjectionRepository {
|
||||
return &userProjectionRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *userProjectionRepository) IsReady(ctx context.Context) (bool, error) {
|
||||
status, err := r.GetStatus(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return status.Ready, nil
|
||||
}
|
||||
|
||||
func (r *userProjectionRepository) GetStatus(ctx context.Context) (domain.UserProjectionStatus, error) {
|
||||
var projectedUsers int64
|
||||
if err := r.db.WithContext(ctx).Model(&domain.User{}).Count(&projectedUsers).Error; err != nil {
|
||||
return domain.UserProjectionStatus{}, err
|
||||
}
|
||||
|
||||
var state domain.UserProjectionState
|
||||
err := r.db.WithContext(ctx).First(&state, "name = ?", domain.UserProjectionNameKratos).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return domain.UserProjectionStatus{
|
||||
Name: domain.UserProjectionNameKratos,
|
||||
Status: domain.UserProjectionStatusFailed,
|
||||
Ready: false,
|
||||
ProjectedUsers: projectedUsers,
|
||||
}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return domain.UserProjectionStatus{}, err
|
||||
}
|
||||
return domain.UserProjectionStatus{
|
||||
Name: state.Name,
|
||||
Status: state.Status,
|
||||
Ready: state.Status == domain.UserProjectionStatusReady && state.LastSyncedAt != nil,
|
||||
LastSyncedAt: state.LastSyncedAt,
|
||||
LastError: state.LastError,
|
||||
UpdatedAt: &state.UpdatedAt,
|
||||
ProjectedUsers: projectedUsers,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *userProjectionRepository) CountTenantMembers(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
|
||||
counts := make(map[string]int64, len(tenants))
|
||||
for _, tenant := range tenants {
|
||||
counts[tenant.ID] = 0
|
||||
}
|
||||
if len(tenants) == 0 {
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
valuePlaceholders := make([]string, 0, len(tenants))
|
||||
args := make([]any, 0, len(tenants)*2)
|
||||
for _, tenant := range tenants {
|
||||
valuePlaceholders = append(valuePlaceholders, "(?, ?)")
|
||||
args = append(args, strings.TrimSpace(tenant.ID), strings.TrimSpace(tenant.Slug))
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
WITH requested(tenant_id, slug) AS (
|
||||
VALUES %s
|
||||
)
|
||||
SELECT requested.tenant_id, COUNT(DISTINCT users.id) AS count
|
||||
FROM requested
|
||||
LEFT JOIN users ON users.deleted_at IS NULL AND (
|
||||
users.tenant_id::text = requested.tenant_id
|
||||
)
|
||||
GROUP BY requested.tenant_id
|
||||
`, strings.Join(valuePlaceholders, ","))
|
||||
|
||||
type result struct {
|
||||
TenantID string
|
||||
Count int64
|
||||
}
|
||||
var rows []result
|
||||
if err := r.db.WithContext(ctx).Raw(query, args...).Scan(&rows).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, row := range rows {
|
||||
counts[row.TenantID] = row.Count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (r *userProjectionRepository) CountTenantMembersRecursive(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
|
||||
counts := make(map[string]int64, len(tenants))
|
||||
for _, tenant := range tenants {
|
||||
counts[tenant.ID] = 0
|
||||
}
|
||||
if len(tenants) == 0 {
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
valuePlaceholders := make([]string, 0, len(tenants))
|
||||
args := make([]any, 0, len(tenants))
|
||||
for _, tenant := range tenants {
|
||||
valuePlaceholders = append(valuePlaceholders, "(?)")
|
||||
args = append(args, strings.TrimSpace(tenant.ID))
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
WITH RECURSIVE requested(tenant_id) AS (
|
||||
VALUES %s
|
||||
),
|
||||
descendants(root_tenant_id, tenant_id) AS (
|
||||
SELECT requested.tenant_id, requested.tenant_id
|
||||
FROM requested
|
||||
UNION ALL
|
||||
SELECT descendants.root_tenant_id, child.id::text
|
||||
FROM descendants
|
||||
JOIN tenants child
|
||||
ON child.parent_id::text = descendants.tenant_id
|
||||
AND child.deleted_at IS NULL
|
||||
)
|
||||
SELECT requested.tenant_id, COUNT(DISTINCT users.id) AS count
|
||||
FROM requested
|
||||
LEFT JOIN descendants
|
||||
ON descendants.root_tenant_id = requested.tenant_id
|
||||
LEFT JOIN users
|
||||
ON users.deleted_at IS NULL
|
||||
AND users.tenant_id::text = descendants.tenant_id
|
||||
GROUP BY requested.tenant_id
|
||||
`, strings.Join(valuePlaceholders, ","))
|
||||
|
||||
type result struct {
|
||||
TenantID string
|
||||
Count int64
|
||||
}
|
||||
var rows []result
|
||||
if err := r.db.WithContext(ctx).Raw(query, args...).Scan(&rows).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, row := range rows {
|
||||
counts[row.TenantID] = row.Count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (r *userProjectionRepository) ReplaceAllFromKratos(ctx context.Context, users []domain.User) error {
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for i := range users {
|
||||
users[i].DeletedAt = gorm.DeletedAt{}
|
||||
if users[i].CreatedAt.IsZero() {
|
||||
users[i].CreatedAt = now
|
||||
}
|
||||
if users[i].UpdatedAt.IsZero() {
|
||||
users[i].UpdatedAt = now
|
||||
}
|
||||
}
|
||||
|
||||
if len(users) > 0 {
|
||||
// [FIX] Handle email conflicts before bulk upsert
|
||||
for _, u := range users {
|
||||
if u.Email != "" {
|
||||
// Hard-delete any record with same email but different ID to clear unique constraint
|
||||
_ = tx.Unscoped().Where("email = ? AND id != ?", u.Email, u.ID).Delete(&domain.User{}).Error
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "id"}},
|
||||
UpdateAll: true,
|
||||
}).Create(&users).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return upsertUserProjectionState(tx, domain.UserProjectionStatusReady, &now, "")
|
||||
})
|
||||
}
|
||||
|
||||
func (r *userProjectionRepository) MarkFailed(ctx context.Context, syncErr error) error {
|
||||
message := ""
|
||||
if syncErr != nil {
|
||||
message = syncErr.Error()
|
||||
}
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
return upsertUserProjectionState(tx, domain.UserProjectionStatusFailed, nil, message)
|
||||
})
|
||||
}
|
||||
|
||||
func upsertUserProjectionState(tx *gorm.DB, status string, syncedAt *time.Time, lastError string) error {
|
||||
state := domain.UserProjectionState{
|
||||
Name: domain.UserProjectionNameKratos,
|
||||
Status: status,
|
||||
LastSyncedAt: syncedAt,
|
||||
LastError: lastError,
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
return tx.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "name"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{
|
||||
"status",
|
||||
"last_synced_at",
|
||||
"last_error",
|
||||
"updated_at",
|
||||
}),
|
||||
}).Create(&state).Error
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUserProjectionRepository_ReplaceAllFromKratosMarksReadyWithoutDeletingUsersMissingFromPartialList(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := NewUserProjectionRepository(testDB)
|
||||
|
||||
require.NoError(t, testDB.Exec("DELETE FROM user_projection_states").Error)
|
||||
require.NoError(t, testDB.Exec("DELETE FROM user_login_ids").Error)
|
||||
require.NoError(t, testDB.Exec("DELETE FROM users").Error)
|
||||
|
||||
tenantID := "10000000-0000-0000-0000-000000000001"
|
||||
tenantSlug := "projection-saman"
|
||||
require.NoError(t, testDB.Create(&domain.Tenant{
|
||||
ID: tenantID,
|
||||
Name: "Projection Saman",
|
||||
Slug: tenantSlug,
|
||||
Type: domain.TenantTypeCompany,
|
||||
Status: domain.TenantStatusActive,
|
||||
}).Error)
|
||||
existing := &domain.User{
|
||||
ID: "00000000-0000-0000-0000-000000000099",
|
||||
Email: "existing@example.com",
|
||||
Name: "Existing",
|
||||
CompanyCode: tenantSlug,
|
||||
TenantID: &tenantID,
|
||||
}
|
||||
require.NoError(t, NewUserRepository(testDB).Create(ctx, existing))
|
||||
|
||||
users := []domain.User{
|
||||
{
|
||||
ID: "00000000-0000-0000-0000-000000000101",
|
||||
Email: "one@example.com",
|
||||
Name: "One",
|
||||
CompanyCode: tenantSlug,
|
||||
TenantID: &tenantID,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
},
|
||||
{
|
||||
ID: "00000000-0000-0000-0000-000000000102",
|
||||
Email: "two@example.com",
|
||||
Name: "Two",
|
||||
TenantID: &tenantID,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
},
|
||||
}
|
||||
|
||||
require.NoError(t, repo.ReplaceAllFromKratos(ctx, users))
|
||||
|
||||
ready, err := repo.IsReady(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
counts, err := repo.CountTenantMembers(ctx, []domain.Tenant{
|
||||
{ID: tenantID, Slug: tenantSlug},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(3), counts[tenantID])
|
||||
|
||||
var activeCount int64
|
||||
require.NoError(t, testDB.Model(&domain.User{}).Count(&activeCount).Error)
|
||||
assert.Equal(t, int64(3), activeCount)
|
||||
|
||||
var existingCount int64
|
||||
require.NoError(t, testDB.Model(&domain.User{}).Where("id = ?", existing.ID).Count(&existingCount).Error)
|
||||
assert.Equal(t, int64(1), existingCount)
|
||||
|
||||
var existingRow domain.User
|
||||
require.NoError(t, testDB.Unscoped().First(&existingRow, "id = ?", existing.ID).Error)
|
||||
assert.False(t, existingRow.DeletedAt.Valid)
|
||||
}
|
||||
|
||||
func TestUserProjectionRepository_CountTenantMembersRecursiveIncludesDescendantsAndExcludesSoftDeletedUsers(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := NewUserProjectionRepository(testDB)
|
||||
|
||||
parentID := "20000000-0000-0000-0000-000000000001"
|
||||
childID := "20000000-0000-0000-0000-000000000002"
|
||||
grandchildID := "20000000-0000-0000-0000-000000000003"
|
||||
siblingID := "20000000-0000-0000-0000-000000000004"
|
||||
tenantIDs := []string{parentID, childID, grandchildID, siblingID}
|
||||
|
||||
require.NoError(t, testDB.Exec("DELETE FROM user_login_ids").Error)
|
||||
require.NoError(t, testDB.Exec("DELETE FROM users").Error)
|
||||
require.NoError(t, testDB.Unscoped().Where("id IN ?", tenantIDs).Delete(&domain.Tenant{}).Error)
|
||||
|
||||
require.NoError(t, testDB.Create(&domain.Tenant{
|
||||
ID: parentID,
|
||||
Name: "Recursive Parent",
|
||||
Slug: "recursive-parent",
|
||||
Type: domain.TenantTypeCompany,
|
||||
Status: domain.TenantStatusActive,
|
||||
}).Error)
|
||||
require.NoError(t, testDB.Create(&domain.Tenant{
|
||||
ID: childID,
|
||||
Name: "Recursive Child",
|
||||
Slug: "recursive-child",
|
||||
Type: domain.TenantTypeOrganization,
|
||||
Status: domain.TenantStatusActive,
|
||||
ParentID: &parentID,
|
||||
}).Error)
|
||||
require.NoError(t, testDB.Create(&domain.Tenant{
|
||||
ID: grandchildID,
|
||||
Name: "Recursive Grandchild",
|
||||
Slug: "recursive-grandchild",
|
||||
Type: domain.TenantTypeUserGroup,
|
||||
Status: domain.TenantStatusActive,
|
||||
ParentID: &childID,
|
||||
}).Error)
|
||||
require.NoError(t, testDB.Create(&domain.Tenant{
|
||||
ID: siblingID,
|
||||
Name: "Recursive Sibling",
|
||||
Slug: "recursive-sibling",
|
||||
Type: domain.TenantTypeCompany,
|
||||
Status: domain.TenantStatusActive,
|
||||
}).Error)
|
||||
|
||||
users := []domain.User{
|
||||
{ID: "21000000-0000-0000-0000-000000000001", Email: "parent@example.com", Name: "Parent", TenantID: &parentID},
|
||||
{ID: "21000000-0000-0000-0000-000000000002", Email: "child@example.com", Name: "Child", TenantID: &childID},
|
||||
{ID: "21000000-0000-0000-0000-000000000003", Email: "grandchild@example.com", Name: "Grandchild", TenantID: &grandchildID},
|
||||
{ID: "21000000-0000-0000-0000-000000000004", Email: "deleted-grandchild@example.com", Name: "Deleted Grandchild", TenantID: &grandchildID},
|
||||
{ID: "21000000-0000-0000-0000-000000000005", Email: "sibling@example.com", Name: "Sibling", TenantID: &siblingID},
|
||||
}
|
||||
for i := range users {
|
||||
require.NoError(t, testDB.Create(&users[i]).Error)
|
||||
}
|
||||
require.NoError(t, testDB.Delete(&domain.User{}, "id = ?", users[3].ID).Error)
|
||||
|
||||
directCounts, err := repo.CountTenantMembers(ctx, []domain.Tenant{{ID: parentID}, {ID: childID}, {ID: grandchildID}, {ID: siblingID}})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), directCounts[parentID])
|
||||
assert.Equal(t, int64(1), directCounts[childID])
|
||||
assert.Equal(t, int64(1), directCounts[grandchildID])
|
||||
assert.Equal(t, int64(1), directCounts[siblingID])
|
||||
|
||||
recursiveCounts, err := repo.CountTenantMembersRecursive(ctx, []domain.Tenant{{ID: parentID}, {ID: childID}, {ID: grandchildID}, {ID: siblingID}})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(3), recursiveCounts[parentID])
|
||||
assert.Equal(t, int64(2), recursiveCounts[childID])
|
||||
assert.Equal(t, int64(1), recursiveCounts[grandchildID])
|
||||
assert.Equal(t, int64(1), recursiveCounts[siblingID])
|
||||
}
|
||||
|
||||
func TestUserProjectionRepository_MarkFailedMakesProjectionNotReady(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := NewUserProjectionRepository(testDB)
|
||||
|
||||
require.NoError(t, testDB.Exec("DELETE FROM user_projection_states").Error)
|
||||
|
||||
require.NoError(t, repo.MarkFailed(ctx, errors.New("kratos down")))
|
||||
|
||||
ready, err := repo.IsReady(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
}
|
||||
335
baron-sso/backend/internal/repository/user_repository.go
Normal file
335
baron-sso/backend/internal/repository/user_repository.go
Normal file
@@ -0,0 +1,335 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/pagination"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type UserRepository interface {
|
||||
Create(ctx context.Context, user *domain.User) error
|
||||
Update(ctx context.Context, user *domain.User) error
|
||||
FindByEmail(ctx context.Context, email string) (*domain.User, error)
|
||||
FindByID(ctx context.Context, id string) (*domain.User, error)
|
||||
FindByIDs(ctx context.Context, ids []string) ([]domain.User, error)
|
||||
ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error)
|
||||
List(ctx context.Context, offset, limit int, search string, tenantIDs []string, cursor string) ([]domain.User, int64, string, error)
|
||||
CountByTenant(ctx context.Context, tenantID string) (int64, error)
|
||||
CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error)
|
||||
CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error)
|
||||
FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error)
|
||||
FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
DB() *gorm.DB
|
||||
|
||||
// Multiple identifiers support
|
||||
UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error
|
||||
GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error)
|
||||
IsLoginIDTaken(ctx context.Context, loginID string) (bool, error)
|
||||
FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error)
|
||||
}
|
||||
|
||||
type userRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewUserRepository(db *gorm.DB) UserRepository {
|
||||
return &userRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *userRepository) DB() *gorm.DB {
|
||||
return r.db
|
||||
}
|
||||
|
||||
func (r *userRepository) withTenantMembershipFilter(db *gorm.DB, tenantIDs []string) *gorm.DB {
|
||||
if len(tenantIDs) == 0 {
|
||||
return db
|
||||
}
|
||||
clauses := []string{"tenant_id IN ?"}
|
||||
args := []any{tenantIDs}
|
||||
for _, tenantID := range tenantIDs {
|
||||
tenantID = strings.TrimSpace(tenantID)
|
||||
if tenantID == "" {
|
||||
continue
|
||||
}
|
||||
payload, err := json.Marshal(map[string]any{
|
||||
"additionalAppointments": []map[string]string{
|
||||
{"tenantId": tenantID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
clauses = append(clauses, "metadata @> ?::jsonb")
|
||||
args = append(args, string(payload))
|
||||
}
|
||||
return db.Where("("+strings.Join(clauses, " OR ")+")", args...)
|
||||
}
|
||||
|
||||
func (r *userRepository) Create(ctx context.Context, user *domain.User) error {
|
||||
return r.db.WithContext(ctx).Create(user).Error
|
||||
}
|
||||
|
||||
func (r *userRepository) Update(ctx context.Context, user *domain.User) error {
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// 1. Check for email collision (including soft-deleted)
|
||||
var existing domain.User
|
||||
if err := tx.Unscoped().Where("email = ?", user.Email).First(&existing).Error; err == nil {
|
||||
// If email exists but ID is different, we MUST clear the old one to avoid unique constraint violation
|
||||
if existing.ID != user.ID {
|
||||
// [Restored] Check if the existing user is archived
|
||||
if strings.EqualFold(strings.TrimSpace(existing.Status), domain.UserStatusArchived) {
|
||||
return fmt.Errorf("email is reserved by archived user: %s", user.Email)
|
||||
}
|
||||
|
||||
// HARD DELETE the old record and its associated login IDs to free up the email and identifiers
|
||||
if err := tx.Unscoped().Where("user_id = ?", existing.ID).Delete(&domain.UserLoginID{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Unscoped().Delete(&domain.User{}, "id = ?", existing.ID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Perform Upsert on the new/target ID
|
||||
return tx.Unscoped().Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "id"}},
|
||||
DoUpdates: clause.Assignments(map[string]any{
|
||||
"email": user.Email,
|
||||
"name": user.Name,
|
||||
"phone": user.Phone,
|
||||
"role": user.Role,
|
||||
"status": user.Status,
|
||||
"department": user.Department,
|
||||
"grade": user.Grade,
|
||||
"position": user.Position,
|
||||
"job_title": user.JobTitle,
|
||||
"metadata": user.Metadata,
|
||||
"tenant_id": user.TenantID,
|
||||
"affiliation_type": user.AffiliationType,
|
||||
"updated_at": user.UpdatedAt,
|
||||
"deleted_at": nil, // Ensure it's active
|
||||
}),
|
||||
}).Create(user).Error
|
||||
})
|
||||
}
|
||||
|
||||
func (r *userRepository) FindByEmail(ctx context.Context, email string) (*domain.User, error) {
|
||||
var user domain.User
|
||||
if err := r.db.WithContext(ctx).Preload("Tenant").Where("email = ?", email).First(&user).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) FindByID(ctx context.Context, id string) (*domain.User, error) {
|
||||
var user domain.User
|
||||
if err := r.db.WithContext(ctx).Preload("Tenant").Where("id = ?", id).First(&user).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) {
|
||||
var users []domain.User
|
||||
if len(ids) == 0 {
|
||||
return users, nil
|
||||
}
|
||||
if err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
|
||||
var users []domain.User
|
||||
if err := r.withTenantMembershipFilter(r.db.WithContext(ctx), []string{tenantID}).Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
|
||||
var count int64
|
||||
err := r.withTenantMembershipFilter(r.db.WithContext(ctx).Model(&domain.User{}), []string{tenantID}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *userRepository) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
|
||||
counts := make(map[string]int64)
|
||||
|
||||
if len(tenantIDs) == 0 {
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
for _, tenantID := range tenantIDs {
|
||||
var count int64
|
||||
if err := r.withTenantMembershipFilter(r.db.WithContext(ctx).Model(&domain.User{}), []string{tenantID}).Count(&count).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts[tenantID] = count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) {
|
||||
if len(codes) == 0 {
|
||||
return make(map[string]int64), nil
|
||||
}
|
||||
|
||||
type result struct {
|
||||
TenantSlug string
|
||||
Count int64
|
||||
}
|
||||
var results []result
|
||||
|
||||
lowerCodes := lowerStrings(codes)
|
||||
|
||||
if err := r.db.WithContext(ctx).Table("users").
|
||||
Select("LOWER(tenants.slug) AS tenant_slug, count(DISTINCT users.id) AS count").
|
||||
Joins("JOIN tenants ON users.tenant_id = tenants.id").
|
||||
Where("users.deleted_at IS NULL AND LOWER(tenants.slug) IN ?", lowerCodes).
|
||||
Group("LOWER(tenants.slug)").
|
||||
Scan(&results).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
counts := make(map[string]int64)
|
||||
for _, res := range results {
|
||||
counts[strings.ToLower(res.TenantSlug)] = res.Count
|
||||
}
|
||||
|
||||
// Ensure all requested codes are present in results (even if count is 0)
|
||||
for _, code := range codes {
|
||||
lower := strings.ToLower(strings.TrimSpace(code))
|
||||
if _, ok := counts[lower]; !ok {
|
||||
counts[lower] = 0
|
||||
}
|
||||
}
|
||||
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func lowerStrings(arr []string) []string {
|
||||
res := make([]string, len(arr))
|
||||
for i, s := range arr {
|
||||
res[i] = strings.ToLower(strings.TrimSpace(s))
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (r *userRepository) List(ctx context.Context, offset, limit int, search string, tenantIDs []string, cursorRaw string) ([]domain.User, int64, string, error) {
|
||||
var users []domain.User
|
||||
var total int64
|
||||
db := r.db.WithContext(ctx).Model(&domain.User{})
|
||||
|
||||
if len(tenantIDs) > 0 {
|
||||
db = r.withTenantMembershipFilter(db, tenantIDs)
|
||||
}
|
||||
|
||||
if search != "" {
|
||||
searchTerm := "%" + search + "%"
|
||||
db = db.Where("(users.email LIKE ? OR users.name LIKE ? OR users.metadata::text LIKE ?)",
|
||||
searchTerm, searchTerm, searchTerm)
|
||||
}
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
|
||||
if cursorRaw != "" {
|
||||
cursor, err := pagination.Decode(cursorRaw)
|
||||
if err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
db = pagination.ApplyCreatedAtIDCursor(db, cursor, "created_at", "id")
|
||||
} else {
|
||||
db = db.Offset(offset)
|
||||
}
|
||||
|
||||
if err := db.Order("created_at desc, id desc").Limit(limit + 1).Preload("Tenant").Find(&users).Error; err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
|
||||
var items []domain.User
|
||||
var nextCursor string
|
||||
if len(users) > limit {
|
||||
items = users[:limit]
|
||||
last := items[limit-1]
|
||||
nextCursor = pagination.Encode(last.CreatedAt, last.ID)
|
||||
} else {
|
||||
items = users
|
||||
}
|
||||
|
||||
return items, total, nextCursor, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&domain.User{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error {
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// [FIX] Use Unscoped to permanently delete existing login IDs for this user
|
||||
// This prevents unique constraint violations with soft-deleted records
|
||||
if err := tx.Unscoped().Where("user_id = ?", userID).Delete(&domain.UserLoginID{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Insert new login IDs if any
|
||||
if len(loginIDs) > 0 {
|
||||
if err := tx.Create(&loginIDs).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *userRepository) GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) {
|
||||
var results []domain.UserLoginID
|
||||
if err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&results).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) {
|
||||
var count int64
|
||||
if err := r.db.WithContext(ctx).Model(&domain.UserLoginID{}).Where("login_id = ?", loginID).Count(&count).Error; err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) {
|
||||
var record domain.UserLoginID
|
||||
if err := r.db.WithContext(ctx).Where("login_id = ?", loginID).First(&record).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
return record.TenantID, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) {
|
||||
var users []domain.User
|
||||
err := r.withTenantMembershipFilter(r.db.WithContext(ctx), tenantIDs).Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
func (r *userRepository) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) {
|
||||
var users []domain.User
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("JOIN tenants ON users.tenant_id = tenants.id").
|
||||
Where("LOWER(tenants.slug) IN ?", lowerStrings(codes)).
|
||||
Preload("Tenant").
|
||||
Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
273
baron-sso/backend/internal/repository/user_repository_test.go
Normal file
273
baron-sso/backend/internal/repository/user_repository_test.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUserRepository(t *testing.T) {
|
||||
repo := NewUserRepository(testDB)
|
||||
ctx := context.Background()
|
||||
|
||||
// Ensure User table exists and clean for tests
|
||||
_ = testDB.AutoMigrate(&domain.User{})
|
||||
|
||||
t.Run("Create and FindByEmail", func(t *testing.T) {
|
||||
user := &domain.User{
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
Role: "user",
|
||||
}
|
||||
|
||||
err := repo.Create(ctx, user)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, user.ID)
|
||||
|
||||
found, err := repo.FindByEmail(ctx, "test@example.com")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, user.ID, found.ID)
|
||||
assert.Equal(t, "Test User", found.Name)
|
||||
})
|
||||
|
||||
t.Run("Update User Info", func(t *testing.T) {
|
||||
user := &domain.User{
|
||||
Email: "update@example.com",
|
||||
Name: "Before Update",
|
||||
Role: "user",
|
||||
}
|
||||
_ = repo.Create(ctx, user)
|
||||
|
||||
user.Name = "After Update"
|
||||
user.Phone = "010-1234-5678"
|
||||
err := repo.Update(ctx, user)
|
||||
assert.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByEmail(ctx, "update@example.com")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "After Update", found.Name)
|
||||
assert.Equal(t, "010-1234-5678", found.Phone)
|
||||
})
|
||||
|
||||
t.Run("Update preserves archived email reservation", func(t *testing.T) {
|
||||
testDB.Exec("DELETE FROM user_login_ids")
|
||||
testDB.Exec("DELETE FROM users")
|
||||
|
||||
archived := &domain.User{
|
||||
ID: "00000000-0000-0000-0000-00000000a001",
|
||||
Email: "reserved@example.com",
|
||||
Name: "Archived User",
|
||||
Role: domain.RoleUser,
|
||||
Status: domain.UserStatusArchived,
|
||||
}
|
||||
replacement := &domain.User{
|
||||
ID: "00000000-0000-0000-0000-00000000a002",
|
||||
Email: "reserved@example.com",
|
||||
Name: "Replacement User",
|
||||
Role: domain.RoleUser,
|
||||
Status: domain.UserStatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, archived))
|
||||
|
||||
err := repo.Update(ctx, replacement)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "archived user")
|
||||
found, err := repo.FindByEmail(ctx, archived.Email)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, archived.ID, found.ID)
|
||||
require.Equal(t, domain.UserStatusArchived, found.Status)
|
||||
})
|
||||
|
||||
t.Run("List Users with Search", func(t *testing.T) {
|
||||
// Add some users
|
||||
_ = repo.Create(ctx, &domain.User{Email: "alice@test.com", Name: "Alice", Role: "user"})
|
||||
_ = repo.Create(ctx, &domain.User{Email: "bob@test.com", Name: "Bob", Role: "user"})
|
||||
|
||||
users, total, _, err := repo.List(ctx, 0, 10, "Alice", []string{}, "")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, total >= 1)
|
||||
assert.Equal(t, "Alice", users[0].Name)
|
||||
})
|
||||
|
||||
t.Run("Delete User", func(t *testing.T) {
|
||||
user := &domain.User{Email: "delete@example.com", Name: "To Delete"}
|
||||
_ = repo.Create(ctx, user)
|
||||
|
||||
err := repo.Delete(ctx, user.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByEmail(ctx, "delete@example.com")
|
||||
assert.Error(t, err) // Should not be found
|
||||
assert.Nil(t, found)
|
||||
})
|
||||
|
||||
t.Run("CountByCompanyCodes", func(t *testing.T) {
|
||||
// Clean start for this subtest
|
||||
testDB.Exec("DELETE FROM user_login_ids")
|
||||
testDB.Exec("DELETE FROM users")
|
||||
testDB.Exec("DELETE FROM tenant_domains")
|
||||
tenantA := createUserRepositoryTestTenant(t, "tenant-a")
|
||||
tenantB := createUserRepositoryTestTenant(t, "tenant-b")
|
||||
|
||||
users := []domain.User{
|
||||
{Email: "u1@a.com", Name: "U1", TenantID: &tenantA.ID},
|
||||
{Email: "u2@a.com", Name: "U2", TenantID: &tenantA.ID},
|
||||
{Email: "u3@b.com", Name: "U3", TenantID: &tenantB.ID},
|
||||
{Email: "u4@none.com", Name: "U4"},
|
||||
}
|
||||
for _, u := range users {
|
||||
_ = repo.Create(ctx, &u)
|
||||
}
|
||||
|
||||
counts, err := repo.CountByCompanyCodes(ctx, []string{"tenant-a", "tenant-b", "tenant-c"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(2), counts["tenant-a"])
|
||||
assert.Equal(t, int64(1), counts["tenant-b"])
|
||||
assert.Equal(t, int64(0), counts["tenant-c"])
|
||||
})
|
||||
|
||||
t.Run("CountByCompanyCodes excludes soft deleted cache rows", func(t *testing.T) {
|
||||
testDB.Exec("DELETE FROM user_login_ids")
|
||||
testDB.Exec("DELETE FROM users")
|
||||
testDB.Exec("DELETE FROM tenant_domains")
|
||||
tenantA := createUserRepositoryTestTenant(t, "tenant-a")
|
||||
|
||||
active := &domain.User{Email: "active@a.com", Name: "Active", TenantID: &tenantA.ID}
|
||||
deleted := &domain.User{Email: "deleted@a.com", Name: "Deleted", TenantID: &tenantA.ID}
|
||||
secondDeleted := &domain.User{Email: "second-deleted@a.com", Name: "Second Deleted", TenantID: &tenantA.ID}
|
||||
|
||||
assert.NoError(t, repo.Create(ctx, active))
|
||||
assert.NoError(t, repo.Create(ctx, deleted))
|
||||
assert.NoError(t, repo.Create(ctx, secondDeleted))
|
||||
assert.NoError(t, repo.Delete(ctx, deleted.ID))
|
||||
assert.NoError(t, repo.Delete(ctx, secondDeleted.ID))
|
||||
|
||||
counts, err := repo.CountByCompanyCodes(ctx, []string{"tenant-a"})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), counts["tenant-a"])
|
||||
})
|
||||
|
||||
t.Run("Multi-Identifier Support", func(t *testing.T) {
|
||||
_ = testDB.AutoMigrate(&domain.UserLoginID{})
|
||||
testDB.Exec("DELETE FROM user_login_ids")
|
||||
testDB.Exec("DELETE FROM users")
|
||||
|
||||
user := &domain.User{Email: "multi@test.com", Name: "Multi"}
|
||||
_ = repo.Create(ctx, user)
|
||||
|
||||
t1 := "00000000-0000-0000-0000-000000000001"
|
||||
t2 := "00000000-0000-0000-0000-000000000002"
|
||||
|
||||
loginIDs := []domain.UserLoginID{
|
||||
{UserID: user.ID, TenantID: t1, FieldKey: "emp_id", LoginID: "E001"},
|
||||
{UserID: user.ID, TenantID: t2, FieldKey: "student_id", LoginID: "S001"},
|
||||
}
|
||||
|
||||
err := repo.UpdateUserLoginIDs(ctx, user.ID, loginIDs)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Get and Verify
|
||||
saved, err := repo.GetUserLoginIDs(ctx, user.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, saved, 2)
|
||||
|
||||
// IsLoginIDTaken
|
||||
taken, err := repo.IsLoginIDTaken(ctx, "E001")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, taken)
|
||||
|
||||
taken, err = repo.IsLoginIDTaken(ctx, "UNKNOWN")
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, taken)
|
||||
|
||||
// FindTenantIDByLoginID
|
||||
tid, err := repo.FindTenantIDByLoginID(ctx, "S001")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, t2, tid)
|
||||
|
||||
// Update (Replace)
|
||||
newList := []domain.UserLoginID{
|
||||
{UserID: user.ID, TenantID: t1, FieldKey: "emp_id", LoginID: "E002"},
|
||||
}
|
||||
err = repo.UpdateUserLoginIDs(ctx, user.ID, newList)
|
||||
assert.NoError(t, err)
|
||||
|
||||
saved, _ = repo.GetUserLoginIDs(ctx, user.ID)
|
||||
assert.Len(t, saved, 1)
|
||||
assert.Equal(t, "E002", saved[0].LoginID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserRepository_ListIncludesAdditionalTenantAppointments(t *testing.T) {
|
||||
repo := NewUserRepository(testDB)
|
||||
ctx := context.Background()
|
||||
require.NoError(t, testDB.Exec("DELETE FROM user_login_ids").Error)
|
||||
require.NoError(t, testDB.Exec("DELETE FROM users").Error)
|
||||
|
||||
primaryTenant := createUserRepositoryTestTenant(t, "repo-primary-tenant")
|
||||
additionalTenant := createUserRepositoryTestTenant(t, "repo-additional-tenant")
|
||||
primaryTenantID := primaryTenant.ID
|
||||
additionalTenantID := additionalTenant.ID
|
||||
users := []domain.User{
|
||||
{
|
||||
ID: uuid.NewString(),
|
||||
Email: "primary-member@example.com",
|
||||
Name: "Primary Member",
|
||||
Role: domain.RoleUser,
|
||||
TenantID: &additionalTenantID,
|
||||
},
|
||||
{
|
||||
ID: uuid.NewString(),
|
||||
Email: "additional-member@example.com",
|
||||
Name: "Additional Member",
|
||||
Role: domain.RoleUser,
|
||||
TenantID: &primaryTenantID,
|
||||
Metadata: domain.JSONMap{
|
||||
"additionalAppointments": []any{
|
||||
map[string]any{
|
||||
"tenantId": additionalTenant.ID,
|
||||
"tenantSlug": additionalTenant.Slug,
|
||||
"tenantName": additionalTenant.Name,
|
||||
"isPrimary": false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for i := range users {
|
||||
require.NoError(t, repo.Create(ctx, &users[i]))
|
||||
}
|
||||
|
||||
listed, total, _, err := repo.List(ctx, 0, 20, "", []string{additionalTenant.ID}, "")
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), total)
|
||||
require.Len(t, listed, 2)
|
||||
emails := []string{listed[0].Email, listed[1].Email}
|
||||
assert.Contains(t, emails, "primary-member@example.com")
|
||||
assert.Contains(t, emails, "additional-member@example.com")
|
||||
|
||||
counts, err := repo.CountByTenantIDs(ctx, []string{additionalTenant.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(2), counts[additionalTenant.ID])
|
||||
}
|
||||
|
||||
func createUserRepositoryTestTenant(t *testing.T, slug string) domain.Tenant {
|
||||
t.Helper()
|
||||
require.NoError(t, testDB.Unscoped().Where("slug = ?", slug).Delete(&domain.Tenant{}).Error)
|
||||
tenant := domain.Tenant{
|
||||
ID: uuid.NewString(),
|
||||
Name: "Tenant " + slug,
|
||||
Slug: slug,
|
||||
Type: domain.TenantTypeCompany,
|
||||
Status: domain.TenantStatusActive,
|
||||
}
|
||||
require.NoError(t, testDB.Create(&tenant).Error)
|
||||
return tenant
|
||||
}
|
||||
@@ -0,0 +1,208 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type WorksmobileOutboxRepository interface {
|
||||
Create(ctx context.Context, item *domain.WorksmobileOutbox) error
|
||||
ListRecent(ctx context.Context, limit int) ([]domain.WorksmobileOutbox, error)
|
||||
ListCredentialBatchJobs(ctx context.Context, tenantRootID, credentialBatchID string) ([]domain.WorksmobileOutbox, error)
|
||||
UpdatePayload(ctx context.Context, id string, payload domain.JSONMap) error
|
||||
DeletePendingByTenantRoot(ctx context.Context, tenantRootID string) (int64, error)
|
||||
ListReady(ctx context.Context, limit int) ([]domain.WorksmobileOutbox, error)
|
||||
FindByID(ctx context.Context, id string) (*domain.WorksmobileOutbox, error)
|
||||
MarkRetry(ctx context.Context, id string) error
|
||||
MarkProcessing(ctx context.Context, id string) (bool, error)
|
||||
MarkProcessed(ctx context.Context, id string) error
|
||||
MarkFailed(ctx context.Context, id string, message string, nextAttemptAt time.Time) error
|
||||
}
|
||||
|
||||
type worksmobileOutboxRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewWorksmobileOutboxRepository(db *gorm.DB) WorksmobileOutboxRepository {
|
||||
return &worksmobileOutboxRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *worksmobileOutboxRepository) Create(ctx context.Context, item *domain.WorksmobileOutbox) error {
|
||||
if item.Payload == nil {
|
||||
item.Payload = domain.JSONMap{}
|
||||
}
|
||||
if item.Status == "" {
|
||||
item.Status = domain.WorksmobileOutboxStatusPending
|
||||
}
|
||||
return r.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "dedupe_key"}},
|
||||
DoUpdates: clause.Assignments(map[string]any{
|
||||
"payload": item.Payload,
|
||||
"status": domain.WorksmobileOutboxStatusPending,
|
||||
"last_error": "",
|
||||
"next_attempt_at": nil,
|
||||
"updated_at": time.Now(),
|
||||
}),
|
||||
}).Create(item).Error
|
||||
}
|
||||
|
||||
func (r *worksmobileOutboxRepository) ListRecent(ctx context.Context, limit int) ([]domain.WorksmobileOutbox, error) {
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 50
|
||||
}
|
||||
var rows []domain.WorksmobileOutbox
|
||||
err := r.db.WithContext(ctx).Order("created_at desc").Limit(limit).Find(&rows).Error
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (r *worksmobileOutboxRepository) ListCredentialBatchJobs(ctx context.Context, tenantRootID, credentialBatchID string) ([]domain.WorksmobileOutbox, error) {
|
||||
query := r.db.WithContext(ctx).
|
||||
Where("resource_type = ? AND payload ->> 'tenantRootId' = ? AND coalesce(payload ->> 'credentialBatchId', '') <> ?", domain.WorksmobileResourceUser, tenantRootID, "")
|
||||
if credentialBatchID != "" {
|
||||
query = query.Where("payload ->> 'credentialBatchId' = ?", credentialBatchID)
|
||||
}
|
||||
var rows []domain.WorksmobileOutbox
|
||||
err := query.Order("created_at desc").Find(&rows).Error
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (r *worksmobileOutboxRepository) UpdatePayload(ctx context.Context, id string, payload domain.JSONMap) error {
|
||||
return r.db.WithContext(ctx).Model(&domain.WorksmobileOutbox{}).Where("id = ?", id).Updates(map[string]any{
|
||||
"payload": payload,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *worksmobileOutboxRepository) DeletePendingByTenantRoot(ctx context.Context, tenantRootID string) (int64, error) {
|
||||
result := r.db.WithContext(ctx).
|
||||
Where("status = ? AND payload ->> 'tenantRootId' = ?", domain.WorksmobileOutboxStatusPending, tenantRootID).
|
||||
Delete(&domain.WorksmobileOutbox{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
func (r *worksmobileOutboxRepository) ListReady(ctx context.Context, limit int) ([]domain.WorksmobileOutbox, error) {
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 20
|
||||
}
|
||||
var rows []domain.WorksmobileOutbox
|
||||
err := r.db.WithContext(ctx).Raw(`
|
||||
WITH RECURSIVE candidates AS (
|
||||
SELECT
|
||||
*,
|
||||
NULLIF(payload #>> '{request,orgUnitExternalKey}', '') AS org_external_key,
|
||||
CASE
|
||||
WHEN payload #>> '{request,parentOrgUnitId}' LIKE 'externalKey:%'
|
||||
THEN NULLIF(substr(payload #>> '{request,parentOrgUnitId}', length('externalKey:') + 1), '')
|
||||
ELSE ''
|
||||
END AS parent_external_key
|
||||
FROM worksmobile_outboxes
|
||||
WHERE status = ? AND (next_attempt_at IS NULL OR next_attempt_at <= ?)
|
||||
),
|
||||
ready AS (
|
||||
SELECT candidates.*
|
||||
FROM candidates
|
||||
WHERE NOT (
|
||||
candidates.resource_type = ?
|
||||
AND candidates.action = ?
|
||||
AND candidates.parent_external_key <> ''
|
||||
AND EXISTS (
|
||||
SELECT 1
|
||||
FROM worksmobile_outboxes parent_job
|
||||
WHERE parent_job.resource_type = ?
|
||||
AND parent_job.action = ?
|
||||
AND parent_job.status <> ?
|
||||
AND NULLIF(parent_job.payload #>> '{request,orgUnitExternalKey}', '') = candidates.parent_external_key
|
||||
)
|
||||
)
|
||||
),
|
||||
org_depth AS (
|
||||
SELECT id, org_external_key, parent_external_key, 0 AS depth
|
||||
FROM ready
|
||||
UNION ALL
|
||||
SELECT child.id, child.org_external_key, child.parent_external_key, parent.depth + 1
|
||||
FROM ready child
|
||||
JOIN org_depth parent ON child.parent_external_key = parent.org_external_key
|
||||
WHERE child.resource_type = ? AND child.action = ? AND parent.depth < 64
|
||||
)
|
||||
SELECT ready.*
|
||||
FROM ready
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT max(depth) AS dependency_depth
|
||||
FROM org_depth
|
||||
WHERE org_depth.id = ready.id
|
||||
) AS depth_rank ON true
|
||||
ORDER BY
|
||||
CASE
|
||||
WHEN ready.resource_type = ? AND ready.action = ? THEN 0
|
||||
WHEN ready.resource_type = ? THEN 1
|
||||
ELSE 2
|
||||
END ASC,
|
||||
COALESCE(depth_rank.dependency_depth, 0) ASC,
|
||||
ready.created_at ASC
|
||||
LIMIT ?
|
||||
`,
|
||||
domain.WorksmobileOutboxStatusPending,
|
||||
time.Now(),
|
||||
domain.WorksmobileResourceOrgUnit,
|
||||
domain.WorksmobileActionUpsert,
|
||||
domain.WorksmobileResourceOrgUnit,
|
||||
domain.WorksmobileActionUpsert,
|
||||
domain.WorksmobileOutboxStatusProcessed,
|
||||
domain.WorksmobileResourceOrgUnit,
|
||||
domain.WorksmobileActionUpsert,
|
||||
domain.WorksmobileResourceOrgUnit,
|
||||
domain.WorksmobileActionUpsert,
|
||||
domain.WorksmobileResourceUser,
|
||||
limit,
|
||||
).Scan(&rows).Error
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (r *worksmobileOutboxRepository) FindByID(ctx context.Context, id string) (*domain.WorksmobileOutbox, error) {
|
||||
var row domain.WorksmobileOutbox
|
||||
if err := r.db.WithContext(ctx).First(&row, "id = ?", id).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &row, nil
|
||||
}
|
||||
|
||||
func (r *worksmobileOutboxRepository) MarkRetry(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Model(&domain.WorksmobileOutbox{}).Where("id = ?", id).Updates(map[string]any{
|
||||
"status": domain.WorksmobileOutboxStatusPending,
|
||||
"last_error": "",
|
||||
"next_attempt_at": nil,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *worksmobileOutboxRepository) MarkProcessing(ctx context.Context, id string) (bool, error) {
|
||||
result := r.db.WithContext(ctx).Model(&domain.WorksmobileOutbox{}).Where("id = ? AND status = ?", id, domain.WorksmobileOutboxStatusPending).Updates(map[string]any{
|
||||
"status": domain.WorksmobileOutboxStatusProcessing,
|
||||
"updated_at": time.Now(),
|
||||
})
|
||||
return result.RowsAffected > 0, result.Error
|
||||
}
|
||||
|
||||
func (r *worksmobileOutboxRepository) MarkProcessed(ctx context.Context, id string) error {
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Model(&domain.WorksmobileOutbox{}).Where("id = ?", id).Updates(map[string]any{
|
||||
"status": domain.WorksmobileOutboxStatusProcessed,
|
||||
"last_error": "",
|
||||
"processed_at": &now,
|
||||
"updated_at": now,
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *worksmobileOutboxRepository) MarkFailed(ctx context.Context, id string, message string, nextAttemptAt time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&domain.WorksmobileOutbox{}).Where("id = ?", id).Updates(map[string]any{
|
||||
"status": domain.WorksmobileOutboxStatusFailed,
|
||||
"retry_count": gorm.Expr("retry_count + 1"),
|
||||
"last_error": message,
|
||||
"next_attempt_at": &nextAttemptAt,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWorksmobileOutboxRepositoryDeletePendingByTenantRoot(t *testing.T) {
|
||||
repo := NewWorksmobileOutboxRepository(testDB)
|
||||
ctx := context.Background()
|
||||
|
||||
require.NoError(t, testDB.Exec("DELETE FROM worksmobile_outboxes").Error)
|
||||
|
||||
rows := []domain.WorksmobileOutbox{
|
||||
{
|
||||
ID: "00000000-0000-0000-0000-000000000101",
|
||||
ResourceType: domain.WorksmobileResourceUser,
|
||||
ResourceID: "user-pending",
|
||||
Action: domain.WorksmobileActionUpsert,
|
||||
Status: domain.WorksmobileOutboxStatusPending,
|
||||
DedupeKey: "pending-root",
|
||||
Payload: domain.JSONMap{"tenantRootId": "root-1"},
|
||||
},
|
||||
{
|
||||
ID: "00000000-0000-0000-0000-000000000102",
|
||||
ResourceType: domain.WorksmobileResourceUser,
|
||||
ResourceID: "user-other-root",
|
||||
Action: domain.WorksmobileActionUpsert,
|
||||
Status: domain.WorksmobileOutboxStatusPending,
|
||||
DedupeKey: "pending-other-root",
|
||||
Payload: domain.JSONMap{"tenantRootId": "root-2"},
|
||||
},
|
||||
{
|
||||
ID: "00000000-0000-0000-0000-000000000103",
|
||||
ResourceType: domain.WorksmobileResourceUser,
|
||||
ResourceID: "user-failed",
|
||||
Action: domain.WorksmobileActionUpsert,
|
||||
Status: domain.WorksmobileOutboxStatusFailed,
|
||||
DedupeKey: "failed-root",
|
||||
Payload: domain.JSONMap{"tenantRootId": "root-1"},
|
||||
},
|
||||
{
|
||||
ID: "00000000-0000-0000-0000-000000000104",
|
||||
ResourceType: domain.WorksmobileResourceOrgUnit,
|
||||
ResourceID: "org-processed",
|
||||
Action: domain.WorksmobileActionUpsert,
|
||||
Status: domain.WorksmobileOutboxStatusProcessed,
|
||||
DedupeKey: "processed-root",
|
||||
Payload: domain.JSONMap{"tenantRootId": "root-1"},
|
||||
},
|
||||
}
|
||||
for i := range rows {
|
||||
require.NoError(t, repo.Create(ctx, &rows[i]))
|
||||
}
|
||||
|
||||
deleted, err := repo.DeletePendingByTenantRoot(ctx, "root-1")
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), deleted)
|
||||
var remaining []domain.WorksmobileOutbox
|
||||
require.NoError(t, testDB.Order("id asc").Find(&remaining).Error)
|
||||
require.Len(t, remaining, 3)
|
||||
require.Equal(t, "00000000-0000-0000-0000-000000000102", remaining[0].ID)
|
||||
require.Equal(t, "00000000-0000-0000-0000-000000000103", remaining[1].ID)
|
||||
require.Equal(t, "00000000-0000-0000-0000-000000000104", remaining[2].ID)
|
||||
}
|
||||
|
||||
func TestWorksmobileOutboxRepositoryListReadyWaitsForPendingOrgUnitParent(t *testing.T) {
|
||||
repo := NewWorksmobileOutboxRepository(testDB)
|
||||
ctx := context.Background()
|
||||
|
||||
require.NoError(t, testDB.Exec("DELETE FROM worksmobile_outboxes").Error)
|
||||
|
||||
baseTime := time.Date(2026, 6, 2, 15, 21, 0, 0, time.UTC)
|
||||
child := domain.WorksmobileOutbox{
|
||||
ID: "00000000-0000-0000-0000-000000000201",
|
||||
ResourceType: domain.WorksmobileResourceOrgUnit,
|
||||
ResourceID: "child-tenant",
|
||||
Action: domain.WorksmobileActionUpsert,
|
||||
Status: domain.WorksmobileOutboxStatusPending,
|
||||
DedupeKey: "orgunit:upsert:child-tenant",
|
||||
Payload: domain.JSONMap{
|
||||
"request": map[string]any{
|
||||
"orgUnitExternalKey": "child-tenant",
|
||||
"parentOrgUnitId": "externalKey:parent-tenant",
|
||||
},
|
||||
},
|
||||
CreatedAt: baseTime,
|
||||
UpdatedAt: baseTime,
|
||||
}
|
||||
parent := domain.WorksmobileOutbox{
|
||||
ID: "00000000-0000-0000-0000-000000000202",
|
||||
ResourceType: domain.WorksmobileResourceOrgUnit,
|
||||
ResourceID: "parent-tenant",
|
||||
Action: domain.WorksmobileActionUpsert,
|
||||
Status: domain.WorksmobileOutboxStatusPending,
|
||||
DedupeKey: "orgunit:upsert:parent-tenant",
|
||||
Payload: domain.JSONMap{
|
||||
"request": map[string]any{
|
||||
"orgUnitExternalKey": "parent-tenant",
|
||||
},
|
||||
},
|
||||
CreatedAt: baseTime.Add(time.Second),
|
||||
UpdatedAt: baseTime.Add(time.Second),
|
||||
}
|
||||
require.NoError(t, testDB.Create(&child).Error)
|
||||
require.NoError(t, testDB.Create(&parent).Error)
|
||||
|
||||
rows, err := repo.ListReady(ctx, 10)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 1)
|
||||
require.Equal(t, "parent-tenant", rows[0].ResourceID)
|
||||
|
||||
require.NoError(t, repo.MarkProcessed(ctx, parent.ID))
|
||||
rows, err = repo.ListReady(ctx, 10)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 1)
|
||||
require.Equal(t, "child-tenant", rows[0].ResourceID)
|
||||
}
|
||||
Reference in New Issue
Block a user