1
0
forked from baron/baron-sso
Files
baron-sso/backend/internal/repository/clickhouse_repo.go

450 lines
12 KiB
Go

package repository
import (
"baron-sso-backend/internal/domain"
"context"
"encoding/json"
"fmt"
"time"
"github.com/ClickHouse/clickhouse-go/v2"
"github.com/ClickHouse/clickhouse-go/v2/lib/driver"
)
type ClickHouseRepository struct {
conn driver.Conn
}
func NewClickHouseRepository(host string, port int, user, password, db string) (*ClickHouseRepository, error) {
// 1. Connect to 'default' database first to ensure target DB exists
tmpConn, err := clickhouse.Open(&clickhouse.Options{
Addr: []string{fmt.Sprintf("%s:%d", host, port)},
Auth: clickhouse.Auth{
Database: "default",
Username: user,
Password: password,
},
})
if err == nil {
_ = tmpConn.Exec(context.Background(), fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", db))
_ = tmpConn.Close()
}
// 2. Now connect to the target database
conn, err := clickhouse.Open(&clickhouse.Options{
Addr: []string{fmt.Sprintf("%s:%d", host, port)},
Auth: clickhouse.Auth{
Database: db,
Username: user,
Password: password,
},
Debug: false,
})
if err != nil {
return nil, fmt.Errorf("failed to open clickhouse connection: %w", err)
}
if err := conn.Ping(context.Background()); err != nil {
return nil, fmt.Errorf("failed to ping clickhouse: %w", err)
}
// Ensure Table Exists
// Note: In production, use migrations.
query := `
CREATE TABLE IF NOT EXISTS audit_logs (
event_id String,
timestamp DateTime DEFAULT now(),
user_id String,
tenant_id String,
event_type String,
status String,
ip_address String,
user_agent String,
device_id String,
details String
) ENGINE = MergeTree()
ORDER BY timestamp
`
if err := conn.Exec(context.Background(), query); err != nil {
return nil, fmt.Errorf("failed to create table: %w", err)
}
alterQuery := `
ALTER TABLE audit_logs
ADD COLUMN IF NOT EXISTS tenant_id String,
ADD COLUMN IF NOT EXISTS event_id String
`
if err := conn.Exec(context.Background(), alterQuery); err != nil {
return nil, fmt.Errorf("failed to alter table: %w", err)
}
if err := ensureRPUsageTables(context.Background(), conn); err != nil {
return nil, fmt.Errorf("failed to create rp usage tables: %w", err)
}
return &ClickHouseRepository{conn: conn}, nil
}
func ensureRPUsageTables(ctx context.Context, conn driver.Conn) error {
factQuery := `
CREATE TABLE IF NOT EXISTS rp_usage_events (
event_id String,
occurred_at DateTime64(3) DEFAULT now64(3),
event_type String,
subject String,
tenant_id String,
tenant_type String,
client_id String,
client_name String,
session_id String,
scopes Array(String),
source String,
correlation_id String,
payload String
) ENGINE = MergeTree()
ORDER BY (occurred_at, event_id)
`
if err := conn.Exec(ctx, factQuery); err != nil {
return err
}
aggregateQuery := `
CREATE TABLE IF NOT EXISTS rp_usage_daily_aggregate (
event_date Date,
tenant_id String,
tenant_type String,
client_id String,
client_name String,
event_type String,
events_count AggregateFunction(count),
unique_subjects AggregateFunction(uniqExact, String)
) ENGINE = AggregatingMergeTree()
ORDER BY (event_date, tenant_id, client_id, event_type)
`
if err := conn.Exec(ctx, aggregateQuery); err != nil {
return err
}
viewQuery := `
CREATE MATERIALIZED VIEW IF NOT EXISTS rp_usage_daily_aggregate_mv
TO rp_usage_daily_aggregate
AS
SELECT
toDate(occurred_at) AS event_date,
tenant_id,
tenant_type,
client_id,
any(client_name) AS client_name,
event_type,
countState() AS events_count,
uniqExactState(subject) AS unique_subjects
FROM rp_usage_events
WHERE tenant_type IN ('COMPANY', 'ORGANIZATION')
GROUP BY event_date, tenant_id, tenant_type, client_id, event_type
`
return conn.Exec(ctx, viewQuery)
}
func (r *ClickHouseRepository) Create(log *domain.AuditLog) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if log.Timestamp.IsZero() {
log.Timestamp = time.Now()
}
query := `
INSERT INTO audit_logs (event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
return r.conn.Exec(ctx, query,
log.EventID,
log.Timestamp,
log.UserID,
log.TenantID,
log.EventType,
log.Status,
log.IPAddress,
log.UserAgent,
log.DeviceID,
log.Details,
)
}
func (r *ClickHouseRepository) CreateRPUsageEvent(ctx context.Context, event domain.RPUsageEvent) error {
if r == nil || r.conn == nil {
return fmt.Errorf("clickhouse connection is nil")
}
if event.OccurredAt.IsZero() {
event.OccurredAt = time.Now()
}
payloadBytes, err := json.Marshal(event.Payload)
if err != nil {
return fmt.Errorf("failed to marshal rp usage payload: %w", err)
}
query := `
INSERT INTO rp_usage_events (
event_id, occurred_at, event_type, subject, tenant_id, tenant_type,
client_id, client_name, session_id, scopes, source, correlation_id, payload
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
return r.conn.Exec(ctx, query,
event.ID,
event.OccurredAt,
event.EventType,
event.Subject,
event.TenantID,
event.TenantType,
event.ClientID,
event.ClientName,
event.SessionID,
[]string(event.Scopes),
event.Source,
event.CorrelationID,
string(payloadBytes),
)
}
func (r *ClickHouseRepository) FindRPUsage(ctx context.Context, rpQuery domain.RPUsageQuery) ([]domain.RPUsageDailyMetric, error) {
if r == nil || r.conn == nil {
return nil, fmt.Errorf("clickhouse connection is nil")
}
days := rpQuery.Days
if days <= 0 || days > 90 {
days = 14
}
periodExpr := "event_date"
switch rpQuery.Period {
case "week":
periodExpr = "toMonday(event_date)"
case "month":
periodExpr = "toStartOfMonth(event_date)"
case "day", "":
periodExpr = "event_date"
default:
periodExpr = "event_date"
}
query := fmt.Sprintf(`
SELECT
date,
tenant_id,
tenant_type,
client_id,
any(client_name) AS client_name,
sumIf(events, event_type = ?) AS login_requests,
sumIf(events, event_type != ?) AS other_requests,
max(unique_subjects) AS unique_subjects
FROM (
SELECT
toString(%s) AS date,
tenant_id,
tenant_type,
client_id,
any(client_name) AS client_name,
event_type,
countMerge(events_count) AS events,
uniqExactMerge(unique_subjects) AS unique_subjects
FROM rp_usage_daily_aggregate
WHERE event_date >= today() - ?
AND tenant_type IN ('COMPANY', 'ORGANIZATION')
`, periodExpr)
args := []any{domain.RPUsageEventTypeAuthorizationGranted, domain.RPUsageEventTypeAuthorizationGranted, days - 1}
if rpQuery.TenantID != "" {
query += " AND tenant_id = ?\n"
args = append(args, rpQuery.TenantID)
}
query += fmt.Sprintf(`
GROUP BY %s, tenant_id, tenant_type, client_id, event_type
)
GROUP BY date, tenant_id, tenant_type, client_id
ORDER BY date ASC, tenant_id ASC, client_id ASC
`, periodExpr)
rows, err := r.conn.Query(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to query rp usage daily aggregate: %w", err)
}
defer rows.Close()
metrics := make([]domain.RPUsageDailyMetric, 0)
for rows.Next() {
var metric domain.RPUsageDailyMetric
if err := rows.Scan(
&metric.Date,
&metric.TenantID,
&metric.TenantType,
&metric.ClientID,
&metric.ClientName,
&metric.LoginRequests,
&metric.OtherRequests,
&metric.UniqueSubjects,
); err != nil {
return nil, fmt.Errorf("failed to scan rp usage daily aggregate: %w", err)
}
if metric.ClientName == "" {
metric.ClientName = metric.ClientID
}
metrics = append(metrics, metric)
}
return metrics, nil
}
func (r *ClickHouseRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor, tenantID string) ([]domain.AuditLog, error) {
if limit <= 0 {
limit = 50
}
query := `
SELECT event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details
FROM audit_logs
WHERE 1=1
`
args := make([]any, 0, 5)
if tenantID != "" {
query += " AND (tenant_id = ? OR (tenant_id = '' AND JSONExtractString(details, 'tenant_id') = ?))"
args = append(args, tenantID, tenantID)
}
if cursor != nil {
query += `
AND ((timestamp < ?) OR (timestamp = ? AND event_id < ?))
`
args = append(args, cursor.Timestamp, cursor.Timestamp, cursor.EventID)
}
query += `
ORDER BY timestamp DESC, event_id DESC
LIMIT ?
`
args = append(args, limit)
rows, err := r.conn.Query(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to query audit logs: %w", err)
}
defer rows.Close()
var logs []domain.AuditLog
for rows.Next() {
var log domain.AuditLog
if err := rows.Scan(
&log.EventID,
&log.Timestamp,
&log.UserID,
&log.TenantID,
&log.EventType,
&log.Status,
&log.IPAddress,
&log.UserAgent,
&log.DeviceID,
&log.Details,
); err != nil {
return nil, fmt.Errorf("failed to scan audit log: %w", err)
}
logs = append(logs, log)
}
return logs, nil
}
func (r *ClickHouseRepository) FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]domain.AuditLog, error) {
if limit <= 0 {
limit = 100
}
query := `
SELECT event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details
FROM audit_logs
WHERE user_id = ? AND event_type IN (?)
ORDER BY timestamp DESC
LIMIT ?
`
rows, err := r.conn.Query(ctx, query, userID, eventTypes, limit)
if err != nil {
return nil, fmt.Errorf("failed to query audit logs by user/events: %w", err)
}
defer rows.Close()
var logs []domain.AuditLog
for rows.Next() {
var log domain.AuditLog
if err := rows.Scan(
&log.EventID,
&log.Timestamp,
&log.UserID,
&log.TenantID,
&log.EventType,
&log.Status,
&log.IPAddress,
&log.UserAgent,
&log.DeviceID,
&log.Details,
); err != nil {
return nil, fmt.Errorf("failed to scan audit log: %w", err)
}
logs = append(logs, log)
}
return logs, nil
}
func (r *ClickHouseRepository) Ping(ctx context.Context) error {
if r.conn == nil {
return fmt.Errorf("clickhouse connection is nil")
}
return r.conn.Ping(ctx)
}
func (r *ClickHouseRepository) CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
query := `
SELECT count()
FROM audit_logs
WHERE status = 'failure' AND timestamp >= ?
`
args := []any{since}
if tenantID != "" {
query += " AND JSONExtractString(details, 'tenant_id') = ?"
args = append(args, tenantID)
}
var count int64
err := r.conn.QueryRow(ctx, query, args...).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to count failures: %w", err)
}
return count, nil
}
func (r *ClickHouseRepository) CountEventsSince(ctx context.Context, since time.Time) (int64, error) {
sinceUTC := since.UTC().Format("2006-01-02 15:04:05")
query := fmt.Sprintf(`
SELECT count()
FROM audit_logs
WHERE timestamp >= toDateTime('%s')
`, sinceUTC)
var count int64
err := r.conn.QueryRow(ctx, query).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to count audit events: %w", err)
}
return count, nil
}
func (r *ClickHouseRepository) CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
// We use uniqExact(session_id) to count unique sessions that had success events recently.
query := `
SELECT uniqExact(session_id)
FROM audit_logs
WHERE status = 'success' AND timestamp >= ? AND session_id != ''
`
args := []any{since}
if tenantID != "" {
query += " AND JSONExtractString(details, 'tenant_id') = ?"
args = append(args, tenantID)
}
var count int64
err := r.conn.QueryRow(ctx, query, args...).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to count active sessions: %w", err)
}
return count, nil
}