첫 커밋: 로컬 프로젝트 업로드

This commit is contained in:
2026-06-10 15:51:34 +09:00
commit 6a8dbeb2e9
1211 changed files with 312864 additions and 0 deletions

View File

@@ -0,0 +1,449 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"encoding/json"
"fmt"
"time"
"github.com/ClickHouse/clickhouse-go/v2"
"github.com/ClickHouse/clickhouse-go/v2/lib/driver"
)
type ClickHouseRepository struct {
conn driver.Conn
}
func NewClickHouseRepository(host string, port int, user, password, db string) (*ClickHouseRepository, error) {
// 1. Connect to 'default' database first to ensure target DB exists
tmpConn, err := clickhouse.Open(&clickhouse.Options{
Addr: []string{fmt.Sprintf("%s:%d", host, port)},
Auth: clickhouse.Auth{
Database: "default",
Username: user,
Password: password,
},
})
if err == nil {
_ = tmpConn.Exec(context.Background(), fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", db))
_ = tmpConn.Close()
}
// 2. Now connect to the target database
conn, err := clickhouse.Open(&clickhouse.Options{
Addr: []string{fmt.Sprintf("%s:%d", host, port)},
Auth: clickhouse.Auth{
Database: db,
Username: user,
Password: password,
},
Debug: false,
})
if err != nil {
return nil, fmt.Errorf("failed to open clickhouse connection: %w", err)
}
if err := conn.Ping(context.Background()); err != nil {
return nil, fmt.Errorf("failed to ping clickhouse: %w", err)
}
// Ensure Table Exists
// Note: In production, use migrations.
query := `
CREATE TABLE IF NOT EXISTS audit_logs (
event_id String,
timestamp DateTime DEFAULT now(),
user_id String,
tenant_id String,
event_type String,
status String,
ip_address String,
user_agent String,
device_id String,
details String
) ENGINE = MergeTree()
ORDER BY timestamp
`
if err := conn.Exec(context.Background(), query); err != nil {
return nil, fmt.Errorf("failed to create table: %w", err)
}
alterQuery := `
ALTER TABLE audit_logs
ADD COLUMN IF NOT EXISTS tenant_id String,
ADD COLUMN IF NOT EXISTS event_id String
`
if err := conn.Exec(context.Background(), alterQuery); err != nil {
return nil, fmt.Errorf("failed to alter table: %w", err)
}
if err := ensureRPUsageTables(context.Background(), conn); err != nil {
return nil, fmt.Errorf("failed to create rp usage tables: %w", err)
}
return &ClickHouseRepository{conn: conn}, nil
}
func ensureRPUsageTables(ctx context.Context, conn driver.Conn) error {
factQuery := `
CREATE TABLE IF NOT EXISTS rp_usage_events (
event_id String,
occurred_at DateTime64(3) DEFAULT now64(3),
event_type String,
subject String,
tenant_id String,
tenant_type String,
client_id String,
client_name String,
session_id String,
scopes Array(String),
source String,
correlation_id String,
payload String
) ENGINE = MergeTree()
ORDER BY (occurred_at, event_id)
`
if err := conn.Exec(ctx, factQuery); err != nil {
return err
}
aggregateQuery := `
CREATE TABLE IF NOT EXISTS rp_usage_daily_aggregate (
event_date Date,
tenant_id String,
tenant_type String,
client_id String,
client_name String,
event_type String,
events_count AggregateFunction(count),
unique_subjects AggregateFunction(uniqExact, String)
) ENGINE = AggregatingMergeTree()
ORDER BY (event_date, tenant_id, client_id, event_type)
`
if err := conn.Exec(ctx, aggregateQuery); err != nil {
return err
}
viewQuery := `
CREATE MATERIALIZED VIEW IF NOT EXISTS rp_usage_daily_aggregate_mv
TO rp_usage_daily_aggregate
AS
SELECT
toDate(occurred_at) AS event_date,
tenant_id,
tenant_type,
client_id,
any(client_name) AS client_name,
event_type,
countState() AS events_count,
uniqExactState(subject) AS unique_subjects
FROM rp_usage_events
WHERE tenant_type IN ('COMPANY', 'ORGANIZATION')
GROUP BY event_date, tenant_id, tenant_type, client_id, event_type
`
return conn.Exec(ctx, viewQuery)
}
func (r *ClickHouseRepository) Create(log *domain.AuditLog) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if log.Timestamp.IsZero() {
log.Timestamp = time.Now()
}
query := `
INSERT INTO audit_logs (event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
return r.conn.Exec(ctx, query,
log.EventID,
log.Timestamp,
log.UserID,
log.TenantID,
log.EventType,
log.Status,
log.IPAddress,
log.UserAgent,
log.DeviceID,
log.Details,
)
}
func (r *ClickHouseRepository) CreateRPUsageEvent(ctx context.Context, event domain.RPUsageEvent) error {
if r == nil || r.conn == nil {
return fmt.Errorf("clickhouse connection is nil")
}
if event.OccurredAt.IsZero() {
event.OccurredAt = time.Now()
}
payloadBytes, err := json.Marshal(event.Payload)
if err != nil {
return fmt.Errorf("failed to marshal rp usage payload: %w", err)
}
query := `
INSERT INTO rp_usage_events (
event_id, occurred_at, event_type, subject, tenant_id, tenant_type,
client_id, client_name, session_id, scopes, source, correlation_id, payload
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
return r.conn.Exec(ctx, query,
event.ID,
event.OccurredAt,
event.EventType,
event.Subject,
event.TenantID,
event.TenantType,
event.ClientID,
event.ClientName,
event.SessionID,
[]string(event.Scopes),
event.Source,
event.CorrelationID,
string(payloadBytes),
)
}
func (r *ClickHouseRepository) FindRPUsage(ctx context.Context, rpQuery domain.RPUsageQuery) ([]domain.RPUsageDailyMetric, error) {
if r == nil || r.conn == nil {
return nil, fmt.Errorf("clickhouse connection is nil")
}
days := rpQuery.Days
if days <= 0 || days > 90 {
days = 14
}
periodExpr := "event_date"
switch rpQuery.Period {
case "week":
periodExpr = "toMonday(event_date)"
case "month":
periodExpr = "toStartOfMonth(event_date)"
case "day", "":
periodExpr = "event_date"
default:
periodExpr = "event_date"
}
query := fmt.Sprintf(`
SELECT
date,
tenant_id,
tenant_type,
client_id,
any(client_name) AS client_name,
sumIf(events, event_type = ?) AS login_requests,
sumIf(events, event_type != ?) AS other_requests,
max(unique_subjects) AS unique_subjects
FROM (
SELECT
toString(%s) AS date,
tenant_id,
tenant_type,
client_id,
any(client_name) AS client_name,
event_type,
countMerge(events_count) AS events,
uniqExactMerge(unique_subjects) AS unique_subjects
FROM rp_usage_daily_aggregate
WHERE event_date >= today() - ?
AND tenant_type IN ('COMPANY', 'ORGANIZATION')
`, periodExpr)
args := []any{domain.RPUsageEventTypeAuthorizationGranted, domain.RPUsageEventTypeAuthorizationGranted, days - 1}
if rpQuery.TenantID != "" {
query += " AND tenant_id = ?\n"
args = append(args, rpQuery.TenantID)
}
query += fmt.Sprintf(`
GROUP BY %s, tenant_id, tenant_type, client_id, event_type
)
GROUP BY date, tenant_id, tenant_type, client_id
ORDER BY date ASC, tenant_id ASC, client_id ASC
`, periodExpr)
rows, err := r.conn.Query(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to query rp usage daily aggregate: %w", err)
}
defer rows.Close()
metrics := make([]domain.RPUsageDailyMetric, 0)
for rows.Next() {
var metric domain.RPUsageDailyMetric
if err := rows.Scan(
&metric.Date,
&metric.TenantID,
&metric.TenantType,
&metric.ClientID,
&metric.ClientName,
&metric.LoginRequests,
&metric.OtherRequests,
&metric.UniqueSubjects,
); err != nil {
return nil, fmt.Errorf("failed to scan rp usage daily aggregate: %w", err)
}
if metric.ClientName == "" {
metric.ClientName = metric.ClientID
}
metrics = append(metrics, metric)
}
return metrics, nil
}
func (r *ClickHouseRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor, tenantID string) ([]domain.AuditLog, error) {
if limit <= 0 {
limit = 50
}
query := `
SELECT event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details
FROM audit_logs
WHERE 1=1
`
args := make([]any, 0, 5)
if tenantID != "" {
query += " AND (tenant_id = ? OR (tenant_id = '' AND JSONExtractString(details, 'tenant_id') = ?))"
args = append(args, tenantID, tenantID)
}
if cursor != nil {
query += `
AND ((timestamp < ?) OR (timestamp = ? AND event_id < ?))
`
args = append(args, cursor.Timestamp, cursor.Timestamp, cursor.EventID)
}
query += `
ORDER BY timestamp DESC, event_id DESC
LIMIT ?
`
args = append(args, limit)
rows, err := r.conn.Query(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to query audit logs: %w", err)
}
defer rows.Close()
var logs []domain.AuditLog
for rows.Next() {
var log domain.AuditLog
if err := rows.Scan(
&log.EventID,
&log.Timestamp,
&log.UserID,
&log.TenantID,
&log.EventType,
&log.Status,
&log.IPAddress,
&log.UserAgent,
&log.DeviceID,
&log.Details,
); err != nil {
return nil, fmt.Errorf("failed to scan audit log: %w", err)
}
logs = append(logs, log)
}
return logs, nil
}
func (r *ClickHouseRepository) FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]domain.AuditLog, error) {
if limit <= 0 {
limit = 100
}
query := `
SELECT event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details
FROM audit_logs
WHERE user_id = ? AND event_type IN (?)
ORDER BY timestamp DESC
LIMIT ?
`
rows, err := r.conn.Query(ctx, query, userID, eventTypes, limit)
if err != nil {
return nil, fmt.Errorf("failed to query audit logs by user/events: %w", err)
}
defer rows.Close()
var logs []domain.AuditLog
for rows.Next() {
var log domain.AuditLog
if err := rows.Scan(
&log.EventID,
&log.Timestamp,
&log.UserID,
&log.TenantID,
&log.EventType,
&log.Status,
&log.IPAddress,
&log.UserAgent,
&log.DeviceID,
&log.Details,
); err != nil {
return nil, fmt.Errorf("failed to scan audit log: %w", err)
}
logs = append(logs, log)
}
return logs, nil
}
func (r *ClickHouseRepository) Ping(ctx context.Context) error {
if r.conn == nil {
return fmt.Errorf("clickhouse connection is nil")
}
return r.conn.Ping(ctx)
}
func (r *ClickHouseRepository) CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
query := `
SELECT count()
FROM audit_logs
WHERE status = 'failure' AND timestamp >= ?
`
args := []any{since}
if tenantID != "" {
query += " AND JSONExtractString(details, 'tenant_id') = ?"
args = append(args, tenantID)
}
var count int64
err := r.conn.QueryRow(ctx, query, args...).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to count failures: %w", err)
}
return count, nil
}
func (r *ClickHouseRepository) CountEventsSince(ctx context.Context, since time.Time) (int64, error) {
sinceUTC := since.UTC().Format("2006-01-02 15:04:05")
query := fmt.Sprintf(`
SELECT count()
FROM audit_logs
WHERE timestamp >= toDateTime('%s')
`, sinceUTC)
var count int64
err := r.conn.QueryRow(ctx, query).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to count audit events: %w", err)
}
return count, nil
}
func (r *ClickHouseRepository) CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
// We use uniqExact(session_id) to count unique sessions that had success events recently.
query := `
SELECT uniqExact(session_id)
FROM audit_logs
WHERE status = 'success' AND timestamp >= ? AND session_id != ''
`
args := []any{since}
if tenantID != "" {
query += " AND JSONExtractString(details, 'tenant_id') = ?"
args = append(args, tenantID)
}
var count int64
err := r.conn.QueryRow(ctx, query, args...).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to count active sessions: %w", err)
}
return count, nil
}

View File

@@ -0,0 +1,138 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"errors"
"gorm.io/gorm"
)
type ClientConsentRepository interface {
Upsert(ctx context.Context, consent *domain.ClientConsent) error
Delete(ctx context.Context, subject, clientID string) error
DeleteByClient(ctx context.Context, clientID string) error
List(ctx context.Context, clientID string, limit, offset int) ([]domain.ClientConsentWithTenantInfo, int64, error)
ListByTenant(ctx context.Context, clientID, tenantID string, limit, offset int) ([]domain.ClientConsentWithTenantInfo, int64, error)
ListBySubject(ctx context.Context, subject string) ([]domain.ClientConsent, error)
ListSubjectsByClient(ctx context.Context, clientID string) ([]string, error)
Find(ctx context.Context, clientID, subject string) (*domain.ClientConsent, error)
}
type clientConsentRepo struct {
db *gorm.DB
}
func NewClientConsentRepository(db *gorm.DB) ClientConsentRepository {
return &clientConsentRepo{db: db}
}
func (r *clientConsentRepo) Find(ctx context.Context, clientID, subject string) (*domain.ClientConsent, error) {
var consent domain.ClientConsent
err := r.db.WithContext(ctx).
Where("client_id = ? AND subject = ?", clientID, subject).
First(&consent).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &consent, nil
}
func (r *clientConsentRepo) Upsert(ctx context.Context, consent *domain.ClientConsent) error {
return r.db.WithContext(ctx).Unscoped().
Where("client_id = ? AND subject = ?", consent.ClientID, consent.Subject).
Assign(map[string]any{
"granted_scopes": consent.GrantedScopes,
"updated_at": gorm.Expr("NOW()"),
"deleted_at": nil,
}).
FirstOrCreate(consent).Error
}
func (r *clientConsentRepo) Delete(ctx context.Context, subject, clientID string) error {
return r.db.WithContext(ctx).
Where("subject = ? AND client_id = ?", subject, clientID).
Delete(&domain.ClientConsent{}).Error
}
func (r *clientConsentRepo) DeleteByClient(ctx context.Context, clientID string) error {
return r.db.WithContext(ctx).
Where("client_id = ?", clientID).
Delete(&domain.ClientConsent{}).Error
}
func (r *clientConsentRepo) List(ctx context.Context, clientID string, limit, offset int) ([]domain.ClientConsentWithTenantInfo, int64, error) {
var consents []domain.ClientConsentWithTenantInfo
var total int64
// Base query for counting
countQuery := r.db.WithContext(ctx).Unscoped().Model(&domain.ClientConsent{}).Where("client_id = ?", clientID)
if err := countQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// Query for fetching data
query := r.db.WithContext(ctx).Unscoped().
Model(&domain.ClientConsent{}).
Select("client_consents.*, users.tenant_id, tenants.name as tenant_name").
Joins("LEFT JOIN users ON users.id::text = client_consents.subject").
Joins("LEFT JOIN tenants ON tenants.id = users.tenant_id").
Where("client_consents.client_id = ?", clientID)
err := query.Limit(limit).Offset(offset).Order("client_consents.updated_at DESC").Scan(&consents).Error
return consents, total, err
}
func (r *clientConsentRepo) ListByTenant(ctx context.Context, clientID, tenantID string, limit, offset int) ([]domain.ClientConsentWithTenantInfo, int64, error) {
var consents []domain.ClientConsentWithTenantInfo
var total int64
// Base query for counting
countQuery := r.db.WithContext(ctx).Unscoped().
Model(&domain.ClientConsent{}).
Joins("JOIN users ON users.id::text = client_consents.subject").
Where("client_consents.client_id = ? AND users.tenant_id = ?", clientID, tenantID)
if err := countQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// Query for fetching data
query := r.db.WithContext(ctx).Unscoped().
Model(&domain.ClientConsent{}).
Select("client_consents.*, users.tenant_id, tenants.name as tenant_name").
Joins("JOIN users ON users.id::text = client_consents.subject").
Joins("JOIN tenants ON tenants.id = users.tenant_id").
Where("client_consents.client_id = ? AND users.tenant_id = ?", clientID, tenantID)
err := query.
Limit(limit).
Offset(offset).
Order("client_consents.updated_at DESC").
Scan(&consents).Error
return consents, total, err
}
func (r *clientConsentRepo) ListBySubject(ctx context.Context, subject string) ([]domain.ClientConsent, error) {
var consents []domain.ClientConsent
err := r.db.WithContext(ctx).Unscoped().
Where("subject = ?", subject).
Order("updated_at DESC").
Find(&consents).Error
return consents, err
}
func (r *clientConsentRepo) ListSubjectsByClient(ctx context.Context, clientID string) ([]string, error) {
var subjects []string
err := r.db.WithContext(ctx).Unscoped().
Model(&domain.ClientConsent{}).
Distinct("subject").
Where("client_id = ?", clientID).
Order("subject ASC").
Pluck("subject", &subjects).Error
return subjects, err
}

View File

@@ -0,0 +1,31 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"testing"
"github.com/lib/pq"
"github.com/stretchr/testify/assert"
)
func TestClientConsentRepository_Find_IgnoresSoftDeletedConsent(t *testing.T) {
repo := NewClientConsentRepository(testDB)
ctx := context.Background()
consent := &domain.ClientConsent{
ClientID: "client-soft-delete",
Subject: "user-soft-delete",
GrantedScopes: pq.StringArray{"openid", "profile"},
}
err := repo.Upsert(ctx, consent)
assert.NoError(t, err)
err = repo.Delete(ctx, consent.Subject, consent.ClientID)
assert.NoError(t, err)
found, err := repo.Find(ctx, consent.ClientID, consent.Subject)
assert.NoError(t, err)
assert.Nil(t, found)
}

View File

@@ -0,0 +1,40 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type clientSecretRepository struct {
db *gorm.DB
}
func NewClientSecretRepository(db *gorm.DB) domain.ClientSecretRepository {
return &clientSecretRepository{db: db}
}
func (r *clientSecretRepository) Upsert(ctx context.Context, clientID, secret string) error {
cs := domain.ClientSecret{
ClientID: clientID,
ClientSecret: secret,
}
return r.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "client_id"}},
DoUpdates: clause.AssignmentColumns([]string{"client_secret", "updated_at"}),
}).Create(&cs).Error
}
func (r *clientSecretRepository) GetByID(ctx context.Context, clientID string) (string, error) {
var cs domain.ClientSecret
if err := r.db.WithContext(ctx).Where("client_id = ?", clientID).First(&cs).Error; err != nil {
return "", err
}
return cs.ClientSecret, nil
}
func (r *clientSecretRepository) Delete(ctx context.Context, clientID string) error {
return r.db.WithContext(ctx).Where("client_id = ?", clientID).Delete(&domain.ClientSecret{}).Error
}

View File

@@ -0,0 +1,380 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"slices"
"strings"
"time"
"gorm.io/gorm"
)
type DataIntegrityChecker interface {
CheckDataIntegrity(ctx context.Context) (domain.DataIntegrityReport, error)
ListOrphanUserLoginIDs(ctx context.Context) ([]domain.OrphanUserLoginID, error)
DeleteOrphanUserLoginIDs(ctx context.Context, ids []string) (domain.DeleteOrphanUserLoginIDsResult, error)
}
type dataIntegrityChecker struct {
db *gorm.DB
}
func NewDataIntegrityChecker(db *gorm.DB) DataIntegrityChecker {
return &dataIntegrityChecker{db: db}
}
func (c *dataIntegrityChecker) CheckDataIntegrity(ctx context.Context) (domain.DataIntegrityReport, error) {
return CheckDataIntegrity(ctx, c.db)
}
func (c *dataIntegrityChecker) ListOrphanUserLoginIDs(ctx context.Context) ([]domain.OrphanUserLoginID, error) {
return ListOrphanUserLoginIDs(ctx, c.db, nil)
}
func (c *dataIntegrityChecker) DeleteOrphanUserLoginIDs(ctx context.Context, ids []string) (domain.DeleteOrphanUserLoginIDsResult, error) {
return DeleteOrphanUserLoginIDs(ctx, c.db, ids)
}
func CheckDataIntegrity(ctx context.Context, db *gorm.DB) (domain.DataIntegrityReport, error) {
tenantChecks := []domain.DataIntegrityCheck{
{
Key: "duplicate_tenant_slugs",
Label: "중복 테넌트 slug",
Description: "삭제되지 않은 tenant의 slug를 대소문자 무시 기준으로 검사합니다.",
Severity: "error",
Count: 0,
},
{
Key: "orphan_tenant_parents",
Label: "유령 상위 테넌트 참조",
Description: "tenant.parent_id가 없거나 삭제된 tenant를 참조하는지 검사합니다.",
Severity: "error",
Count: 0,
},
}
userChecks := []domain.DataIntegrityCheck{
{
Key: "orphan_user_tenant_memberships",
Label: "유령 테넌트 사용자 소속",
Description: "users.tenant_id가 없거나 삭제된 tenant를 참조하는지 검사합니다.",
Severity: "error",
Count: 0,
},
{
Key: "orphan_user_login_id_tenants",
Label: "유령 테넌트 로그인 ID",
Description: "user_login_ids.tenant_id가 없거나 삭제된 tenant를 참조하는지 검사합니다.",
Severity: "error",
Count: 0,
},
{
Key: "orphan_user_login_id_users",
Label: "유령 사용자 로그인 ID",
Description: "user_login_ids.user_id가 없거나 삭제된 user를 참조하는지 검사합니다.",
Severity: "error",
Count: 0,
},
}
counts := []struct {
target *int64
query string
}{
{
target: &tenantChecks[0].Count,
query: `
SELECT COUNT(*)
FROM (
SELECT LOWER(TRIM(slug)) AS normalized_slug
FROM tenants
WHERE deleted_at IS NULL
AND status <> 'deleted'
AND TRIM(slug) <> ''
GROUP BY LOWER(TRIM(slug))
HAVING COUNT(*) > 1
) AS duplicate_slugs
`,
},
{
target: &tenantChecks[1].Count,
query: `
SELECT COUNT(*)
FROM tenants AS child
WHERE child.deleted_at IS NULL
AND child.parent_id IS NOT NULL
AND NOT EXISTS (
SELECT 1
FROM tenants AS parent
WHERE parent.id = child.parent_id
AND parent.deleted_at IS NULL
)
`,
},
{
target: &userChecks[0].Count,
query: `
SELECT COUNT(*)
FROM users AS u
WHERE u.deleted_at IS NULL
AND u.tenant_id IS NOT NULL
AND NOT EXISTS (
SELECT 1
FROM tenants AS t
WHERE t.id = u.tenant_id
AND t.deleted_at IS NULL
)
`,
},
{
target: &userChecks[1].Count,
query: `
SELECT COUNT(*)
FROM user_login_ids AS uli
WHERE NOT EXISTS (
SELECT 1
FROM tenants AS t
WHERE t.id = uli.tenant_id
AND t.deleted_at IS NULL
)
`,
},
{
target: &userChecks[2].Count,
query: `
SELECT COUNT(*)
FROM user_login_ids AS uli
WHERE NOT EXISTS (
SELECT 1
FROM users AS u
WHERE u.id = uli.user_id
AND u.deleted_at IS NULL
)
`,
},
}
for _, item := range counts {
if err := db.WithContext(ctx).Raw(item.query).Scan(item.target).Error; err != nil {
return domain.DataIntegrityReport{}, err
}
}
tenantChecks = applyIntegrityStatuses(tenantChecks)
userChecks = applyIntegrityStatuses(userChecks)
sections := []domain.DataIntegritySection{
{
Key: "tenant_integrity",
Label: "테넌트 정합성",
Status: summarizeIntegrityStatus(tenantChecks),
Checks: tenantChecks,
},
{
Key: "user_integrity",
Label: "사용자 정합성",
Status: summarizeIntegrityStatus(userChecks),
Checks: userChecks,
},
}
summary := domain.DataIntegritySummary{}
for _, section := range sections {
for _, check := range section.Checks {
summary.TotalChecks++
switch check.Status {
case domain.DataIntegrityStatusFail:
summary.Failures += check.Count
case domain.DataIntegrityStatusWarning:
summary.Warnings++
default:
summary.Passed++
}
}
}
return domain.DataIntegrityReport{
Status: summarizeSectionStatus(sections),
CheckedAt: time.Now().UTC(),
Summary: summary,
Sections: sections,
}, nil
}
func applyIntegrityStatuses(checks []domain.DataIntegrityCheck) []domain.DataIntegrityCheck {
for i := range checks {
if checks[i].Count > 0 {
checks[i].Status = domain.DataIntegrityStatusFail
} else {
checks[i].Status = domain.DataIntegrityStatusPass
}
}
return checks
}
func summarizeIntegrityStatus(checks []domain.DataIntegrityCheck) domain.DataIntegrityStatus {
status := domain.DataIntegrityStatusPass
for _, check := range checks {
if check.Status == domain.DataIntegrityStatusFail {
return domain.DataIntegrityStatusFail
}
if check.Status == domain.DataIntegrityStatusWarning {
status = domain.DataIntegrityStatusWarning
}
}
return status
}
func summarizeSectionStatus(sections []domain.DataIntegritySection) domain.DataIntegrityStatus {
status := domain.DataIntegrityStatusPass
for _, section := range sections {
if section.Status == domain.DataIntegrityStatusFail {
return domain.DataIntegrityStatusFail
}
if section.Status == domain.DataIntegrityStatusWarning {
status = domain.DataIntegrityStatusWarning
}
}
return status
}
func ListOrphanUserLoginIDs(ctx context.Context, db *gorm.DB, ids []string) ([]domain.OrphanUserLoginID, error) {
type orphanRow struct {
ID string
UserID string
UserEmail string
UserDeletedAt *time.Time
TenantID string
TenantSlug string
TenantDeletedAt *time.Time
FieldKey string
LoginID string
MissingUser bool
DeletedUser bool
MissingTenant bool
DeletedTenant bool
}
query := `
SELECT
uli.id,
uli.user_id,
COALESCE(u.email, '') AS user_email,
u.deleted_at AS user_deleted_at,
uli.tenant_id,
COALESCE(t.slug, '') AS tenant_slug,
t.deleted_at AS tenant_deleted_at,
uli.field_key,
uli.login_id,
(u.id IS NULL) AS missing_user,
(u.id IS NOT NULL AND u.deleted_at IS NOT NULL) AS deleted_user,
(t.id IS NULL) AS missing_tenant,
(t.id IS NOT NULL AND t.deleted_at IS NOT NULL) AS deleted_tenant
FROM user_login_ids AS uli
LEFT JOIN users AS u ON u.id = uli.user_id
LEFT JOIN tenants AS t ON t.id = uli.tenant_id
WHERE (
u.id IS NULL
OR u.deleted_at IS NOT NULL
OR t.id IS NULL
OR t.deleted_at IS NOT NULL
)
`
args := []any{}
if len(ids) > 0 {
query += " AND uli.id IN ?\n"
args = append(args, ids)
}
query += "ORDER BY uli.login_id, uli.id"
var rows []orphanRow
if err := db.WithContext(ctx).Raw(query, args...).Scan(&rows).Error; err != nil {
return nil, err
}
items := make([]domain.OrphanUserLoginID, 0, len(rows))
for _, row := range rows {
reasons := make([]string, 0, 4)
if row.MissingUser {
reasons = append(reasons, "missing_user")
}
if row.DeletedUser {
reasons = append(reasons, "deleted_user")
}
if row.MissingTenant {
reasons = append(reasons, "missing_tenant")
}
if row.DeletedTenant {
reasons = append(reasons, "deleted_tenant")
}
items = append(items, domain.OrphanUserLoginID{
ID: row.ID,
UserID: row.UserID,
UserEmail: row.UserEmail,
UserDeletedAt: row.UserDeletedAt,
TenantID: row.TenantID,
TenantSlug: row.TenantSlug,
TenantDeletedAt: row.TenantDeletedAt,
FieldKey: row.FieldKey,
LoginID: row.LoginID,
Reasons: reasons,
})
}
return items, nil
}
func DeleteOrphanUserLoginIDs(ctx context.Context, db *gorm.DB, ids []string) (domain.DeleteOrphanUserLoginIDsResult, error) {
ids = normalizeIDList(ids)
result := domain.DeleteOrphanUserLoginIDsResult{
Deleted: []domain.OrphanUserLoginID{},
SkippedIDs: []string{},
}
if len(ids) == 0 {
return result, nil
}
err := db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
items, err := ListOrphanUserLoginIDs(ctx, tx, ids)
if err != nil {
return err
}
deletableIDs := make([]string, 0, len(items))
deletableSet := make(map[string]bool, len(items))
for _, item := range items {
deletableIDs = append(deletableIDs, item.ID)
deletableSet[item.ID] = true
}
for _, id := range ids {
if !deletableSet[id] {
result.SkippedIDs = append(result.SkippedIDs, id)
}
}
if len(deletableIDs) == 0 {
return nil
}
deleteResult := tx.Exec("DELETE FROM user_login_ids WHERE id IN ?", deletableIDs)
if deleteResult.Error != nil {
return deleteResult.Error
}
result.Deleted = items
result.DeletedCount = deleteResult.RowsAffected
return nil
})
return result, err
}
func normalizeIDList(ids []string) []string {
normalized := make([]string, 0, len(ids))
seen := map[string]bool{}
for _, id := range ids {
id = strings.TrimSpace(id)
if id == "" || seen[id] {
continue
}
seen[id] = true
normalized = append(normalized, id)
}
slices.Sort(normalized)
return normalized
}

View File

@@ -0,0 +1,312 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
)
func TestCheckDataIntegrityDetectsTenantAndUserProblems(t *testing.T) {
ctx := context.Background()
suffix := uuid.NewString()
parent := domain.Tenant{
ID: uuid.NewString(),
Name: "Deleted Parent " + suffix,
Slug: "deleted-parent-" + suffix,
Type: domain.TenantTypeCompany,
Status: domain.TenantStatusActive,
}
child := domain.Tenant{
ID: uuid.NewString(),
Name: "Orphan Child " + suffix,
Slug: "orphan-child-" + suffix,
Type: domain.TenantTypeOrganization,
ParentID: &parent.ID,
Status: domain.TenantStatusActive,
}
dupA := domain.Tenant{
ID: uuid.NewString(),
Name: "Duplicate A " + suffix,
Slug: "Dup-" + suffix,
Type: domain.TenantTypeCompany,
Status: domain.TenantStatusActive,
}
dupB := domain.Tenant{
ID: uuid.NewString(),
Name: "Duplicate B " + suffix,
Slug: "dup-" + suffix,
Type: domain.TenantTypeCompany,
Status: domain.TenantStatusActive,
}
require.NoError(t, testDB.Create(&parent).Error)
require.NoError(t, testDB.Create(&child).Error)
require.NoError(t, testDB.Create(&dupA).Error)
require.NoError(t, testDB.Create(&dupB).Error)
require.NoError(t, testDB.Delete(&domain.Tenant{}, "id = ?", parent.ID).Error)
orphanUser := domain.User{
ID: uuid.NewString(),
Email: "orphan-" + suffix + "@example.com",
Name: "Orphan User",
Role: domain.RoleUser,
TenantID: &parent.ID,
Status: domain.UserStatusActive,
CreatedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
}
deletedLoginUser := domain.User{
ID: uuid.NewString(),
Email: "deleted-login-user-" + suffix + "@example.com",
Name: "Deleted Login User",
Role: domain.RoleUser,
TenantID: &child.ID,
Status: domain.UserStatusActive,
CreatedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
}
require.NoError(t, testDB.Create(&orphanUser).Error)
require.NoError(t, testDB.Create(&deletedLoginUser).Error)
require.NoError(t, testDB.Create(&domain.UserLoginID{
ID: uuid.NewString(),
UserID: orphanUser.ID,
TenantID: parent.ID,
FieldKey: "emp_id",
LoginID: "EMP-" + suffix,
}).Error)
require.NoError(t, testDB.Create(&domain.UserLoginID{
ID: uuid.NewString(),
UserID: deletedLoginUser.ID,
TenantID: child.ID,
FieldKey: "emp_id",
LoginID: "MISSING-" + suffix,
}).Error)
require.NoError(t, testDB.Delete(&domain.User{}, "id = ?", deletedLoginUser.ID).Error)
report, err := CheckDataIntegrity(ctx, testDB)
require.NoError(t, err)
require.Equal(t, domain.DataIntegrityStatusFail, report.Status)
require.Equal(t, int64(5), report.Summary.Failures) // Reverted back to 5 due to successful soft delete simulation
requireIntegrityCheck(t, report, "tenant_integrity", "duplicate_tenant_slugs", domain.DataIntegrityStatusFail, 1)
requireIntegrityCheck(t, report, "tenant_integrity", "orphan_tenant_parents", domain.DataIntegrityStatusFail, 1)
requireIntegrityCheck(t, report, "user_integrity", "orphan_user_tenant_memberships", domain.DataIntegrityStatusFail, 1)
requireIntegrityCheck(t, report, "user_integrity", "orphan_user_login_id_tenants", domain.DataIntegrityStatusFail, 1)
requireIntegrityCheck(t, report, "user_integrity", "orphan_user_login_id_users", domain.DataIntegrityStatusFail, 1)
}
func TestCheckDataIntegrityDetectsHardOrphanUserLoginIDRows(t *testing.T) {
ctx := context.Background()
suffix := uuid.NewString()
rollback := errors.New("rollback hard orphan fixture")
err := testDB.Transaction(func(tx *gorm.DB) error {
var constraintNames []string
if err := tx.Raw(`
SELECT conname
FROM pg_constraint
WHERE conrelid = 'user_login_ids'::regclass
AND contype = 'f'
`).Scan(&constraintNames).Error; err != nil {
return err
}
for _, constraintName := range constraintNames {
statement := fmt.Sprintf("ALTER TABLE user_login_ids DROP CONSTRAINT %s", pq.QuoteIdentifier(constraintName))
if err := tx.Exec(statement).Error; err != nil {
return err
}
}
before, err := CheckDataIntegrity(ctx, tx)
if err != nil {
return err
}
beforeTenantCount, err := integrityCheckCount(before, "user_integrity", "orphan_user_login_id_tenants")
if err != nil {
return err
}
beforeUserCount, err := integrityCheckCount(before, "user_integrity", "orphan_user_login_id_users")
if err != nil {
return err
}
if err := tx.Create(&domain.UserLoginID{
ID: uuid.NewString(),
UserID: uuid.NewString(),
TenantID: uuid.NewString(),
FieldKey: "emp_id",
LoginID: "HARD-ORPHAN-" + suffix,
}).Error; err != nil {
return err
}
report, err := CheckDataIntegrity(ctx, tx)
if err != nil {
return err
}
if err := expectIntegrityCheck(report, "user_integrity", "orphan_user_login_id_tenants", domain.DataIntegrityStatusFail, beforeTenantCount+1); err != nil {
return err
}
if err := expectIntegrityCheck(report, "user_integrity", "orphan_user_login_id_users", domain.DataIntegrityStatusFail, beforeUserCount+1); err != nil {
return err
}
return rollback
})
require.ErrorIs(t, err, rollback)
}
func TestListAndDeleteOrphanUserLoginIDsOnlyDeletesRevalidatedTargets(t *testing.T) {
ctx := context.Background()
suffix := uuid.NewString()
validTenant := domain.Tenant{
ID: uuid.NewString(),
Name: "Valid Tenant " + suffix,
Slug: "valid-tenant-" + suffix,
Type: domain.TenantTypeCompany,
Status: domain.TenantStatusActive,
}
deletedTenant := domain.Tenant{
ID: uuid.NewString(),
Name: "Deleted Tenant " + suffix,
Slug: "deleted-tenant-" + suffix,
Type: domain.TenantTypeCompany,
Status: domain.TenantStatusActive,
}
require.NoError(t, testDB.Create(&validTenant).Error)
require.NoError(t, testDB.Create(&deletedTenant).Error)
validUser := domain.User{
ID: uuid.NewString(),
Email: "valid-login-" + suffix + "@example.com",
Name: "Valid Login User",
Role: domain.RoleUser,
TenantID: &validTenant.ID,
Status: domain.UserStatusActive,
CreatedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
}
deletedUser := domain.User{
ID: uuid.NewString(),
Email: "deleted-login-" + suffix + "@example.com",
Name: "Deleted Login User",
Role: domain.RoleUser,
TenantID: &validTenant.ID,
Status: domain.UserStatusActive,
CreatedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
}
require.NoError(t, testDB.Create(&validUser).Error)
require.NoError(t, testDB.Create(&deletedUser).Error)
validLogin := domain.UserLoginID{
ID: uuid.NewString(),
UserID: validUser.ID,
TenantID: validTenant.ID,
FieldKey: "emp_id",
LoginID: "VALID-" + suffix,
}
deletedTenantLogin := domain.UserLoginID{
ID: uuid.NewString(),
UserID: validUser.ID,
TenantID: deletedTenant.ID,
FieldKey: "emp_id",
LoginID: "DELETED-TENANT-" + suffix,
}
deletedUserLogin := domain.UserLoginID{
ID: uuid.NewString(),
UserID: deletedUser.ID,
TenantID: validTenant.ID,
FieldKey: "emp_id",
LoginID: "DELETED-USER-" + suffix,
}
require.NoError(t, testDB.Create(&validLogin).Error)
require.NoError(t, testDB.Create(&deletedTenantLogin).Error)
require.NoError(t, testDB.Create(&deletedUserLogin).Error)
require.NoError(t, testDB.Delete(&domain.Tenant{}, "id = ?", deletedTenant.ID).Error)
require.NoError(t, testDB.Delete(&domain.User{}, "id = ?", deletedUser.ID).Error)
items, err := ListOrphanUserLoginIDs(ctx, testDB, nil)
require.NoError(t, err)
orphanReasons := map[string][]string{}
for _, item := range items {
orphanReasons[item.ID] = item.Reasons
}
require.Equal(t, []string{"deleted_tenant"}, orphanReasons[deletedTenantLogin.ID])
require.Equal(t, []string{"deleted_user"}, orphanReasons[deletedUserLogin.ID])
require.NotContains(t, orphanReasons, validLogin.ID)
result, err := DeleteOrphanUserLoginIDs(ctx, testDB, []string{
deletedTenantLogin.ID,
validLogin.ID,
"00000000-0000-0000-0000-000000000000",
})
require.NoError(t, err)
require.Equal(t, int64(1), result.DeletedCount)
require.Len(t, result.Deleted, 1)
require.Equal(t, deletedTenantLogin.ID, result.Deleted[0].ID)
require.ElementsMatch(t, []string{
validLogin.ID,
"00000000-0000-0000-0000-000000000000",
}, result.SkippedIDs)
var deletedTenantLoginCount int64
require.NoError(t, testDB.Model(&domain.UserLoginID{}).Where("id = ?", deletedTenantLogin.ID).Count(&deletedTenantLoginCount).Error)
require.Equal(t, int64(0), deletedTenantLoginCount)
var validLoginCount int64
require.NoError(t, testDB.Model(&domain.UserLoginID{}).Where("id = ?", validLogin.ID).Count(&validLoginCount).Error)
require.Equal(t, int64(1), validLoginCount)
}
func requireIntegrityCheck(t *testing.T, report domain.DataIntegrityReport, sectionKey, checkKey string, status domain.DataIntegrityStatus, count int64) {
t.Helper()
require.NoError(t, expectIntegrityCheck(report, sectionKey, checkKey, status, count))
}
func expectIntegrityCheck(report domain.DataIntegrityReport, sectionKey, checkKey string, status domain.DataIntegrityStatus, count int64) error {
check, ok := findIntegrityCheck(report, sectionKey, checkKey)
if !ok {
return fmt.Errorf("integrity check %s/%s not found", sectionKey, checkKey)
}
if check.Status != status {
return fmt.Errorf("integrity check %s/%s status = %s, want %s", sectionKey, checkKey, check.Status, status)
}
if check.Count != count {
return fmt.Errorf("integrity check %s/%s count = %d, want %d", sectionKey, checkKey, check.Count, count)
}
return nil
}
func integrityCheckCount(report domain.DataIntegrityReport, sectionKey, checkKey string) (int64, error) {
check, ok := findIntegrityCheck(report, sectionKey, checkKey)
if !ok {
return 0, fmt.Errorf("integrity check %s/%s not found", sectionKey, checkKey)
}
return check.Count, nil
}
func findIntegrityCheck(report domain.DataIntegrityReport, sectionKey, checkKey string) (domain.DataIntegrityCheck, bool) {
for _, section := range report.Sections {
if section.Key != sectionKey {
continue
}
for _, check := range section.Checks {
if check.Key == checkKey {
return check, true
}
}
}
return domain.DataIntegrityCheck{}, false
}

View File

@@ -0,0 +1,10 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
)
type FederationRepository interface {
FindProviderByID(ctx context.Context, providerID string) (*domain.IdentityProviderConfig, error)
}

View File

@@ -0,0 +1,24 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"gorm.io/gorm"
)
type GormFederationRepository struct {
db *gorm.DB
}
func NewGormFederationRepository(db *gorm.DB) *GormFederationRepository {
return &GormFederationRepository{db: db}
}
func (r *GormFederationRepository) FindProviderByID(ctx context.Context, providerID string) (*domain.IdentityProviderConfig, error) {
var provider domain.IdentityProviderConfig
if err := r.db.WithContext(ctx).First(&provider, "id = ?", providerID).Error; err != nil {
return nil, err
}
return &provider, nil
}

View File

@@ -0,0 +1,88 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"time"
"gorm.io/gorm"
)
type KetoOutboxRepository interface {
Create(ctx context.Context, entry *domain.KetoOutbox) error
CreateWithTx(tx *gorm.DB, entry *domain.KetoOutbox) error
FindPending(ctx context.Context, limit int) ([]domain.KetoOutbox, error)
ListCurrentBySubject(ctx context.Context, namespace, subject string) ([]domain.KetoOutbox, error)
UpdateStatus(ctx context.Context, id string, status string, retryCount int, lastError string) error
MarkProcessed(ctx context.Context, id string) error
}
type ketoOutboxRepository struct {
db *gorm.DB
}
func NewKetoOutboxRepository(db *gorm.DB) KetoOutboxRepository {
return &ketoOutboxRepository{db: db}
}
func (r *ketoOutboxRepository) Create(ctx context.Context, entry *domain.KetoOutbox) error {
return r.db.WithContext(ctx).Create(entry).Error
}
func (r *ketoOutboxRepository) CreateWithTx(tx *gorm.DB, entry *domain.KetoOutbox) error {
return tx.Create(entry).Error
}
func (r *ketoOutboxRepository) FindPending(ctx context.Context, limit int) ([]domain.KetoOutbox, error) {
var entries []domain.KetoOutbox
err := r.db.WithContext(ctx).
Where("status = ?", domain.KetoOutboxStatusPending).
Order("created_at asc").
Limit(limit).
Find(&entries).Error
return entries, err
}
func (r *ketoOutboxRepository) ListCurrentBySubject(ctx context.Context, namespace, subject string) ([]domain.KetoOutbox, error) {
var entries []domain.KetoOutbox
if err := r.db.WithContext(ctx).
Where("namespace = ? AND subject = ? AND status <> ?", namespace, subject, domain.KetoOutboxStatusFailed).
Order("created_at desc").
Order("updated_at desc").
Find(&entries).Error; err != nil {
return nil, err
}
current := make([]domain.KetoOutbox, 0, len(entries))
seen := make(map[string]struct{}, len(entries))
for _, entry := range entries {
key := entry.Namespace + "\x00" + entry.Object + "\x00" + entry.Relation + "\x00" + entry.Subject
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}
if entry.Action == domain.KetoOutboxActionCreate {
current = append(current, entry)
}
}
return current, nil
}
func (r *ketoOutboxRepository) UpdateStatus(ctx context.Context, id string, status string, retryCount int, lastError string) error {
return r.db.WithContext(ctx).Model(&domain.KetoOutbox{}).Where("id = ?", id).Updates(map[string]any{
"status": status,
"retry_count": retryCount,
"last_error": lastError,
"updated_at": time.Now(),
}).Error
}
func (r *ketoOutboxRepository) MarkProcessed(ctx context.Context, id string) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&domain.KetoOutbox{}).Where("id = ?", id).Updates(map[string]any{
"status": domain.KetoOutboxStatusProcessed,
"processed_at": &now,
"updated_at": now,
}).Error
}

View File

@@ -0,0 +1,68 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestKetoOutboxRepository_ListCurrentBySubject(t *testing.T) {
repo := NewKetoOutboxRepository(testDB)
ctx := context.Background()
require.NoError(t, testDB.Exec("DELETE FROM keto_outbox").Error)
entries := []domain.KetoOutbox{
{
Namespace: "RelyingParty",
Object: "client-1",
Relation: "admins",
Subject: "User:user-1",
Action: domain.KetoOutboxActionCreate,
Status: domain.KetoOutboxStatusProcessed,
},
{
Namespace: "RelyingParty",
Object: "client-1",
Relation: "admins",
Subject: "User:user-1",
Action: domain.KetoOutboxActionDelete,
Status: domain.KetoOutboxStatusProcessed,
},
{
Namespace: "RelyingParty",
Object: "client-2",
Relation: "config_editor",
Subject: "User:user-1",
Action: domain.KetoOutboxActionCreate,
Status: domain.KetoOutboxStatusProcessed,
},
{
Namespace: "RelyingParty",
Object: "client-3",
Relation: "audit_viewer",
Subject: "User:user-1",
Action: domain.KetoOutboxActionCreate,
Status: domain.KetoOutboxStatusFailed,
},
{
Namespace: "Tenant",
Object: "tenant-1",
Relation: "members",
Subject: "User:user-1",
Action: domain.KetoOutboxActionCreate,
Status: domain.KetoOutboxStatusProcessed,
},
}
for i := range entries {
require.NoError(t, repo.Create(ctx, &entries[i]))
}
current, err := repo.ListCurrentBySubject(ctx, "RelyingParty", "User:user-1")
require.NoError(t, err)
require.Len(t, current, 1)
require.Equal(t, "client-2", current[0].Object)
require.Equal(t, "config_editor", current[0].Relation)
}

View File

@@ -0,0 +1,74 @@
package repository
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/testsupport"
"context"
"log"
"os"
"testing"
"time"
"github.com/testcontainers/testcontainers-go"
postgres_module "github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
gorm_postgres "gorm.io/driver/postgres"
"gorm.io/gorm"
)
var testDB *gorm.DB
func TestMain(m *testing.M) {
if !testsupport.DockerAvailable() {
log.Printf("skipping repository tests: Docker provider is unavailable in this environment")
os.Exit(0)
}
ctx := context.Background()
// Start PostgreSQL container
dbName := "testdb"
dbUser := "user"
dbPassword := "password"
postgresContainer, err := postgres_module.Run(ctx,
"postgres:16-alpine",
postgres_module.WithDatabase(dbName),
postgres_module.WithUsername(dbUser),
postgres_module.WithPassword(dbPassword),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(30*time.Second)),
)
if err != nil {
log.Fatalf("failed to start container: %s", err)
}
defer func() {
if err := postgresContainer.Terminate(ctx); err != nil {
log.Fatalf("failed to terminate container: %s", err)
}
}()
connStr, err := postgresContainer.ConnectionString(ctx, "sslmode=disable")
if err != nil {
log.Fatalf("failed to get connection string: %s", err)
}
// Connect to test database
db, err := gorm.Open(gorm_postgres.Open(connStr), &gorm.Config{})
if err != nil {
log.Fatalf("failed to connect to database: %s", err)
}
// Auto-migrate
err = db.AutoMigrate(&domain.Tenant{}, &domain.TenantDomain{}, &domain.User{}, &domain.UserLoginID{}, &domain.UserProjectionState{}, &domain.ClientConsent{}, &domain.RPUserMetadata{}, &domain.RPUsageEvent{}, &domain.KetoOutbox{}, &domain.WorksmobileOutbox{})
if err != nil {
log.Fatalf("failed to migrate database: %s", err)
}
testDB = db
os.Exit(m.Run())
}

View File

@@ -0,0 +1,160 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"fmt"
"strings"
"github.com/ClickHouse/clickhouse-go/v2"
"github.com/ClickHouse/clickhouse-go/v2/lib/driver"
)
type OathkeeperClickHouseRepository struct {
conn driver.Conn
}
func NewOathkeeperClickHouseRepository(host string, port int, user, password, db string) (*OathkeeperClickHouseRepository, error) {
conn, err := clickhouse.Open(&clickhouse.Options{
Addr: []string{fmt.Sprintf("%s:%d", host, port)},
Auth: clickhouse.Auth{
Database: db,
Username: user,
Password: password,
},
Debug: false,
})
if err != nil {
return nil, fmt.Errorf("failed to open ory clickhouse connection: %w", err)
}
if err := conn.Ping(context.Background()); err != nil {
return nil, fmt.Errorf("failed to ping ory clickhouse: %w", err)
}
return &OathkeeperClickHouseRepository{conn: conn}, nil
}
func (r *OathkeeperClickHouseRepository) FindPageBySubject(ctx context.Context, subject string, limit int, cursor *domain.AuditCursor) ([]domain.OathkeeperAccessLog, error) {
if limit <= 0 {
limit = 50
}
query, args := buildOathkeeperQuery(subject, limit, cursor, true)
rows, err := r.conn.Query(ctx, query, args...)
if err != nil && isMissingColumnError(err, "client_id") {
query, args = buildOathkeeperQuery(subject, limit, cursor, false)
rows, err = r.conn.Query(ctx, query, args...)
}
if err != nil {
return nil, fmt.Errorf("failed to query oathkeeper logs: %w", err)
}
defer rows.Close()
withClientID := strings.Contains(query, "client_id")
var logs []domain.OathkeeperAccessLog
for rows.Next() {
var log domain.OathkeeperAccessLog
if withClientID {
if err := rows.Scan(
&log.Timestamp,
&log.RequestID,
&log.Method,
&log.Path,
&log.Status,
&log.LatencyMs,
&log.ClientID,
&log.RP,
&log.Action,
&log.Target,
&log.Subject,
&log.ClientIP,
&log.UserAgent,
&log.Decision,
&log.TraceID,
&log.SpanID,
&log.Raw,
); err != nil {
return nil, fmt.Errorf("failed to scan oathkeeper log: %w", err)
}
} else {
if err := rows.Scan(
&log.Timestamp,
&log.RequestID,
&log.Method,
&log.Path,
&log.Status,
&log.LatencyMs,
&log.RP,
&log.Action,
&log.Target,
&log.Subject,
&log.ClientIP,
&log.UserAgent,
&log.Decision,
&log.TraceID,
&log.SpanID,
&log.Raw,
); err != nil {
return nil, fmt.Errorf("failed to scan oathkeeper log: %w", err)
}
}
logs = append(logs, log)
}
return logs, nil
}
func buildOathkeeperQuery(subject string, limit int, cursor *domain.AuditCursor, withClientID bool) (string, []any) {
selectCols := "timestamp, request_id, method, path, status, latency_ms, rp, action, target, subject, client_ip, user_agent, decision, trace_id, span_id, raw"
if withClientID {
selectCols = "timestamp, request_id, method, path, status, latency_ms, client_id, rp, action, target, subject, client_ip, user_agent, decision, trace_id, span_id, raw"
}
query := fmt.Sprintf("SELECT %s FROM oathkeeper_access_logs", selectCols)
args := make([]any, 0, 5)
if subject != "" {
query += `
WHERE subject = ?
`
args = append(args, subject)
if cursor != nil {
query += `
AND ((timestamp < ?) OR (timestamp = ? AND request_id < ?))
`
args = append(args, cursor.Timestamp, cursor.Timestamp, cursor.EventID)
}
} else if cursor != nil {
query += `
WHERE (timestamp < ?) OR (timestamp = ? AND request_id < ?)
`
args = append(args, cursor.Timestamp, cursor.Timestamp, cursor.EventID)
}
query += `
ORDER BY timestamp DESC, request_id DESC
LIMIT ?
`
args = append(args, limit)
return query, args
}
func isMissingColumnError(err error, column string) bool {
if err == nil {
return false
}
msg := strings.ToLower(err.Error())
column = strings.ToLower(column)
if strings.Contains(msg, "unknown identifier") && strings.Contains(msg, column) {
return true
}
if strings.Contains(msg, "unknown expression identifier") && strings.Contains(msg, column) {
return true
}
if strings.Contains(msg, "missing columns") && strings.Contains(msg, column) {
return true
}
return false
}
func (r *OathkeeperClickHouseRepository) Ping(ctx context.Context) error {
if r == nil || r.conn == nil {
return fmt.Errorf("ory clickhouse connection is nil")
}
return r.conn.Ping(ctx)
}

View File

@@ -0,0 +1,61 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"gorm.io/gorm"
)
type RelyingPartyRepository interface {
Create(ctx context.Context, rp *domain.RelyingParty) error
Update(ctx context.Context, rp *domain.RelyingParty) error
Delete(ctx context.Context, clientID string) error
FindByID(ctx context.Context, clientID string) (*domain.RelyingParty, error)
ListByTenantID(ctx context.Context, tenantID string) ([]domain.RelyingParty, error)
ListAll(ctx context.Context) ([]domain.RelyingParty, error)
}
func (r *relyingPartyRepository) ListAll(ctx context.Context) ([]domain.RelyingParty, error) {
var rps []domain.RelyingParty
if err := r.db.WithContext(ctx).Find(&rps).Error; err != nil {
return nil, err
}
return rps, nil
}
type relyingPartyRepository struct {
db *gorm.DB
}
func NewRelyingPartyRepository(db *gorm.DB) RelyingPartyRepository {
return &relyingPartyRepository{db: db}
}
func (r *relyingPartyRepository) Create(ctx context.Context, rp *domain.RelyingParty) error {
return r.db.WithContext(ctx).Create(rp).Error
}
func (r *relyingPartyRepository) Update(ctx context.Context, rp *domain.RelyingParty) error {
return r.db.WithContext(ctx).Save(rp).Error
}
func (r *relyingPartyRepository) Delete(ctx context.Context, clientID string) error {
return r.db.WithContext(ctx).Delete(&domain.RelyingParty{}, "client_id = ?", clientID).Error
}
func (r *relyingPartyRepository) FindByID(ctx context.Context, clientID string) (*domain.RelyingParty, error) {
var rp domain.RelyingParty
if err := r.db.WithContext(ctx).First(&rp, "client_id = ?", clientID).Error; err != nil {
return nil, err
}
return &rp, nil
}
func (r *relyingPartyRepository) ListByTenantID(ctx context.Context, tenantID string) ([]domain.RelyingParty, error) {
var rps []domain.RelyingParty
if err := r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Find(&rps).Error; err != nil {
return nil, err
}
return rps, nil
}

View File

@@ -0,0 +1,91 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type RPUsageOutboxRepository interface {
Create(ctx context.Context, event *domain.RPUsageEvent) error
ListReady(ctx context.Context, limit int) ([]domain.RPUsageEvent, error)
MarkProcessing(ctx context.Context, id string) error
MarkProcessed(ctx context.Context, id string) error
MarkFailed(ctx context.Context, id string, message string, nextAttemptAt time.Time) error
}
type rpUsageOutboxRepository struct {
db *gorm.DB
}
func NewRPUsageOutboxRepository(db *gorm.DB) RPUsageOutboxRepository {
return &rpUsageOutboxRepository{db: db}
}
func (r *rpUsageOutboxRepository) Create(ctx context.Context, event *domain.RPUsageEvent) error {
if event.Payload == nil {
event.Payload = domain.JSONMap{}
}
if event.Status == "" {
event.Status = domain.RPUsageOutboxStatusPending
}
if event.OccurredAt.IsZero() {
event.OccurredAt = time.Now()
}
return r.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "dedupe_key"}},
DoNothing: true,
}).Create(event).Error
}
func (r *rpUsageOutboxRepository) ListReady(ctx context.Context, limit int) ([]domain.RPUsageEvent, error) {
if limit <= 0 || limit > 100 {
limit = 20
}
var rows []domain.RPUsageEvent
err := r.db.WithContext(ctx).
Where("status = ? AND (next_attempt_at IS NULL OR next_attempt_at <= ?)", domain.RPUsageOutboxStatusPending, time.Now()).
Order("occurred_at asc, created_at asc").
Limit(limit).
Find(&rows).Error
return rows, err
}
func (r *rpUsageOutboxRepository) MarkProcessing(ctx context.Context, id string) error {
return r.db.WithContext(ctx).
Model(&domain.RPUsageEvent{}).
Where("id = ? AND status = ?", id, domain.RPUsageOutboxStatusPending).
Updates(map[string]any{
"status": domain.RPUsageOutboxStatusProcessing,
"updated_at": time.Now(),
}).Error
}
func (r *rpUsageOutboxRepository) MarkProcessed(ctx context.Context, id string) error {
now := time.Now()
return r.db.WithContext(ctx).
Model(&domain.RPUsageEvent{}).
Where("id = ?", id).
Updates(map[string]any{
"status": domain.RPUsageOutboxStatusProcessed,
"last_error": "",
"processed_at": &now,
"updated_at": now,
}).Error
}
func (r *rpUsageOutboxRepository) MarkFailed(ctx context.Context, id string, message string, nextAttemptAt time.Time) error {
return r.db.WithContext(ctx).
Model(&domain.RPUsageEvent{}).
Where("id = ?", id).
Updates(map[string]any{
"status": domain.RPUsageOutboxStatusFailed,
"retry_count": gorm.Expr("retry_count + 1"),
"last_error": message,
"next_attempt_at": &nextAttemptAt,
"updated_at": time.Now(),
}).Error
}

View File

@@ -0,0 +1,40 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type RPUserMetadataRepository interface {
Get(ctx context.Context, clientID, userID string) (*domain.RPUserMetadata, error)
Upsert(ctx context.Context, metadata *domain.RPUserMetadata) error
}
type rpUserMetadataRepository struct {
db *gorm.DB
}
func NewRPUserMetadataRepository(db *gorm.DB) RPUserMetadataRepository {
return &rpUserMetadataRepository{db: db}
}
func (r *rpUserMetadataRepository) Get(ctx context.Context, clientID, userID string) (*domain.RPUserMetadata, error) {
var metadata domain.RPUserMetadata
if err := r.db.WithContext(ctx).First(&metadata, "client_id = ? AND user_id = ?", clientID, userID).Error; err != nil {
return nil, err
}
return &metadata, nil
}
func (r *rpUserMetadataRepository) Upsert(ctx context.Context, metadata *domain.RPUserMetadata) error {
return r.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "client_id"},
{Name: "user_id"},
},
DoUpdates: clause.AssignmentColumns([]string{"metadata", "updated_at"}),
}).Create(metadata).Error
}

View File

@@ -0,0 +1,51 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"gorm.io/gorm"
)
type SharedLinkRepository interface {
Create(ctx context.Context, link *domain.SharedLink) error
FindByToken(ctx context.Context, token string) (*domain.SharedLink, error)
FindByTenantID(ctx context.Context, tenantID string) ([]domain.SharedLink, error)
Delete(ctx context.Context, id string) error
Update(ctx context.Context, link *domain.SharedLink) error
}
type sharedLinkRepository struct {
db *gorm.DB
}
func NewSharedLinkRepository(db *gorm.DB) SharedLinkRepository {
return &sharedLinkRepository{db: db}
}
func (r *sharedLinkRepository) Create(ctx context.Context, link *domain.SharedLink) error {
return r.db.WithContext(ctx).Create(link).Error
}
func (r *sharedLinkRepository) FindByToken(ctx context.Context, token string) (*domain.SharedLink, error) {
var link domain.SharedLink
err := r.db.WithContext(ctx).Where("token = ? AND is_active = ?", token, true).First(&link).Error
if err != nil {
return nil, err
}
return &link, nil
}
func (r *sharedLinkRepository) FindByTenantID(ctx context.Context, tenantID string) ([]domain.SharedLink, error) {
var links []domain.SharedLink
err := r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Find(&links).Error
return links, err
}
func (r *sharedLinkRepository) Delete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&domain.SharedLink{}, "id = ?", id).Error
}
func (r *sharedLinkRepository) Update(ctx context.Context, link *domain.SharedLink) error {
return r.db.WithContext(ctx).Save(link).Error
}

View File

@@ -0,0 +1,185 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"errors"
"strconv"
"strings"
"time"
"gorm.io/gorm"
)
type TenantRepository interface {
Create(ctx context.Context, tenant *domain.Tenant) error
Update(ctx context.Context, tenant *domain.Tenant) error
FindByID(ctx context.Context, id string) (*domain.Tenant, error)
FindBySlug(ctx context.Context, slug string) (*domain.Tenant, error)
FindByName(ctx context.Context, name string) (*domain.Tenant, error)
FindByDomain(ctx context.Context, domainName string) (*domain.Tenant, error)
FindByIDs(ctx context.Context, ids []string) ([]domain.Tenant, error)
AddDomain(ctx context.Context, tenantID string, domainName string, verified bool) error
List(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error)
ListByType(ctx context.Context, tenantType string) ([]domain.Tenant, error)
DeleteBulk(ctx context.Context, ids []string) error
}
type tenantRepository struct {
db *gorm.DB
}
func NewTenantRepository(db *gorm.DB) TenantRepository {
return &tenantRepository{db: db}
}
func (r *tenantRepository) Create(ctx context.Context, tenant *domain.Tenant) error {
tenant.Slug = strings.ToLower(strings.TrimSpace(tenant.Slug))
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if tenant.Slug != "" {
suffix := "-deleted-" + strconv.FormatInt(time.Now().UTC().UnixNano(), 10)
if err := tx.Unscoped().
Model(&domain.Tenant{}).
Where("slug = ? AND deleted_at IS NOT NULL", tenant.Slug).
Update("slug", gorm.Expr("slug || ?", suffix)).Error; err != nil {
return err
}
}
return tx.Create(tenant).Error
})
}
func (r *tenantRepository) Update(ctx context.Context, tenant *domain.Tenant) error {
return r.db.WithContext(ctx).Save(tenant).Error
}
func (r *tenantRepository) FindByID(ctx context.Context, id string) (*domain.Tenant, error) {
var tenant domain.Tenant
if err := r.db.WithContext(ctx).Preload("Domains").First(&tenant, "id = ?", id).Error; err != nil {
return nil, err
}
return &tenant, nil
}
func (r *tenantRepository) FindBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
var tenant domain.Tenant
if err := r.db.WithContext(ctx).Preload("Domains").Where("slug = ?", strings.ToLower(slug)).First(&tenant).Error; err != nil {
return nil, err
}
return &tenant, nil
}
func (r *tenantRepository) FindByName(ctx context.Context, name string) (*domain.Tenant, error) {
var tenant domain.Tenant
if err := r.db.WithContext(ctx).Preload("Domains").Where("name = ?", name).First(&tenant).Error; err != nil {
return nil, err
}
return &tenant, nil
}
func (r *tenantRepository) FindByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
var tenantDomain domain.TenantDomain
if err := r.db.WithContext(ctx).Where("domain = ?", domainName).First(&tenantDomain).Error; err != nil {
return nil, err
}
var tenant domain.Tenant
if err := r.db.WithContext(ctx).Preload("Domains").First(&tenant, "id = ?", tenantDomain.TenantID).Error; err != nil {
return nil, err
}
return &tenant, nil
}
func (r *tenantRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.Tenant, error) {
var tenants []domain.Tenant
if len(ids) == 0 {
return tenants, nil
}
if err := r.db.WithContext(ctx).Preload("Domains").Where("id IN ?", ids).Find(&tenants).Error; err != nil {
return nil, err
}
return tenants, nil
}
func (r *tenantRepository) AddDomain(ctx context.Context, tenantID string, domainName string, verified bool) error {
var existing domain.TenantDomain
err := r.db.WithContext(ctx).Unscoped().
Where("tenant_id = ? AND domain = ?", tenantID, domainName).
First(&existing).Error
if err == nil {
return r.db.WithContext(ctx).Unscoped().Model(&existing).Updates(map[string]any{
"verified": verified,
"deleted_at": nil,
}).Error
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
td := domain.TenantDomain{
TenantID: tenantID,
Domain: domainName,
Verified: verified,
}
return r.db.WithContext(ctx).Create(&td).Error
}
func (r *tenantRepository) List(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error) {
var tenants []domain.Tenant
var total int64
db := r.db.WithContext(ctx).Model(&domain.Tenant{})
if parentID != "" {
db = db.Where("parent_id = ?", parentID)
}
if search != "" {
searchTerm := "%" + strings.ToLower(search) + "%"
db = db.Where("LOWER(name) LIKE ? OR LOWER(slug) LIKE ? OR LOWER(description) LIKE ?", searchTerm, searchTerm, searchTerm)
}
if err := db.Count(&total).Error; err != nil {
return nil, 0, err
}
if err := db.Order("created_at desc, id desc").Limit(limit).Offset(offset).Preload("Domains").Find(&tenants).Error; err != nil {
return nil, 0, err
}
return tenants, total, nil
}
func (r *tenantRepository) ListByType(ctx context.Context, tenantType string) ([]domain.Tenant, error) {
var tenants []domain.Tenant
if err := r.db.WithContext(ctx).Where("type = ?", tenantType).Preload("Domains").Find(&tenants).Error; err != nil {
return nil, err
}
return tenants, nil
}
func (r *tenantRepository) DeleteBulk(ctx context.Context, ids []string) error {
if len(ids) == 0 {
return nil
}
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. Release slugs for all target tenants to allow reuse
suffix := "-deleted-" + time.Now().Format("20060102150405")
if err := tx.Model(&domain.Tenant{}).Where("id IN ?", ids).
Update("slug", gorm.Expr("slug || ?", suffix)).Error; err != nil {
return err
}
// 2. Soft delete tenants
if err := tx.Where("id IN ?", ids).Delete(&domain.Tenant{}).Error; err != nil {
return err
}
// 3. Also delete related UserGroups if any (Type USER_GROUP tenants have records in user_groups table)
if err := tx.Where("id IN ?", ids).Delete(&domain.UserGroup{}).Error; err != nil {
return err
}
return nil
})
}

View File

@@ -0,0 +1,192 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestTenantRepository(t *testing.T) {
repo := NewTenantRepository(testDB)
ctx := context.Background()
t.Run("Create and FindByID", func(t *testing.T) {
tenant := &domain.Tenant{
Name: "Test Tenant",
Slug: "test-tenant",
Type: domain.TenantTypeCompany,
}
err := repo.Create(ctx, tenant)
assert.NoError(t, err)
assert.NotEmpty(t, tenant.ID)
found, err := repo.FindByID(ctx, tenant.ID)
assert.NoError(t, err)
assert.Equal(t, tenant.Name, found.Name)
assert.Equal(t, tenant.Slug, found.Slug)
})
t.Run("FindBySlug", func(t *testing.T) {
tenant := &domain.Tenant{
Name: "Slug Test",
Slug: "slug-test",
Type: domain.TenantTypeCompany,
}
_ = repo.Create(ctx, tenant)
found, err := repo.FindBySlug(ctx, "slug-test")
assert.NoError(t, err)
assert.Equal(t, tenant.ID, found.ID)
})
t.Run("AddDomain and FindByDomain", func(t *testing.T) {
tenant := &domain.Tenant{
Name: "Domain Test",
Slug: "domain-test",
Type: domain.TenantTypeCompany,
}
_ = repo.Create(ctx, tenant)
err := repo.AddDomain(ctx, tenant.ID, "test-domain.com", true)
assert.NoError(t, err)
found, err := repo.FindByDomain(ctx, "test-domain.com")
assert.NoError(t, err)
assert.Equal(t, tenant.ID, found.ID)
assert.Len(t, found.Domains, 1)
assert.Equal(t, "test-domain.com", found.Domains[0].Domain)
})
t.Run("AddDomain allows same domain on multiple tenants", func(t *testing.T) {
first := &domain.Tenant{
Name: "Saman Existing",
Slug: "saman-existing",
Type: domain.TenantTypeCompany,
}
second := &domain.Tenant{
Name: "Saman Current",
Slug: "saman-current",
Type: domain.TenantTypeCompany,
}
assert.NoError(t, repo.Create(ctx, first))
assert.NoError(t, repo.Create(ctx, second))
assert.NoError(t, repo.AddDomain(ctx, first.ID, "samaneng.com", true))
assert.NoError(t, repo.AddDomain(ctx, second.ID, "samaneng.com", true))
var rows []domain.TenantDomain
err := testDB.Where("domain = ?", "samaneng.com").Find(&rows).Error
assert.NoError(t, err)
assert.Len(t, rows, 2)
})
t.Run("AddDomain restores deleted tenant domain", func(t *testing.T) {
tenant := &domain.Tenant{
Name: "Domain Restore",
Slug: "domain-restore",
Type: domain.TenantTypeCompany,
}
assert.NoError(t, repo.Create(ctx, tenant))
assert.NoError(t, repo.AddDomain(ctx, tenant.ID, "restore.samaneng.com", true))
assert.NoError(t, testDB.Where("tenant_id = ? AND domain = ?", tenant.ID, "restore.samaneng.com").Delete(&domain.TenantDomain{}).Error)
assert.NoError(t, repo.AddDomain(ctx, tenant.ID, "restore.samaneng.com", true))
var rows []domain.TenantDomain
err := testDB.Where("tenant_id = ? AND domain = ?", tenant.ID, "restore.samaneng.com").Find(&rows).Error
assert.NoError(t, err)
if assert.Len(t, rows, 1) {
assert.True(t, rows[0].Verified)
}
})
t.Run("Update", func(t *testing.T) {
tenant := &domain.Tenant{
Name: "Before Update",
Slug: "before-update",
Type: domain.TenantTypeCompany,
}
_ = repo.Create(ctx, tenant)
tenant.Name = "After Update"
err := repo.Update(ctx, tenant)
assert.NoError(t, err)
found, err := repo.FindByID(ctx, tenant.ID)
assert.NoError(t, err)
assert.Equal(t, "After Update", found.Name)
})
t.Run("Hierarchy", func(t *testing.T) {
parent := &domain.Tenant{
Name: "Parent Tenant",
Slug: "parent-hierarchy",
Type: domain.TenantTypeCompanyGroup,
}
err := repo.Create(ctx, parent)
assert.NoError(t, err)
child := &domain.Tenant{
Name: "Child Tenant",
Slug: "child-hierarchy",
Type: domain.TenantTypeCompany,
ParentID: &parent.ID,
}
err = repo.Create(ctx, child)
assert.NoError(t, err)
foundChild, err := repo.FindByID(ctx, child.ID)
assert.NoError(t, err)
assert.Equal(t, parent.ID, *foundChild.ParentID)
})
t.Run("Unique Constraint on Slug", func(t *testing.T) {
slug := "unique-slug-test"
tenant1 := &domain.Tenant{
Name: "First",
Slug: slug,
Type: domain.TenantTypeCompany,
}
err := repo.Create(ctx, tenant1)
assert.NoError(t, err)
tenant2 := &domain.Tenant{
Name: "Second",
Slug: slug,
Type: domain.TenantTypeCompany,
}
err = repo.Create(ctx, tenant2)
assert.Error(t, err) // Should fail due to UNIQUE constraint
})
t.Run("Create reuses slug held by legacy soft-deleted tenant", func(t *testing.T) {
slug := "legacy-soft-delete-reuse"
require.NoError(t, testDB.Unscoped().Where("slug = ?", slug).Delete(&domain.Tenant{}).Error)
legacy := &domain.Tenant{
Name: "Legacy Deleted",
Slug: slug,
Type: domain.TenantTypeCompany,
}
require.NoError(t, repo.Create(ctx, legacy))
require.NoError(t, testDB.Delete(&domain.Tenant{}, "id = ?", legacy.ID).Error)
_, err := repo.FindBySlug(ctx, slug)
require.Error(t, err)
replacement := &domain.Tenant{
Name: "Replacement",
Slug: slug,
Type: domain.TenantTypeCompany,
}
require.NoError(t, repo.Create(ctx, replacement))
found, err := repo.FindBySlug(ctx, slug)
require.NoError(t, err)
assert.Equal(t, replacement.ID, found.ID)
})
}

View File

@@ -0,0 +1,53 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"gorm.io/gorm"
)
type UserGroupRepository interface {
Create(ctx context.Context, group *domain.UserGroup) error
Update(ctx context.Context, group *domain.UserGroup) error
Delete(ctx context.Context, id string) error
FindByID(ctx context.Context, id string) (*domain.UserGroup, error)
ListByTenantID(ctx context.Context, tenantID string) ([]domain.UserGroup, error)
}
type userGroupRepository struct {
db *gorm.DB
}
func NewUserGroupRepository(db *gorm.DB) UserGroupRepository {
return &userGroupRepository{db: db}
}
func (r *userGroupRepository) Create(ctx context.Context, group *domain.UserGroup) error {
return r.db.WithContext(ctx).Create(group).Error
}
func (r *userGroupRepository) Update(ctx context.Context, group *domain.UserGroup) error {
return r.db.WithContext(ctx).Save(group).Error
}
func (r *userGroupRepository) Delete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&domain.UserGroup{}, "id = ?", id).Error
}
func (r *userGroupRepository) FindByID(ctx context.Context, id string) (*domain.UserGroup, error) {
var group domain.UserGroup
// Using Where to be more explicit and avoid issues with GORM's default primary key handling if ID is string/uuid
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&group).Error; err != nil {
return nil, err
}
return &group, nil
}
func (r *userGroupRepository) ListByTenantID(ctx context.Context, tenantID string) ([]domain.UserGroup, error) {
var groups []domain.UserGroup
if err := r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Find(&groups).Error; err != nil {
return nil, err
}
return groups, nil
}

View File

@@ -0,0 +1,51 @@
package repository
import (
"context"
"gorm.io/gorm"
)
func CountOrphanUserTenantMemberships(ctx context.Context, db *gorm.DB) (int64, error) {
var count int64
err := db.WithContext(ctx).Raw(`
SELECT COUNT(*)
FROM users AS u
WHERE u.deleted_at IS NULL
AND (
u.tenant_id IS NOT NULL
AND NOT EXISTS (
SELECT 1
FROM tenants AS t
WHERE t.id = u.tenant_id
AND t.deleted_at IS NULL
)
)
`).Scan(&count).Error
return count, err
}
func ClearOrphanUserTenantMemberships(ctx context.Context, db *gorm.DB) (int64, error) {
result := db.WithContext(ctx).Exec(`
WITH orphan_users AS (
SELECT u.id
FROM users AS u
WHERE u.deleted_at IS NULL
AND (
u.tenant_id IS NOT NULL
AND NOT EXISTS (
SELECT 1
FROM tenants AS t
WHERE t.id = u.tenant_id
AND t.deleted_at IS NULL
)
)
)
UPDATE users AS u
SET tenant_id = NULL,
updated_at = NOW()
FROM orphan_users AS ou
WHERE u.id = ou.id
`)
return result.RowsAffected, result.Error
}

View File

@@ -0,0 +1,64 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestClearOrphanUserTenantMemberships(t *testing.T) {
ctx := context.Background()
repo := NewUserRepository(testDB)
tenantRepo := NewTenantRepository(testDB)
require.NoError(t, testDB.Exec("DELETE FROM user_login_ids").Error)
require.NoError(t, testDB.Exec("DELETE FROM users").Error)
require.NoError(t, testDB.Exec("DELETE FROM tenant_domains").Error)
require.NoError(t, testDB.Unscoped().Where("slug IN ?", []string{"orphan-active", "orphan-deleted"}).Delete(&domain.Tenant{}).Error)
activeTenant := &domain.Tenant{Name: "Active Tenant", Slug: "orphan-active", Type: domain.TenantTypeCompany}
deletedTenant := &domain.Tenant{Name: "Deleted Tenant", Slug: "orphan-deleted", Type: domain.TenantTypeCompany}
require.NoError(t, tenantRepo.Create(ctx, activeTenant))
require.NoError(t, tenantRepo.Create(ctx, deletedTenant))
require.NoError(t, testDB.Delete(&domain.Tenant{}, "id = ?", deletedTenant.ID).Error)
activeUser := &domain.User{
Email: "active-membership@example.com",
Name: "Active Membership",
Role: "user",
TenantID: &activeTenant.ID,
}
orphanUser := &domain.User{
Email: "orphan-membership@example.com",
Name: "Orphan Membership",
Role: "user",
TenantID: &deletedTenant.ID,
}
require.NoError(t, repo.Create(ctx, activeUser))
require.NoError(t, repo.Create(ctx, orphanUser))
count, err := CountOrphanUserTenantMemberships(ctx, testDB)
require.NoError(t, err)
assert.Equal(t, int64(1), count)
affected, err := ClearOrphanUserTenantMemberships(ctx, testDB)
require.NoError(t, err)
assert.Equal(t, int64(1), affected)
foundActive, err := repo.FindByEmail(ctx, activeUser.Email)
require.NoError(t, err)
require.NotNil(t, foundActive.TenantID)
require.NotNil(t, foundActive.Tenant)
assert.Equal(t, activeTenant.ID, *foundActive.TenantID)
foundOrphan, err := repo.FindByEmail(ctx, orphanUser.Email)
require.NoError(t, err)
assert.Nil(t, foundOrphan.TenantID)
count, err = CountOrphanUserTenantMemberships(ctx, testDB)
require.NoError(t, err)
assert.Equal(t, int64(0), count)
}

View File

@@ -0,0 +1,227 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"errors"
"fmt"
"strings"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type UserProjectionRepository interface {
IsReady(ctx context.Context) (bool, error)
GetStatus(ctx context.Context) (domain.UserProjectionStatus, error)
CountTenantMembers(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error)
CountTenantMembersRecursive(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error)
ReplaceAllFromKratos(ctx context.Context, users []domain.User) error
MarkFailed(ctx context.Context, syncErr error) error
}
type userProjectionRepository struct {
db *gorm.DB
}
func NewUserProjectionRepository(db *gorm.DB) UserProjectionRepository {
return &userProjectionRepository{db: db}
}
func (r *userProjectionRepository) IsReady(ctx context.Context) (bool, error) {
status, err := r.GetStatus(ctx)
if err != nil {
return false, err
}
return status.Ready, nil
}
func (r *userProjectionRepository) GetStatus(ctx context.Context) (domain.UserProjectionStatus, error) {
var projectedUsers int64
if err := r.db.WithContext(ctx).Model(&domain.User{}).Count(&projectedUsers).Error; err != nil {
return domain.UserProjectionStatus{}, err
}
var state domain.UserProjectionState
err := r.db.WithContext(ctx).First(&state, "name = ?", domain.UserProjectionNameKratos).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return domain.UserProjectionStatus{
Name: domain.UserProjectionNameKratos,
Status: domain.UserProjectionStatusFailed,
Ready: false,
ProjectedUsers: projectedUsers,
}, nil
}
if err != nil {
return domain.UserProjectionStatus{}, err
}
return domain.UserProjectionStatus{
Name: state.Name,
Status: state.Status,
Ready: state.Status == domain.UserProjectionStatusReady && state.LastSyncedAt != nil,
LastSyncedAt: state.LastSyncedAt,
LastError: state.LastError,
UpdatedAt: &state.UpdatedAt,
ProjectedUsers: projectedUsers,
}, nil
}
func (r *userProjectionRepository) CountTenantMembers(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
counts := make(map[string]int64, len(tenants))
for _, tenant := range tenants {
counts[tenant.ID] = 0
}
if len(tenants) == 0 {
return counts, nil
}
valuePlaceholders := make([]string, 0, len(tenants))
args := make([]any, 0, len(tenants)*2)
for _, tenant := range tenants {
valuePlaceholders = append(valuePlaceholders, "(?, ?)")
args = append(args, strings.TrimSpace(tenant.ID), strings.TrimSpace(tenant.Slug))
}
query := fmt.Sprintf(`
WITH requested(tenant_id, slug) AS (
VALUES %s
)
SELECT requested.tenant_id, COUNT(DISTINCT users.id) AS count
FROM requested
LEFT JOIN users ON users.deleted_at IS NULL AND (
users.tenant_id::text = requested.tenant_id
)
GROUP BY requested.tenant_id
`, strings.Join(valuePlaceholders, ","))
type result struct {
TenantID string
Count int64
}
var rows []result
if err := r.db.WithContext(ctx).Raw(query, args...).Scan(&rows).Error; err != nil {
return nil, err
}
for _, row := range rows {
counts[row.TenantID] = row.Count
}
return counts, nil
}
func (r *userProjectionRepository) CountTenantMembersRecursive(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
counts := make(map[string]int64, len(tenants))
for _, tenant := range tenants {
counts[tenant.ID] = 0
}
if len(tenants) == 0 {
return counts, nil
}
valuePlaceholders := make([]string, 0, len(tenants))
args := make([]any, 0, len(tenants))
for _, tenant := range tenants {
valuePlaceholders = append(valuePlaceholders, "(?)")
args = append(args, strings.TrimSpace(tenant.ID))
}
query := fmt.Sprintf(`
WITH RECURSIVE requested(tenant_id) AS (
VALUES %s
),
descendants(root_tenant_id, tenant_id) AS (
SELECT requested.tenant_id, requested.tenant_id
FROM requested
UNION ALL
SELECT descendants.root_tenant_id, child.id::text
FROM descendants
JOIN tenants child
ON child.parent_id::text = descendants.tenant_id
AND child.deleted_at IS NULL
)
SELECT requested.tenant_id, COUNT(DISTINCT users.id) AS count
FROM requested
LEFT JOIN descendants
ON descendants.root_tenant_id = requested.tenant_id
LEFT JOIN users
ON users.deleted_at IS NULL
AND users.tenant_id::text = descendants.tenant_id
GROUP BY requested.tenant_id
`, strings.Join(valuePlaceholders, ","))
type result struct {
TenantID string
Count int64
}
var rows []result
if err := r.db.WithContext(ctx).Raw(query, args...).Scan(&rows).Error; err != nil {
return nil, err
}
for _, row := range rows {
counts[row.TenantID] = row.Count
}
return counts, nil
}
func (r *userProjectionRepository) ReplaceAllFromKratos(ctx context.Context, users []domain.User) error {
now := time.Now()
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range users {
users[i].DeletedAt = gorm.DeletedAt{}
if users[i].CreatedAt.IsZero() {
users[i].CreatedAt = now
}
if users[i].UpdatedAt.IsZero() {
users[i].UpdatedAt = now
}
}
if len(users) > 0 {
// [FIX] Handle email conflicts before bulk upsert
for _, u := range users {
if u.Email != "" {
// Hard-delete any record with same email but different ID to clear unique constraint
_ = tx.Unscoped().Where("email = ? AND id != ?", u.Email, u.ID).Delete(&domain.User{}).Error
}
}
if err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "id"}},
UpdateAll: true,
}).Create(&users).Error; err != nil {
return err
}
}
return upsertUserProjectionState(tx, domain.UserProjectionStatusReady, &now, "")
})
}
func (r *userProjectionRepository) MarkFailed(ctx context.Context, syncErr error) error {
message := ""
if syncErr != nil {
message = syncErr.Error()
}
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
return upsertUserProjectionState(tx, domain.UserProjectionStatusFailed, nil, message)
})
}
func upsertUserProjectionState(tx *gorm.DB, status string, syncedAt *time.Time, lastError string) error {
state := domain.UserProjectionState{
Name: domain.UserProjectionNameKratos,
Status: status,
LastSyncedAt: syncedAt,
LastError: lastError,
UpdatedAt: time.Now(),
}
return tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "name"}},
DoUpdates: clause.AssignmentColumns([]string{
"status",
"last_synced_at",
"last_error",
"updated_at",
}),
}).Create(&state).Error
}

View File

@@ -0,0 +1,168 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestUserProjectionRepository_ReplaceAllFromKratosMarksReadyWithoutDeletingUsersMissingFromPartialList(t *testing.T) {
ctx := context.Background()
repo := NewUserProjectionRepository(testDB)
require.NoError(t, testDB.Exec("DELETE FROM user_projection_states").Error)
require.NoError(t, testDB.Exec("DELETE FROM user_login_ids").Error)
require.NoError(t, testDB.Exec("DELETE FROM users").Error)
tenantID := "10000000-0000-0000-0000-000000000001"
tenantSlug := "projection-saman"
require.NoError(t, testDB.Create(&domain.Tenant{
ID: tenantID,
Name: "Projection Saman",
Slug: tenantSlug,
Type: domain.TenantTypeCompany,
Status: domain.TenantStatusActive,
}).Error)
existing := &domain.User{
ID: "00000000-0000-0000-0000-000000000099",
Email: "existing@example.com",
Name: "Existing",
CompanyCode: tenantSlug,
TenantID: &tenantID,
}
require.NoError(t, NewUserRepository(testDB).Create(ctx, existing))
users := []domain.User{
{
ID: "00000000-0000-0000-0000-000000000101",
Email: "one@example.com",
Name: "One",
CompanyCode: tenantSlug,
TenantID: &tenantID,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
},
{
ID: "00000000-0000-0000-0000-000000000102",
Email: "two@example.com",
Name: "Two",
TenantID: &tenantID,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
},
}
require.NoError(t, repo.ReplaceAllFromKratos(ctx, users))
ready, err := repo.IsReady(ctx)
require.NoError(t, err)
assert.True(t, ready)
counts, err := repo.CountTenantMembers(ctx, []domain.Tenant{
{ID: tenantID, Slug: tenantSlug},
})
require.NoError(t, err)
assert.Equal(t, int64(3), counts[tenantID])
var activeCount int64
require.NoError(t, testDB.Model(&domain.User{}).Count(&activeCount).Error)
assert.Equal(t, int64(3), activeCount)
var existingCount int64
require.NoError(t, testDB.Model(&domain.User{}).Where("id = ?", existing.ID).Count(&existingCount).Error)
assert.Equal(t, int64(1), existingCount)
var existingRow domain.User
require.NoError(t, testDB.Unscoped().First(&existingRow, "id = ?", existing.ID).Error)
assert.False(t, existingRow.DeletedAt.Valid)
}
func TestUserProjectionRepository_CountTenantMembersRecursiveIncludesDescendantsAndExcludesSoftDeletedUsers(t *testing.T) {
ctx := context.Background()
repo := NewUserProjectionRepository(testDB)
parentID := "20000000-0000-0000-0000-000000000001"
childID := "20000000-0000-0000-0000-000000000002"
grandchildID := "20000000-0000-0000-0000-000000000003"
siblingID := "20000000-0000-0000-0000-000000000004"
tenantIDs := []string{parentID, childID, grandchildID, siblingID}
require.NoError(t, testDB.Exec("DELETE FROM user_login_ids").Error)
require.NoError(t, testDB.Exec("DELETE FROM users").Error)
require.NoError(t, testDB.Unscoped().Where("id IN ?", tenantIDs).Delete(&domain.Tenant{}).Error)
require.NoError(t, testDB.Create(&domain.Tenant{
ID: parentID,
Name: "Recursive Parent",
Slug: "recursive-parent",
Type: domain.TenantTypeCompany,
Status: domain.TenantStatusActive,
}).Error)
require.NoError(t, testDB.Create(&domain.Tenant{
ID: childID,
Name: "Recursive Child",
Slug: "recursive-child",
Type: domain.TenantTypeOrganization,
Status: domain.TenantStatusActive,
ParentID: &parentID,
}).Error)
require.NoError(t, testDB.Create(&domain.Tenant{
ID: grandchildID,
Name: "Recursive Grandchild",
Slug: "recursive-grandchild",
Type: domain.TenantTypeUserGroup,
Status: domain.TenantStatusActive,
ParentID: &childID,
}).Error)
require.NoError(t, testDB.Create(&domain.Tenant{
ID: siblingID,
Name: "Recursive Sibling",
Slug: "recursive-sibling",
Type: domain.TenantTypeCompany,
Status: domain.TenantStatusActive,
}).Error)
users := []domain.User{
{ID: "21000000-0000-0000-0000-000000000001", Email: "parent@example.com", Name: "Parent", TenantID: &parentID},
{ID: "21000000-0000-0000-0000-000000000002", Email: "child@example.com", Name: "Child", TenantID: &childID},
{ID: "21000000-0000-0000-0000-000000000003", Email: "grandchild@example.com", Name: "Grandchild", TenantID: &grandchildID},
{ID: "21000000-0000-0000-0000-000000000004", Email: "deleted-grandchild@example.com", Name: "Deleted Grandchild", TenantID: &grandchildID},
{ID: "21000000-0000-0000-0000-000000000005", Email: "sibling@example.com", Name: "Sibling", TenantID: &siblingID},
}
for i := range users {
require.NoError(t, testDB.Create(&users[i]).Error)
}
require.NoError(t, testDB.Delete(&domain.User{}, "id = ?", users[3].ID).Error)
directCounts, err := repo.CountTenantMembers(ctx, []domain.Tenant{{ID: parentID}, {ID: childID}, {ID: grandchildID}, {ID: siblingID}})
require.NoError(t, err)
assert.Equal(t, int64(1), directCounts[parentID])
assert.Equal(t, int64(1), directCounts[childID])
assert.Equal(t, int64(1), directCounts[grandchildID])
assert.Equal(t, int64(1), directCounts[siblingID])
recursiveCounts, err := repo.CountTenantMembersRecursive(ctx, []domain.Tenant{{ID: parentID}, {ID: childID}, {ID: grandchildID}, {ID: siblingID}})
require.NoError(t, err)
assert.Equal(t, int64(3), recursiveCounts[parentID])
assert.Equal(t, int64(2), recursiveCounts[childID])
assert.Equal(t, int64(1), recursiveCounts[grandchildID])
assert.Equal(t, int64(1), recursiveCounts[siblingID])
}
func TestUserProjectionRepository_MarkFailedMakesProjectionNotReady(t *testing.T) {
ctx := context.Background()
repo := NewUserProjectionRepository(testDB)
require.NoError(t, testDB.Exec("DELETE FROM user_projection_states").Error)
require.NoError(t, repo.MarkFailed(ctx, errors.New("kratos down")))
ready, err := repo.IsReady(ctx)
require.NoError(t, err)
assert.False(t, ready)
}

View File

@@ -0,0 +1,335 @@
package repository
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/pagination"
"context"
"encoding/json"
"fmt"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type UserRepository interface {
Create(ctx context.Context, user *domain.User) error
Update(ctx context.Context, user *domain.User) error
FindByEmail(ctx context.Context, email string) (*domain.User, error)
FindByID(ctx context.Context, id string) (*domain.User, error)
FindByIDs(ctx context.Context, ids []string) ([]domain.User, error)
ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error)
List(ctx context.Context, offset, limit int, search string, tenantIDs []string, cursor string) ([]domain.User, int64, string, error)
CountByTenant(ctx context.Context, tenantID string) (int64, error)
CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error)
CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error)
FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error)
FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error)
Delete(ctx context.Context, id string) error
DB() *gorm.DB
// Multiple identifiers support
UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error
GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error)
IsLoginIDTaken(ctx context.Context, loginID string) (bool, error)
FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error)
}
type userRepository struct {
db *gorm.DB
}
func NewUserRepository(db *gorm.DB) UserRepository {
return &userRepository{db: db}
}
func (r *userRepository) DB() *gorm.DB {
return r.db
}
func (r *userRepository) withTenantMembershipFilter(db *gorm.DB, tenantIDs []string) *gorm.DB {
if len(tenantIDs) == 0 {
return db
}
clauses := []string{"tenant_id IN ?"}
args := []any{tenantIDs}
for _, tenantID := range tenantIDs {
tenantID = strings.TrimSpace(tenantID)
if tenantID == "" {
continue
}
payload, err := json.Marshal(map[string]any{
"additionalAppointments": []map[string]string{
{"tenantId": tenantID},
},
})
if err != nil {
continue
}
clauses = append(clauses, "metadata @> ?::jsonb")
args = append(args, string(payload))
}
return db.Where("("+strings.Join(clauses, " OR ")+")", args...)
}
func (r *userRepository) Create(ctx context.Context, user *domain.User) error {
return r.db.WithContext(ctx).Create(user).Error
}
func (r *userRepository) Update(ctx context.Context, user *domain.User) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. Check for email collision (including soft-deleted)
var existing domain.User
if err := tx.Unscoped().Where("email = ?", user.Email).First(&existing).Error; err == nil {
// If email exists but ID is different, we MUST clear the old one to avoid unique constraint violation
if existing.ID != user.ID {
// [Restored] Check if the existing user is archived
if strings.EqualFold(strings.TrimSpace(existing.Status), domain.UserStatusArchived) {
return fmt.Errorf("email is reserved by archived user: %s", user.Email)
}
// HARD DELETE the old record and its associated login IDs to free up the email and identifiers
if err := tx.Unscoped().Where("user_id = ?", existing.ID).Delete(&domain.UserLoginID{}).Error; err != nil {
return err
}
if err := tx.Unscoped().Delete(&domain.User{}, "id = ?", existing.ID).Error; err != nil {
return err
}
}
}
// 2. Perform Upsert on the new/target ID
return tx.Unscoped().Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "id"}},
DoUpdates: clause.Assignments(map[string]any{
"email": user.Email,
"name": user.Name,
"phone": user.Phone,
"role": user.Role,
"status": user.Status,
"department": user.Department,
"grade": user.Grade,
"position": user.Position,
"job_title": user.JobTitle,
"metadata": user.Metadata,
"tenant_id": user.TenantID,
"affiliation_type": user.AffiliationType,
"updated_at": user.UpdatedAt,
"deleted_at": nil, // Ensure it's active
}),
}).Create(user).Error
})
}
func (r *userRepository) FindByEmail(ctx context.Context, email string) (*domain.User, error) {
var user domain.User
if err := r.db.WithContext(ctx).Preload("Tenant").Where("email = ?", email).First(&user).Error; err != nil {
return nil, err
}
return &user, nil
}
func (r *userRepository) FindByID(ctx context.Context, id string) (*domain.User, error) {
var user domain.User
if err := r.db.WithContext(ctx).Preload("Tenant").Where("id = ?", id).First(&user).Error; err != nil {
return nil, err
}
return &user, nil
}
func (r *userRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) {
var users []domain.User
if len(ids) == 0 {
return users, nil
}
if err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&users).Error; err != nil {
return nil, err
}
return users, nil
}
func (r *userRepository) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
var users []domain.User
if err := r.withTenantMembershipFilter(r.db.WithContext(ctx), []string{tenantID}).Find(&users).Error; err != nil {
return nil, err
}
return users, nil
}
func (r *userRepository) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
var count int64
err := r.withTenantMembershipFilter(r.db.WithContext(ctx).Model(&domain.User{}), []string{tenantID}).Count(&count).Error
return count, err
}
func (r *userRepository) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
counts := make(map[string]int64)
if len(tenantIDs) == 0 {
return counts, nil
}
for _, tenantID := range tenantIDs {
var count int64
if err := r.withTenantMembershipFilter(r.db.WithContext(ctx).Model(&domain.User{}), []string{tenantID}).Count(&count).Error; err != nil {
return nil, err
}
counts[tenantID] = count
}
return counts, nil
}
func (r *userRepository) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) {
if len(codes) == 0 {
return make(map[string]int64), nil
}
type result struct {
TenantSlug string
Count int64
}
var results []result
lowerCodes := lowerStrings(codes)
if err := r.db.WithContext(ctx).Table("users").
Select("LOWER(tenants.slug) AS tenant_slug, count(DISTINCT users.id) AS count").
Joins("JOIN tenants ON users.tenant_id = tenants.id").
Where("users.deleted_at IS NULL AND LOWER(tenants.slug) IN ?", lowerCodes).
Group("LOWER(tenants.slug)").
Scan(&results).Error; err != nil {
return nil, err
}
counts := make(map[string]int64)
for _, res := range results {
counts[strings.ToLower(res.TenantSlug)] = res.Count
}
// Ensure all requested codes are present in results (even if count is 0)
for _, code := range codes {
lower := strings.ToLower(strings.TrimSpace(code))
if _, ok := counts[lower]; !ok {
counts[lower] = 0
}
}
return counts, nil
}
func lowerStrings(arr []string) []string {
res := make([]string, len(arr))
for i, s := range arr {
res[i] = strings.ToLower(strings.TrimSpace(s))
}
return res
}
func (r *userRepository) List(ctx context.Context, offset, limit int, search string, tenantIDs []string, cursorRaw string) ([]domain.User, int64, string, error) {
var users []domain.User
var total int64
db := r.db.WithContext(ctx).Model(&domain.User{})
if len(tenantIDs) > 0 {
db = r.withTenantMembershipFilter(db, tenantIDs)
}
if search != "" {
searchTerm := "%" + search + "%"
db = db.Where("(users.email LIKE ? OR users.name LIKE ? OR users.metadata::text LIKE ?)",
searchTerm, searchTerm, searchTerm)
}
if err := db.Count(&total).Error; err != nil {
return nil, 0, "", err
}
if cursorRaw != "" {
cursor, err := pagination.Decode(cursorRaw)
if err != nil {
return nil, 0, "", err
}
db = pagination.ApplyCreatedAtIDCursor(db, cursor, "created_at", "id")
} else {
db = db.Offset(offset)
}
if err := db.Order("created_at desc, id desc").Limit(limit + 1).Preload("Tenant").Find(&users).Error; err != nil {
return nil, 0, "", err
}
var items []domain.User
var nextCursor string
if len(users) > limit {
items = users[:limit]
last := items[limit-1]
nextCursor = pagination.Encode(last.CreatedAt, last.ID)
} else {
items = users
}
return items, total, nextCursor, nil
}
func (r *userRepository) Delete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&domain.User{}, "id = ?", id).Error
}
func (r *userRepository) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// [FIX] Use Unscoped to permanently delete existing login IDs for this user
// This prevents unique constraint violations with soft-deleted records
if err := tx.Unscoped().Where("user_id = ?", userID).Delete(&domain.UserLoginID{}).Error; err != nil {
return err
}
// Insert new login IDs if any
if len(loginIDs) > 0 {
if err := tx.Create(&loginIDs).Error; err != nil {
return err
}
}
return nil
})
}
func (r *userRepository) GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) {
var results []domain.UserLoginID
if err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&results).Error; err != nil {
return nil, err
}
return results, nil
}
func (r *userRepository) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) {
var count int64
if err := r.db.WithContext(ctx).Model(&domain.UserLoginID{}).Where("login_id = ?", loginID).Count(&count).Error; err != nil {
return false, err
}
return count > 0, nil
}
func (r *userRepository) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) {
var record domain.UserLoginID
if err := r.db.WithContext(ctx).Where("login_id = ?", loginID).First(&record).Error; err != nil {
return "", err
}
return record.TenantID, nil
}
func (r *userRepository) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) {
var users []domain.User
err := r.withTenantMembershipFilter(r.db.WithContext(ctx), tenantIDs).Find(&users).Error
return users, err
}
func (r *userRepository) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) {
var users []domain.User
err := r.db.WithContext(ctx).
Joins("JOIN tenants ON users.tenant_id = tenants.id").
Where("LOWER(tenants.slug) IN ?", lowerStrings(codes)).
Preload("Tenant").
Find(&users).Error
return users, err
}

View File

@@ -0,0 +1,273 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestUserRepository(t *testing.T) {
repo := NewUserRepository(testDB)
ctx := context.Background()
// Ensure User table exists and clean for tests
_ = testDB.AutoMigrate(&domain.User{})
t.Run("Create and FindByEmail", func(t *testing.T) {
user := &domain.User{
Email: "test@example.com",
Name: "Test User",
Role: "user",
}
err := repo.Create(ctx, user)
assert.NoError(t, err)
assert.NotEmpty(t, user.ID)
found, err := repo.FindByEmail(ctx, "test@example.com")
assert.NoError(t, err)
assert.Equal(t, user.ID, found.ID)
assert.Equal(t, "Test User", found.Name)
})
t.Run("Update User Info", func(t *testing.T) {
user := &domain.User{
Email: "update@example.com",
Name: "Before Update",
Role: "user",
}
_ = repo.Create(ctx, user)
user.Name = "After Update"
user.Phone = "010-1234-5678"
err := repo.Update(ctx, user)
assert.NoError(t, err)
found, err := repo.FindByEmail(ctx, "update@example.com")
assert.NoError(t, err)
assert.Equal(t, "After Update", found.Name)
assert.Equal(t, "010-1234-5678", found.Phone)
})
t.Run("Update preserves archived email reservation", func(t *testing.T) {
testDB.Exec("DELETE FROM user_login_ids")
testDB.Exec("DELETE FROM users")
archived := &domain.User{
ID: "00000000-0000-0000-0000-00000000a001",
Email: "reserved@example.com",
Name: "Archived User",
Role: domain.RoleUser,
Status: domain.UserStatusArchived,
}
replacement := &domain.User{
ID: "00000000-0000-0000-0000-00000000a002",
Email: "reserved@example.com",
Name: "Replacement User",
Role: domain.RoleUser,
Status: domain.UserStatusActive,
}
require.NoError(t, repo.Create(ctx, archived))
err := repo.Update(ctx, replacement)
require.Error(t, err)
require.Contains(t, err.Error(), "archived user")
found, err := repo.FindByEmail(ctx, archived.Email)
require.NoError(t, err)
require.Equal(t, archived.ID, found.ID)
require.Equal(t, domain.UserStatusArchived, found.Status)
})
t.Run("List Users with Search", func(t *testing.T) {
// Add some users
_ = repo.Create(ctx, &domain.User{Email: "alice@test.com", Name: "Alice", Role: "user"})
_ = repo.Create(ctx, &domain.User{Email: "bob@test.com", Name: "Bob", Role: "user"})
users, total, _, err := repo.List(ctx, 0, 10, "Alice", []string{}, "")
assert.NoError(t, err)
assert.True(t, total >= 1)
assert.Equal(t, "Alice", users[0].Name)
})
t.Run("Delete User", func(t *testing.T) {
user := &domain.User{Email: "delete@example.com", Name: "To Delete"}
_ = repo.Create(ctx, user)
err := repo.Delete(ctx, user.ID)
assert.NoError(t, err)
found, err := repo.FindByEmail(ctx, "delete@example.com")
assert.Error(t, err) // Should not be found
assert.Nil(t, found)
})
t.Run("CountByCompanyCodes", func(t *testing.T) {
// Clean start for this subtest
testDB.Exec("DELETE FROM user_login_ids")
testDB.Exec("DELETE FROM users")
testDB.Exec("DELETE FROM tenant_domains")
tenantA := createUserRepositoryTestTenant(t, "tenant-a")
tenantB := createUserRepositoryTestTenant(t, "tenant-b")
users := []domain.User{
{Email: "u1@a.com", Name: "U1", TenantID: &tenantA.ID},
{Email: "u2@a.com", Name: "U2", TenantID: &tenantA.ID},
{Email: "u3@b.com", Name: "U3", TenantID: &tenantB.ID},
{Email: "u4@none.com", Name: "U4"},
}
for _, u := range users {
_ = repo.Create(ctx, &u)
}
counts, err := repo.CountByCompanyCodes(ctx, []string{"tenant-a", "tenant-b", "tenant-c"})
assert.NoError(t, err)
assert.Equal(t, int64(2), counts["tenant-a"])
assert.Equal(t, int64(1), counts["tenant-b"])
assert.Equal(t, int64(0), counts["tenant-c"])
})
t.Run("CountByCompanyCodes excludes soft deleted cache rows", func(t *testing.T) {
testDB.Exec("DELETE FROM user_login_ids")
testDB.Exec("DELETE FROM users")
testDB.Exec("DELETE FROM tenant_domains")
tenantA := createUserRepositoryTestTenant(t, "tenant-a")
active := &domain.User{Email: "active@a.com", Name: "Active", TenantID: &tenantA.ID}
deleted := &domain.User{Email: "deleted@a.com", Name: "Deleted", TenantID: &tenantA.ID}
secondDeleted := &domain.User{Email: "second-deleted@a.com", Name: "Second Deleted", TenantID: &tenantA.ID}
assert.NoError(t, repo.Create(ctx, active))
assert.NoError(t, repo.Create(ctx, deleted))
assert.NoError(t, repo.Create(ctx, secondDeleted))
assert.NoError(t, repo.Delete(ctx, deleted.ID))
assert.NoError(t, repo.Delete(ctx, secondDeleted.ID))
counts, err := repo.CountByCompanyCodes(ctx, []string{"tenant-a"})
assert.NoError(t, err)
assert.Equal(t, int64(1), counts["tenant-a"])
})
t.Run("Multi-Identifier Support", func(t *testing.T) {
_ = testDB.AutoMigrate(&domain.UserLoginID{})
testDB.Exec("DELETE FROM user_login_ids")
testDB.Exec("DELETE FROM users")
user := &domain.User{Email: "multi@test.com", Name: "Multi"}
_ = repo.Create(ctx, user)
t1 := "00000000-0000-0000-0000-000000000001"
t2 := "00000000-0000-0000-0000-000000000002"
loginIDs := []domain.UserLoginID{
{UserID: user.ID, TenantID: t1, FieldKey: "emp_id", LoginID: "E001"},
{UserID: user.ID, TenantID: t2, FieldKey: "student_id", LoginID: "S001"},
}
err := repo.UpdateUserLoginIDs(ctx, user.ID, loginIDs)
assert.NoError(t, err)
// Get and Verify
saved, err := repo.GetUserLoginIDs(ctx, user.ID)
assert.NoError(t, err)
assert.Len(t, saved, 2)
// IsLoginIDTaken
taken, err := repo.IsLoginIDTaken(ctx, "E001")
assert.NoError(t, err)
assert.True(t, taken)
taken, err = repo.IsLoginIDTaken(ctx, "UNKNOWN")
assert.NoError(t, err)
assert.False(t, taken)
// FindTenantIDByLoginID
tid, err := repo.FindTenantIDByLoginID(ctx, "S001")
assert.NoError(t, err)
assert.Equal(t, t2, tid)
// Update (Replace)
newList := []domain.UserLoginID{
{UserID: user.ID, TenantID: t1, FieldKey: "emp_id", LoginID: "E002"},
}
err = repo.UpdateUserLoginIDs(ctx, user.ID, newList)
assert.NoError(t, err)
saved, _ = repo.GetUserLoginIDs(ctx, user.ID)
assert.Len(t, saved, 1)
assert.Equal(t, "E002", saved[0].LoginID)
})
}
func TestUserRepository_ListIncludesAdditionalTenantAppointments(t *testing.T) {
repo := NewUserRepository(testDB)
ctx := context.Background()
require.NoError(t, testDB.Exec("DELETE FROM user_login_ids").Error)
require.NoError(t, testDB.Exec("DELETE FROM users").Error)
primaryTenant := createUserRepositoryTestTenant(t, "repo-primary-tenant")
additionalTenant := createUserRepositoryTestTenant(t, "repo-additional-tenant")
primaryTenantID := primaryTenant.ID
additionalTenantID := additionalTenant.ID
users := []domain.User{
{
ID: uuid.NewString(),
Email: "primary-member@example.com",
Name: "Primary Member",
Role: domain.RoleUser,
TenantID: &additionalTenantID,
},
{
ID: uuid.NewString(),
Email: "additional-member@example.com",
Name: "Additional Member",
Role: domain.RoleUser,
TenantID: &primaryTenantID,
Metadata: domain.JSONMap{
"additionalAppointments": []any{
map[string]any{
"tenantId": additionalTenant.ID,
"tenantSlug": additionalTenant.Slug,
"tenantName": additionalTenant.Name,
"isPrimary": false,
},
},
},
},
}
for i := range users {
require.NoError(t, repo.Create(ctx, &users[i]))
}
listed, total, _, err := repo.List(ctx, 0, 20, "", []string{additionalTenant.ID}, "")
require.NoError(t, err)
require.Equal(t, int64(2), total)
require.Len(t, listed, 2)
emails := []string{listed[0].Email, listed[1].Email}
assert.Contains(t, emails, "primary-member@example.com")
assert.Contains(t, emails, "additional-member@example.com")
counts, err := repo.CountByTenantIDs(ctx, []string{additionalTenant.ID})
require.NoError(t, err)
assert.Equal(t, int64(2), counts[additionalTenant.ID])
}
func createUserRepositoryTestTenant(t *testing.T, slug string) domain.Tenant {
t.Helper()
require.NoError(t, testDB.Unscoped().Where("slug = ?", slug).Delete(&domain.Tenant{}).Error)
tenant := domain.Tenant{
ID: uuid.NewString(),
Name: "Tenant " + slug,
Slug: slug,
Type: domain.TenantTypeCompany,
Status: domain.TenantStatusActive,
}
require.NoError(t, testDB.Create(&tenant).Error)
return tenant
}

View File

@@ -0,0 +1,208 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type WorksmobileOutboxRepository interface {
Create(ctx context.Context, item *domain.WorksmobileOutbox) error
ListRecent(ctx context.Context, limit int) ([]domain.WorksmobileOutbox, error)
ListCredentialBatchJobs(ctx context.Context, tenantRootID, credentialBatchID string) ([]domain.WorksmobileOutbox, error)
UpdatePayload(ctx context.Context, id string, payload domain.JSONMap) error
DeletePendingByTenantRoot(ctx context.Context, tenantRootID string) (int64, error)
ListReady(ctx context.Context, limit int) ([]domain.WorksmobileOutbox, error)
FindByID(ctx context.Context, id string) (*domain.WorksmobileOutbox, error)
MarkRetry(ctx context.Context, id string) error
MarkProcessing(ctx context.Context, id string) (bool, error)
MarkProcessed(ctx context.Context, id string) error
MarkFailed(ctx context.Context, id string, message string, nextAttemptAt time.Time) error
}
type worksmobileOutboxRepository struct {
db *gorm.DB
}
func NewWorksmobileOutboxRepository(db *gorm.DB) WorksmobileOutboxRepository {
return &worksmobileOutboxRepository{db: db}
}
func (r *worksmobileOutboxRepository) Create(ctx context.Context, item *domain.WorksmobileOutbox) error {
if item.Payload == nil {
item.Payload = domain.JSONMap{}
}
if item.Status == "" {
item.Status = domain.WorksmobileOutboxStatusPending
}
return r.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "dedupe_key"}},
DoUpdates: clause.Assignments(map[string]any{
"payload": item.Payload,
"status": domain.WorksmobileOutboxStatusPending,
"last_error": "",
"next_attempt_at": nil,
"updated_at": time.Now(),
}),
}).Create(item).Error
}
func (r *worksmobileOutboxRepository) ListRecent(ctx context.Context, limit int) ([]domain.WorksmobileOutbox, error) {
if limit <= 0 || limit > 1000 {
limit = 50
}
var rows []domain.WorksmobileOutbox
err := r.db.WithContext(ctx).Order("created_at desc").Limit(limit).Find(&rows).Error
return rows, err
}
func (r *worksmobileOutboxRepository) ListCredentialBatchJobs(ctx context.Context, tenantRootID, credentialBatchID string) ([]domain.WorksmobileOutbox, error) {
query := r.db.WithContext(ctx).
Where("resource_type = ? AND payload ->> 'tenantRootId' = ? AND coalesce(payload ->> 'credentialBatchId', '') <> ?", domain.WorksmobileResourceUser, tenantRootID, "")
if credentialBatchID != "" {
query = query.Where("payload ->> 'credentialBatchId' = ?", credentialBatchID)
}
var rows []domain.WorksmobileOutbox
err := query.Order("created_at desc").Find(&rows).Error
return rows, err
}
func (r *worksmobileOutboxRepository) UpdatePayload(ctx context.Context, id string, payload domain.JSONMap) error {
return r.db.WithContext(ctx).Model(&domain.WorksmobileOutbox{}).Where("id = ?", id).Updates(map[string]any{
"payload": payload,
"updated_at": time.Now(),
}).Error
}
func (r *worksmobileOutboxRepository) DeletePendingByTenantRoot(ctx context.Context, tenantRootID string) (int64, error) {
result := r.db.WithContext(ctx).
Where("status = ? AND payload ->> 'tenantRootId' = ?", domain.WorksmobileOutboxStatusPending, tenantRootID).
Delete(&domain.WorksmobileOutbox{})
return result.RowsAffected, result.Error
}
func (r *worksmobileOutboxRepository) ListReady(ctx context.Context, limit int) ([]domain.WorksmobileOutbox, error) {
if limit <= 0 || limit > 100 {
limit = 20
}
var rows []domain.WorksmobileOutbox
err := r.db.WithContext(ctx).Raw(`
WITH RECURSIVE candidates AS (
SELECT
*,
NULLIF(payload #>> '{request,orgUnitExternalKey}', '') AS org_external_key,
CASE
WHEN payload #>> '{request,parentOrgUnitId}' LIKE 'externalKey:%'
THEN NULLIF(substr(payload #>> '{request,parentOrgUnitId}', length('externalKey:') + 1), '')
ELSE ''
END AS parent_external_key
FROM worksmobile_outboxes
WHERE status = ? AND (next_attempt_at IS NULL OR next_attempt_at <= ?)
),
ready AS (
SELECT candidates.*
FROM candidates
WHERE NOT (
candidates.resource_type = ?
AND candidates.action = ?
AND candidates.parent_external_key <> ''
AND EXISTS (
SELECT 1
FROM worksmobile_outboxes parent_job
WHERE parent_job.resource_type = ?
AND parent_job.action = ?
AND parent_job.status <> ?
AND NULLIF(parent_job.payload #>> '{request,orgUnitExternalKey}', '') = candidates.parent_external_key
)
)
),
org_depth AS (
SELECT id, org_external_key, parent_external_key, 0 AS depth
FROM ready
UNION ALL
SELECT child.id, child.org_external_key, child.parent_external_key, parent.depth + 1
FROM ready child
JOIN org_depth parent ON child.parent_external_key = parent.org_external_key
WHERE child.resource_type = ? AND child.action = ? AND parent.depth < 64
)
SELECT ready.*
FROM ready
LEFT JOIN LATERAL (
SELECT max(depth) AS dependency_depth
FROM org_depth
WHERE org_depth.id = ready.id
) AS depth_rank ON true
ORDER BY
CASE
WHEN ready.resource_type = ? AND ready.action = ? THEN 0
WHEN ready.resource_type = ? THEN 1
ELSE 2
END ASC,
COALESCE(depth_rank.dependency_depth, 0) ASC,
ready.created_at ASC
LIMIT ?
`,
domain.WorksmobileOutboxStatusPending,
time.Now(),
domain.WorksmobileResourceOrgUnit,
domain.WorksmobileActionUpsert,
domain.WorksmobileResourceOrgUnit,
domain.WorksmobileActionUpsert,
domain.WorksmobileOutboxStatusProcessed,
domain.WorksmobileResourceOrgUnit,
domain.WorksmobileActionUpsert,
domain.WorksmobileResourceOrgUnit,
domain.WorksmobileActionUpsert,
domain.WorksmobileResourceUser,
limit,
).Scan(&rows).Error
return rows, err
}
func (r *worksmobileOutboxRepository) FindByID(ctx context.Context, id string) (*domain.WorksmobileOutbox, error) {
var row domain.WorksmobileOutbox
if err := r.db.WithContext(ctx).First(&row, "id = ?", id).Error; err != nil {
return nil, err
}
return &row, nil
}
func (r *worksmobileOutboxRepository) MarkRetry(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Model(&domain.WorksmobileOutbox{}).Where("id = ?", id).Updates(map[string]any{
"status": domain.WorksmobileOutboxStatusPending,
"last_error": "",
"next_attempt_at": nil,
"updated_at": time.Now(),
}).Error
}
func (r *worksmobileOutboxRepository) MarkProcessing(ctx context.Context, id string) (bool, error) {
result := r.db.WithContext(ctx).Model(&domain.WorksmobileOutbox{}).Where("id = ? AND status = ?", id, domain.WorksmobileOutboxStatusPending).Updates(map[string]any{
"status": domain.WorksmobileOutboxStatusProcessing,
"updated_at": time.Now(),
})
return result.RowsAffected > 0, result.Error
}
func (r *worksmobileOutboxRepository) MarkProcessed(ctx context.Context, id string) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&domain.WorksmobileOutbox{}).Where("id = ?", id).Updates(map[string]any{
"status": domain.WorksmobileOutboxStatusProcessed,
"last_error": "",
"processed_at": &now,
"updated_at": now,
}).Error
}
func (r *worksmobileOutboxRepository) MarkFailed(ctx context.Context, id string, message string, nextAttemptAt time.Time) error {
return r.db.WithContext(ctx).Model(&domain.WorksmobileOutbox{}).Where("id = ?", id).Updates(map[string]any{
"status": domain.WorksmobileOutboxStatusFailed,
"retry_count": gorm.Expr("retry_count + 1"),
"last_error": message,
"next_attempt_at": &nextAttemptAt,
"updated_at": time.Now(),
}).Error
}

View File

@@ -0,0 +1,125 @@
package repository
import (
"baron-sso-backend/internal/domain"
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestWorksmobileOutboxRepositoryDeletePendingByTenantRoot(t *testing.T) {
repo := NewWorksmobileOutboxRepository(testDB)
ctx := context.Background()
require.NoError(t, testDB.Exec("DELETE FROM worksmobile_outboxes").Error)
rows := []domain.WorksmobileOutbox{
{
ID: "00000000-0000-0000-0000-000000000101",
ResourceType: domain.WorksmobileResourceUser,
ResourceID: "user-pending",
Action: domain.WorksmobileActionUpsert,
Status: domain.WorksmobileOutboxStatusPending,
DedupeKey: "pending-root",
Payload: domain.JSONMap{"tenantRootId": "root-1"},
},
{
ID: "00000000-0000-0000-0000-000000000102",
ResourceType: domain.WorksmobileResourceUser,
ResourceID: "user-other-root",
Action: domain.WorksmobileActionUpsert,
Status: domain.WorksmobileOutboxStatusPending,
DedupeKey: "pending-other-root",
Payload: domain.JSONMap{"tenantRootId": "root-2"},
},
{
ID: "00000000-0000-0000-0000-000000000103",
ResourceType: domain.WorksmobileResourceUser,
ResourceID: "user-failed",
Action: domain.WorksmobileActionUpsert,
Status: domain.WorksmobileOutboxStatusFailed,
DedupeKey: "failed-root",
Payload: domain.JSONMap{"tenantRootId": "root-1"},
},
{
ID: "00000000-0000-0000-0000-000000000104",
ResourceType: domain.WorksmobileResourceOrgUnit,
ResourceID: "org-processed",
Action: domain.WorksmobileActionUpsert,
Status: domain.WorksmobileOutboxStatusProcessed,
DedupeKey: "processed-root",
Payload: domain.JSONMap{"tenantRootId": "root-1"},
},
}
for i := range rows {
require.NoError(t, repo.Create(ctx, &rows[i]))
}
deleted, err := repo.DeletePendingByTenantRoot(ctx, "root-1")
require.NoError(t, err)
require.Equal(t, int64(1), deleted)
var remaining []domain.WorksmobileOutbox
require.NoError(t, testDB.Order("id asc").Find(&remaining).Error)
require.Len(t, remaining, 3)
require.Equal(t, "00000000-0000-0000-0000-000000000102", remaining[0].ID)
require.Equal(t, "00000000-0000-0000-0000-000000000103", remaining[1].ID)
require.Equal(t, "00000000-0000-0000-0000-000000000104", remaining[2].ID)
}
func TestWorksmobileOutboxRepositoryListReadyWaitsForPendingOrgUnitParent(t *testing.T) {
repo := NewWorksmobileOutboxRepository(testDB)
ctx := context.Background()
require.NoError(t, testDB.Exec("DELETE FROM worksmobile_outboxes").Error)
baseTime := time.Date(2026, 6, 2, 15, 21, 0, 0, time.UTC)
child := domain.WorksmobileOutbox{
ID: "00000000-0000-0000-0000-000000000201",
ResourceType: domain.WorksmobileResourceOrgUnit,
ResourceID: "child-tenant",
Action: domain.WorksmobileActionUpsert,
Status: domain.WorksmobileOutboxStatusPending,
DedupeKey: "orgunit:upsert:child-tenant",
Payload: domain.JSONMap{
"request": map[string]any{
"orgUnitExternalKey": "child-tenant",
"parentOrgUnitId": "externalKey:parent-tenant",
},
},
CreatedAt: baseTime,
UpdatedAt: baseTime,
}
parent := domain.WorksmobileOutbox{
ID: "00000000-0000-0000-0000-000000000202",
ResourceType: domain.WorksmobileResourceOrgUnit,
ResourceID: "parent-tenant",
Action: domain.WorksmobileActionUpsert,
Status: domain.WorksmobileOutboxStatusPending,
DedupeKey: "orgunit:upsert:parent-tenant",
Payload: domain.JSONMap{
"request": map[string]any{
"orgUnitExternalKey": "parent-tenant",
},
},
CreatedAt: baseTime.Add(time.Second),
UpdatedAt: baseTime.Add(time.Second),
}
require.NoError(t, testDB.Create(&child).Error)
require.NoError(t, testDB.Create(&parent).Error)
rows, err := repo.ListReady(ctx, 10)
require.NoError(t, err)
require.Len(t, rows, 1)
require.Equal(t, "parent-tenant", rows[0].ResourceID)
require.NoError(t, repo.MarkProcessed(ctx, parent.ID))
rows, err = repo.ListReady(ctx, 10)
require.NoError(t, err)
require.Len(t, rows, 1)
require.Equal(t, "child-tenant", rows[0].ResourceID)
}