1
0
forked from baron/baron-sso
Files
baron-sso/backend/internal/repository/clickhouse_repo.go

251 lines
6.2 KiB
Go

package repository
import (
"baron-sso-backend/internal/domain"
"context"
"fmt"
"time"
"github.com/ClickHouse/clickhouse-go/v2"
"github.com/ClickHouse/clickhouse-go/v2/lib/driver"
)
type ClickHouseRepository struct {
conn driver.Conn
}
func NewClickHouseRepository(host string, port int, user, password, db string) (*ClickHouseRepository, error) {
// 1. Connect to 'default' database first to ensure target DB exists
tmpConn, err := clickhouse.Open(&clickhouse.Options{
Addr: []string{fmt.Sprintf("%s:%d", host, port)},
Auth: clickhouse.Auth{
Database: "default",
Username: user,
Password: password,
},
})
if err == nil {
_ = tmpConn.Exec(context.Background(), fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", db))
_ = tmpConn.Close()
}
// 2. Now connect to the target database
conn, err := clickhouse.Open(&clickhouse.Options{
Addr: []string{fmt.Sprintf("%s:%d", host, port)},
Auth: clickhouse.Auth{
Database: db,
Username: user,
Password: password,
},
Debug: false,
})
if err != nil {
return nil, fmt.Errorf("failed to open clickhouse connection: %w", err)
}
if err := conn.Ping(context.Background()); err != nil {
return nil, fmt.Errorf("failed to ping clickhouse: %w", err)
}
// Ensure Table Exists
// Note: In production, use migrations.
query := `
CREATE TABLE IF NOT EXISTS audit_logs (
event_id String,
timestamp DateTime DEFAULT now(),
user_id String,
tenant_id String,
event_type String,
status String,
ip_address String,
user_agent String,
device_id String,
details String
) ENGINE = MergeTree()
ORDER BY timestamp
`
if err := conn.Exec(context.Background(), query); err != nil {
return nil, fmt.Errorf("failed to create table: %w", err)
}
alterQuery := `
ALTER TABLE audit_logs
ADD COLUMN IF NOT EXISTS tenant_id String,
ADD COLUMN IF NOT EXISTS event_id String
`
if err := conn.Exec(context.Background(), alterQuery); err != nil {
return nil, fmt.Errorf("failed to alter table: %w", err)
}
return &ClickHouseRepository{conn: conn}, nil
}
func (r *ClickHouseRepository) Create(log *domain.AuditLog) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if log.Timestamp.IsZero() {
log.Timestamp = time.Now()
}
query := `
INSERT INTO audit_logs (event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
return r.conn.Exec(ctx, query,
log.EventID,
log.Timestamp,
log.UserID,
log.TenantID,
log.EventType,
log.Status,
log.IPAddress,
log.UserAgent,
log.DeviceID,
log.Details,
)
}
func (r *ClickHouseRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor, tenantID string) ([]domain.AuditLog, error) {
if limit <= 0 {
limit = 50
}
query := `
SELECT event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details
FROM audit_logs
WHERE 1=1
`
args := make([]any, 0, 5)
if tenantID != "" {
query += " AND tenant_id = ?"
args = append(args, tenantID)
}
if cursor != nil {
query += `
AND ((timestamp < ?) OR (timestamp = ? AND event_id < ?))
`
args = append(args, cursor.Timestamp, cursor.Timestamp, cursor.EventID)
}
query += `
ORDER BY timestamp DESC, event_id DESC
LIMIT ?
`
args = append(args, limit)
rows, err := r.conn.Query(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to query audit logs: %w", err)
}
defer rows.Close()
var logs []domain.AuditLog
for rows.Next() {
var log domain.AuditLog
if err := rows.Scan(
&log.EventID,
&log.Timestamp,
&log.UserID,
&log.TenantID,
&log.EventType,
&log.Status,
&log.IPAddress,
&log.UserAgent,
&log.DeviceID,
&log.Details,
); err != nil {
return nil, fmt.Errorf("failed to scan audit log: %w", err)
}
logs = append(logs, log)
}
return logs, nil
}
func (r *ClickHouseRepository) FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]domain.AuditLog, error) {
if limit <= 0 {
limit = 100
}
query := `
SELECT event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details
FROM audit_logs
WHERE user_id = ? AND event_type IN (?)
ORDER BY timestamp DESC
LIMIT ?
`
rows, err := r.conn.Query(ctx, query, userID, eventTypes, limit)
if err != nil {
return nil, fmt.Errorf("failed to query audit logs by user/events: %w", err)
}
defer rows.Close()
var logs []domain.AuditLog
for rows.Next() {
var log domain.AuditLog
if err := rows.Scan(
&log.EventID,
&log.Timestamp,
&log.UserID,
&log.TenantID,
&log.EventType,
&log.Status,
&log.IPAddress,
&log.UserAgent,
&log.DeviceID,
&log.Details,
); err != nil {
return nil, fmt.Errorf("failed to scan audit log: %w", err)
}
logs = append(logs, log)
}
return logs, nil
}
func (r *ClickHouseRepository) Ping(ctx context.Context) error {
if r.conn == nil {
return fmt.Errorf("clickhouse connection is nil")
}
return r.conn.Ping(ctx)
}
func (r *ClickHouseRepository) CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
query := `
SELECT count()
FROM audit_logs
WHERE status = 'failure' AND timestamp >= ?
`
args := []any{since}
if tenantID != "" {
query += " AND JSONExtractString(details, 'tenant_id') = ?"
args = append(args, tenantID)
}
var count int64
err := r.conn.QueryRow(ctx, query, args...).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to count failures: %w", err)
}
return count, nil
}
func (r *ClickHouseRepository) CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
// We use uniqExact(session_id) to count unique sessions that had success events recently.
query := `
SELECT uniqExact(session_id)
FROM audit_logs
WHERE status = 'success' AND timestamp >= ? AND session_id != ''
`
args := []any{since}
if tenantID != "" {
query += " AND JSONExtractString(details, 'tenant_id') = ?"
args = append(args, tenantID)
}
var count int64
err := r.conn.QueryRow(ctx, query, args...).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to count active sessions: %w", err)
}
return count, nil
}