forked from baron/baron-sso
- Added support for fixed UUIDs during bulk registration (Search-first + ExternalID mapping) - Implemented idempotency and visibility restoration for soft-deleted users - Enhanced bulk upload UI to show 'New/Updated/Unchanged' status and modified fields - Added logic to reclaim identifiers (login_id) from colliding records - Added frontend E2E and backend unit tests for UUID integrity and conflict handling - Fixed i18n, formatting, and mock tests to satisfy code-check - Applied 'go fix' for 'omitzero' tags and general Go standards
225 lines
5.8 KiB
Go
225 lines
5.8 KiB
Go
package middleware
|
|
|
|
import (
|
|
"baron-sso-backend/internal/domain"
|
|
"baron-sso-backend/internal/utils"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"maps"
|
|
"reflect"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type AuditConfig struct {
|
|
Repo domain.AuditRepository
|
|
ExcludePaths map[string]struct{}
|
|
ReadPaths map[string]struct{}
|
|
BodyDump bool
|
|
WorkerCount int
|
|
QueueSize int
|
|
}
|
|
|
|
func isNil(i any) bool {
|
|
if i == nil {
|
|
return true
|
|
}
|
|
v := reflect.ValueOf(i)
|
|
return v.Kind() == reflect.Pointer && v.IsNil()
|
|
}
|
|
|
|
// AuditMiddleware provides comprehensive audit logging for write requests by default.
|
|
// Read requests are skipped unless they are explicitly allowlisted in ReadPaths.
|
|
func AuditMiddleware(config AuditConfig) fiber.Handler {
|
|
// 0. Initialize Worker Pool for Async Logging
|
|
if config.WorkerCount <= 0 {
|
|
config.WorkerCount = 5 // Default workers
|
|
}
|
|
if config.QueueSize <= 0 {
|
|
config.QueueSize = 1000 // Default queue size
|
|
}
|
|
|
|
auditQueue := make(chan *domain.AuditLog, config.QueueSize)
|
|
var once sync.Once
|
|
|
|
// Start workers only once
|
|
once.Do(func() {
|
|
for i := 0; i < config.WorkerCount; i++ {
|
|
go func(workerID int) {
|
|
slog.Debug("Audit worker started", "id", workerID)
|
|
for log := range auditQueue {
|
|
func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
slog.Error("Audit worker panic recovery", "reason", r, "req_id", log.EventID)
|
|
}
|
|
}()
|
|
if err := config.Repo.Create(log); err != nil {
|
|
slog.Warn("Failed to write async audit log", "error", err, "req_id", log.EventID)
|
|
}
|
|
}()
|
|
}
|
|
}(i)
|
|
}
|
|
})
|
|
|
|
// Default methods classification
|
|
writeMethods := map[string]struct{}{
|
|
fiber.MethodPost: {},
|
|
fiber.MethodPut: {},
|
|
fiber.MethodPatch: {},
|
|
fiber.MethodDelete: {},
|
|
}
|
|
|
|
if config.ExcludePaths == nil {
|
|
config.ExcludePaths = map[string]struct{}{}
|
|
}
|
|
if config.ReadPaths == nil {
|
|
config.ReadPaths = map[string]struct{}{}
|
|
}
|
|
|
|
return func(c *fiber.Ctx) error {
|
|
// 1. Check exclusions
|
|
if _, excluded := config.ExcludePaths[c.Path()]; excluded {
|
|
return c.Next()
|
|
}
|
|
|
|
// 2. Setup context variables
|
|
start := time.Now()
|
|
reqID := c.Get("X-Request-Id")
|
|
if reqID == "" {
|
|
reqID = uuid.New().String()
|
|
c.Set("X-Request-Id", reqID)
|
|
}
|
|
|
|
// 3. Process Request
|
|
err := c.Next()
|
|
|
|
// 4. Gather Metrics & Context
|
|
latency := time.Since(start)
|
|
status := c.Response().StatusCode()
|
|
|
|
// If Fiber handler returned an error, status might default to 500 or be in the error
|
|
if err != nil {
|
|
if fiberErr, ok := err.(*fiber.Error); ok {
|
|
status = fiberErr.Code
|
|
} else {
|
|
status = fiber.StatusInternalServerError
|
|
}
|
|
}
|
|
|
|
statusText := "success"
|
|
if status >= fiber.StatusBadRequest {
|
|
statusText = "failure"
|
|
}
|
|
|
|
// 5. Extract User Context (populated by AuthMiddleware/TenantGuard)
|
|
userID, _ := c.Locals("user_id").(string)
|
|
loginID, _ := c.Locals("login_id").(string)
|
|
tenantID, _ := c.Locals("tenant_id").(string)
|
|
sessionID, _ := c.Locals("session_id").(string)
|
|
clientIP := extractClientIP(c)
|
|
|
|
// 6. Capture & Mask Body
|
|
var maskedBody string
|
|
if config.BodyDump {
|
|
if c.Method() != fiber.MethodGet && c.Method() != fiber.MethodHead {
|
|
bodyBytes := c.Body()
|
|
if len(bodyBytes) > 0 {
|
|
maskedBytes := utils.MaskSensitiveJSON(bodyBytes)
|
|
maskedBody = string(maskedBytes)
|
|
}
|
|
}
|
|
}
|
|
|
|
// 7. Construct Details JSON
|
|
details := map[string]any{
|
|
"request_id": reqID,
|
|
"method": c.Method(),
|
|
"path": c.Path(),
|
|
"status": status,
|
|
"latency_ms": latency.Milliseconds(),
|
|
"login_id": loginID,
|
|
"tenant_id": tenantID,
|
|
"request_body": maskedBody,
|
|
}
|
|
// 핸들러에서 추가한 상세 정보를 병합합니다.
|
|
if extra := c.Locals("audit_details_extra"); extra != nil {
|
|
switch v := extra.(type) {
|
|
case map[string]string:
|
|
for key, value := range v {
|
|
details[key] = value
|
|
}
|
|
case map[string]any:
|
|
maps.Copy(details, v)
|
|
}
|
|
}
|
|
if skipTimeline, ok := c.Locals("auth_timeline_skip").(bool); ok && skipTimeline {
|
|
details["auth_timeline_skip"] = true
|
|
}
|
|
if sessionID != "" {
|
|
details["session_id"] = sessionID
|
|
}
|
|
if approvedSessionID, ok := c.Locals("approved_session_id").(string); ok && approvedSessionID != "" {
|
|
details["approved_session_id"] = approvedSessionID
|
|
}
|
|
if err != nil {
|
|
details["error"] = err.Error()
|
|
}
|
|
|
|
detailsJSON, _ := json.Marshal(details)
|
|
|
|
// 8. Create Audit Log Object
|
|
auditLog := &domain.AuditLog{
|
|
EventID: reqID,
|
|
Timestamp: start,
|
|
UserID: userID,
|
|
TenantID: tenantID,
|
|
SessionID: sessionID,
|
|
EventType: fmt.Sprintf("%s %s", c.Method(), c.Path()),
|
|
Status: statusText,
|
|
IPAddress: clientIP,
|
|
UserAgent: c.Get("User-Agent"),
|
|
Details: string(detailsJSON),
|
|
}
|
|
|
|
// 9. Store Log (Policy Enforcement)
|
|
_, isWrite := writeMethods[c.Method()]
|
|
_, allowRead := config.ReadPaths[c.Path()]
|
|
|
|
if isNil(config.Repo) {
|
|
if isWrite {
|
|
slog.Warn("Audit repository missing for command, proceeding without audit log", "req_id", reqID, "method", c.Method(), "path", c.Path())
|
|
}
|
|
return err
|
|
}
|
|
|
|
if isWrite {
|
|
// Strict Mode: Synchronous write
|
|
if createErr := config.Repo.Create(auditLog); createErr != nil {
|
|
slog.Error("Failed to write audit log (sync)", "error", createErr, "req_id", reqID)
|
|
return fiber.NewError(fiber.StatusServiceUnavailable, "Audit logging failed")
|
|
}
|
|
} else if allowRead {
|
|
// Best Effort: Load Shedding via Buffered Channel
|
|
select {
|
|
case auditQueue <- auditLog:
|
|
// Successfully queued
|
|
default:
|
|
// Queue full -> DROP (Load Shedding)
|
|
slog.Warn("Audit queue full, dropping log (load shedding)", "req_id", reqID, "path", c.Path())
|
|
}
|
|
}
|
|
|
|
return err
|
|
}
|
|
}
|
|
|
|
func extractClientIP(c *fiber.Ctx) string {
|
|
return utils.ResolveClientIP(c.Get("X-Forwarded-For"), c.Get("X-Real-IP"), c.IP())
|
|
}
|