1
0
forked from baron/baron-sso
Files
baron-sso/backend/internal/middleware/audit_middleware.go

193 lines
4.8 KiB
Go

package middleware
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/utils"
"encoding/json"
"fmt"
"log/slog"
"reflect"
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
)
type AuditConfig struct {
Repo domain.AuditRepository
ExcludePaths map[string]struct{}
BodyDump bool
WorkerCount int
QueueSize int
}
func isNil(i any) bool {
if i == nil {
return true
}
v := reflect.ValueOf(i)
return v.Kind() == reflect.Ptr && v.IsNil()
}
// AuditMiddleware provides comprehensive audit logging for all requests.
// It enforces strict logging for state-changing commands (POST, PUT, DELETE, PATCH)
// and best-effort logging for queries (GET, HEAD, OPTIONS).
func AuditMiddleware(config AuditConfig) fiber.Handler {
// 0. Initialize Worker Pool for Async Logging
if config.WorkerCount <= 0 {
config.WorkerCount = 5 // Default workers
}
if config.QueueSize <= 0 {
config.QueueSize = 1000 // Default queue size
}
auditQueue := make(chan *domain.AuditLog, config.QueueSize)
var once sync.Once
// Start workers only once
once.Do(func() {
for i := 0; i < config.WorkerCount; i++ {
go func(workerID int) {
slog.Debug("Audit worker started", "id", workerID)
for log := range auditQueue {
func() {
defer func() {
if r := recover(); r != nil {
slog.Error("Audit worker panic recovery", "reason", r, "req_id", log.EventID)
}
}()
if err := config.Repo.Create(log); err != nil {
slog.Warn("Failed to write async audit log", "error", err, "req_id", log.EventID)
}
}()
}
}(i)
}
})
// Default methods classification
writeMethods := map[string]struct{}{
fiber.MethodPost: {},
fiber.MethodPut: {},
fiber.MethodPatch: {},
fiber.MethodDelete: {},
}
if config.ExcludePaths == nil {
config.ExcludePaths = map[string]struct{}{}
}
return func(c *fiber.Ctx) error {
// 1. Check exclusions
if _, excluded := config.ExcludePaths[c.Path()]; excluded {
return c.Next()
}
// 2. Setup context variables
start := time.Now()
reqID := c.Get("X-Request-Id")
if reqID == "" {
reqID = uuid.New().String()
c.Set("X-Request-Id", reqID)
}
// 3. Process Request
err := c.Next()
// 4. Gather Metrics & Context
latency := time.Since(start)
status := c.Response().StatusCode()
// If Fiber handler returned an error, status might default to 500 or be in the error
if err != nil {
if fiberErr, ok := err.(*fiber.Error); ok {
status = fiberErr.Code
} else {
status = fiber.StatusInternalServerError
}
}
statusText := "success"
if status >= fiber.StatusBadRequest {
statusText = "failure"
}
// 5. Extract User Context (populated by AuthMiddleware/TenantGuard)
userID, _ := c.Locals("user_id").(string)
loginID, _ := c.Locals("login_id").(string)
tenantID, _ := c.Locals("tenant_id").(string)
// 6. Capture & Mask Body
var maskedBody string
if config.BodyDump {
if c.Method() != fiber.MethodGet && c.Method() != fiber.MethodHead {
bodyBytes := c.Body()
if len(bodyBytes) > 0 {
maskedBytes := utils.MaskSensitiveJSON(bodyBytes)
maskedBody = string(maskedBytes)
}
}
}
// 7. Construct Details JSON
details := map[string]any{
"request_id": reqID,
"method": c.Method(),
"path": c.Path(),
"status": status,
"latency_ms": latency.Milliseconds(),
"login_id": loginID,
"tenant_id": tenantID,
"request_body": maskedBody,
}
if err != nil {
details["error"] = err.Error()
}
detailsJSON, _ := json.Marshal(details)
// 8. Create Audit Log Object
auditLog := &domain.AuditLog{
EventID: reqID,
Timestamp: start,
UserID: userID,
EventType: fmt.Sprintf("%s %s", c.Method(), c.Path()),
Status: statusText,
IPAddress: c.IP(),
UserAgent: c.Get("User-Agent"),
Details: string(detailsJSON),
}
// 9. Store Log (Policy Enforcement)
_, isWrite := writeMethods[c.Method()]
if isNil(config.Repo) {
if isWrite {
slog.Error("Audit repository missing for command", "req_id", reqID)
return fiber.NewError(fiber.StatusServiceUnavailable, "Audit system unavailable")
}
return err
}
if isWrite {
// Strict Mode: Synchronous write
if createErr := config.Repo.Create(auditLog); createErr != nil {
slog.Error("Failed to write audit log (sync)", "error", createErr, "req_id", reqID)
return fiber.NewError(fiber.StatusServiceUnavailable, "Audit logging failed")
}
} else {
// Best Effort: Load Shedding via Buffered Channel
select {
case auditQueue <- auditLog:
// Successfully queued
default:
// Queue full -> DROP (Load Shedding)
slog.Warn("Audit queue full, dropping log (load shedding)", "req_id", reqID, "path", c.Path())
}
}
return err
}
}