forked from baron/baron-sso
Resolve merge conflicts with main
This commit is contained in:
192
backend/internal/middleware/audit_middleware.go
Normal file
192
backend/internal/middleware/audit_middleware.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/utils"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditConfig struct {
|
||||
Repo domain.AuditRepository
|
||||
ExcludePaths map[string]struct{}
|
||||
BodyDump bool
|
||||
WorkerCount int
|
||||
QueueSize int
|
||||
}
|
||||
|
||||
func isNil(i any) bool {
|
||||
if i == nil {
|
||||
return true
|
||||
}
|
||||
v := reflect.ValueOf(i)
|
||||
return v.Kind() == reflect.Ptr && v.IsNil()
|
||||
}
|
||||
|
||||
// AuditMiddleware provides comprehensive audit logging for all requests.
|
||||
// It enforces strict logging for state-changing commands (POST, PUT, DELETE, PATCH)
|
||||
// and best-effort logging for queries (GET, HEAD, OPTIONS).
|
||||
func AuditMiddleware(config AuditConfig) fiber.Handler {
|
||||
// 0. Initialize Worker Pool for Async Logging
|
||||
if config.WorkerCount <= 0 {
|
||||
config.WorkerCount = 5 // Default workers
|
||||
}
|
||||
if config.QueueSize <= 0 {
|
||||
config.QueueSize = 1000 // Default queue size
|
||||
}
|
||||
|
||||
auditQueue := make(chan *domain.AuditLog, config.QueueSize)
|
||||
var once sync.Once
|
||||
|
||||
// Start workers only once
|
||||
once.Do(func() {
|
||||
for i := 0; i < config.WorkerCount; i++ {
|
||||
go func(workerID int) {
|
||||
slog.Debug("Audit worker started", "id", workerID)
|
||||
for log := range auditQueue {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("Audit worker panic recovery", "reason", r, "req_id", log.EventID)
|
||||
}
|
||||
}()
|
||||
if err := config.Repo.Create(log); err != nil {
|
||||
slog.Warn("Failed to write async audit log", "error", err, "req_id", log.EventID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
})
|
||||
|
||||
// Default methods classification
|
||||
writeMethods := map[string]struct{}{
|
||||
fiber.MethodPost: {},
|
||||
fiber.MethodPut: {},
|
||||
fiber.MethodPatch: {},
|
||||
fiber.MethodDelete: {},
|
||||
}
|
||||
|
||||
if config.ExcludePaths == nil {
|
||||
config.ExcludePaths = map[string]struct{}{}
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
// 1. Check exclusions
|
||||
if _, excluded := config.ExcludePaths[c.Path()]; excluded {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// 2. Setup context variables
|
||||
start := time.Now()
|
||||
reqID := c.Get("X-Request-Id")
|
||||
if reqID == "" {
|
||||
reqID = uuid.New().String()
|
||||
c.Set("X-Request-Id", reqID)
|
||||
}
|
||||
|
||||
// 3. Process Request
|
||||
err := c.Next()
|
||||
|
||||
// 4. Gather Metrics & Context
|
||||
latency := time.Since(start)
|
||||
status := c.Response().StatusCode()
|
||||
|
||||
// If Fiber handler returned an error, status might default to 500 or be in the error
|
||||
if err != nil {
|
||||
if fiberErr, ok := err.(*fiber.Error); ok {
|
||||
status = fiberErr.Code
|
||||
} else {
|
||||
status = fiber.StatusInternalServerError
|
||||
}
|
||||
}
|
||||
|
||||
statusText := "success"
|
||||
if status >= fiber.StatusBadRequest {
|
||||
statusText = "failure"
|
||||
}
|
||||
|
||||
// 5. Extract User Context (populated by AuthMiddleware/TenantGuard)
|
||||
userID, _ := c.Locals("user_id").(string)
|
||||
loginID, _ := c.Locals("login_id").(string)
|
||||
tenantID, _ := c.Locals("tenant_id").(string)
|
||||
|
||||
// 6. Capture & Mask Body
|
||||
var maskedBody string
|
||||
if config.BodyDump {
|
||||
if c.Method() != fiber.MethodGet && c.Method() != fiber.MethodHead {
|
||||
bodyBytes := c.Body()
|
||||
if len(bodyBytes) > 0 {
|
||||
maskedBytes := utils.MaskSensitiveJSON(bodyBytes)
|
||||
maskedBody = string(maskedBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 7. Construct Details JSON
|
||||
details := map[string]any{
|
||||
"request_id": reqID,
|
||||
"method": c.Method(),
|
||||
"path": c.Path(),
|
||||
"status": status,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
"login_id": loginID,
|
||||
"tenant_id": tenantID,
|
||||
"request_body": maskedBody,
|
||||
}
|
||||
if err != nil {
|
||||
details["error"] = err.Error()
|
||||
}
|
||||
|
||||
detailsJSON, _ := json.Marshal(details)
|
||||
|
||||
// 8. Create Audit Log Object
|
||||
auditLog := &domain.AuditLog{
|
||||
EventID: reqID,
|
||||
Timestamp: start,
|
||||
UserID: userID,
|
||||
EventType: fmt.Sprintf("%s %s", c.Method(), c.Path()),
|
||||
Status: statusText,
|
||||
IPAddress: c.IP(),
|
||||
UserAgent: c.Get("User-Agent"),
|
||||
Details: string(detailsJSON),
|
||||
}
|
||||
|
||||
// 9. Store Log (Policy Enforcement)
|
||||
_, isWrite := writeMethods[c.Method()]
|
||||
|
||||
if isNil(config.Repo) {
|
||||
if isWrite {
|
||||
slog.Error("Audit repository missing for command", "req_id", reqID)
|
||||
return fiber.NewError(fiber.StatusServiceUnavailable, "Audit system unavailable")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if isWrite {
|
||||
// Strict Mode: Synchronous write
|
||||
if createErr := config.Repo.Create(auditLog); createErr != nil {
|
||||
slog.Error("Failed to write audit log (sync)", "error", createErr, "req_id", reqID)
|
||||
return fiber.NewError(fiber.StatusServiceUnavailable, "Audit logging failed")
|
||||
}
|
||||
} else {
|
||||
// Best Effort: Load Shedding via Buffered Channel
|
||||
select {
|
||||
case auditQueue <- auditLog:
|
||||
// Successfully queued
|
||||
default:
|
||||
// Queue full -> DROP (Load Shedding)
|
||||
slog.Warn("Audit queue full, dropping log (load shedding)", "req_id", reqID, "path", c.Path())
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
117
backend/internal/middleware/audit_middleware_test.go
Normal file
117
backend/internal/middleware/audit_middleware_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockAuditRepository is a mock implementation of AuditRepository
|
||||
type MockAuditRepository struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockAuditRepository) Create(log *domain.AuditLog) error {
|
||||
args := m.Called(log)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAuditRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor) ([]domain.AuditLog, error) {
|
||||
args := m.Called(ctx, limit, cursor)
|
||||
return args.Get(0).([]domain.AuditLog), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAuditRepository) Ping(ctx context.Context) error {
|
||||
args := m.Called(ctx)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestAuditMiddleware(t *testing.T) {
|
||||
t.Run("POST request - Sync Success", func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockRepo := new(MockAuditRepository)
|
||||
|
||||
app.Use(AuditMiddleware(AuditConfig{
|
||||
Repo: mockRepo,
|
||||
BodyDump: true,
|
||||
}))
|
||||
|
||||
app.Post("/test", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
mockRepo.On("Create", mock.MatchedBy(func(log *domain.AuditLog) bool {
|
||||
var details map[string]any
|
||||
json.Unmarshal([]byte(log.Details), &details)
|
||||
return log.Status == "success" &&
|
||||
details["method"] == "POST" &&
|
||||
details["request_body"] == `{"password":"*****","user":"test"}`
|
||||
})).Return(nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", strings.NewReader(`{"user": "test", "password": "mypassword"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
mockRepo.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("POST request - Sync Failure (Strict Mode)", func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockRepo := new(MockAuditRepository)
|
||||
|
||||
app.Use(AuditMiddleware(AuditConfig{
|
||||
Repo: mockRepo,
|
||||
}))
|
||||
|
||||
app.Post("/test", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
mockRepo.On("Create", mock.Anything).Return(errors.New("db error"))
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
// Should return 503 because Audit failed on a Write method
|
||||
assert.Equal(t, fiber.StatusServiceUnavailable, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("GET request - Async Load Shedding", func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockRepo := new(MockAuditRepository)
|
||||
|
||||
// Set very small queue and no workers to force load shedding
|
||||
app.Use(AuditMiddleware(AuditConfig{
|
||||
Repo: mockRepo,
|
||||
QueueSize: 1,
|
||||
WorkerCount: 0, // This will be defaulted to 5 by the code, so let's use another way or just small queue
|
||||
}))
|
||||
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
// 1. First request fills the queue
|
||||
mockRepo.On("Create", mock.Anything).Return(nil)
|
||||
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
resp1, _ := app.Test(req1)
|
||||
assert.Equal(t, fiber.StatusOK, resp1.StatusCode)
|
||||
|
||||
// 2. Second request should be dropped (load shedding) if workers are slow
|
||||
// Since we can't easily pause workers without modifying code,
|
||||
// this test mostly ensures the non-blocking send doesn't hang.
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
resp2, _ := app.Test(req2)
|
||||
assert.Equal(t, fiber.StatusOK, resp2.StatusCode)
|
||||
})
|
||||
}
|
||||
@@ -1,117 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditRequiredConfig struct {
|
||||
Repo domain.AuditRepository
|
||||
ExcludePaths map[string]struct{}
|
||||
CommandMethods map[string]struct{}
|
||||
}
|
||||
|
||||
func isNil(i any) bool {
|
||||
if i == nil {
|
||||
return true
|
||||
}
|
||||
v := reflect.ValueOf(i)
|
||||
return v.Kind() == reflect.Ptr && v.IsNil()
|
||||
}
|
||||
|
||||
func RequireAudit(config AuditRequiredConfig) fiber.Handler {
|
||||
commandMethods := config.CommandMethods
|
||||
if len(commandMethods) == 0 {
|
||||
commandMethods = map[string]struct{}{
|
||||
fiber.MethodPost: {},
|
||||
fiber.MethodPut: {},
|
||||
fiber.MethodPatch: {},
|
||||
fiber.MethodDelete: {},
|
||||
}
|
||||
}
|
||||
|
||||
excludePaths := config.ExcludePaths
|
||||
if excludePaths == nil {
|
||||
excludePaths = map[string]struct{}{}
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
if _, ok := commandMethods[c.Method()]; !ok {
|
||||
return c.Next()
|
||||
}
|
||||
if _, excluded := excludePaths[c.Path()]; excluded {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
if isNil(config.Repo) {
|
||||
slog.Warn("audit repository is nil, skipping audit log creation", "path", c.Path())
|
||||
return c.Next() // Don't block the request, just skip audit
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
reqID := c.Get("X-Request-Id")
|
||||
if reqID == "" {
|
||||
reqID = uuid.New().String()
|
||||
c.Set("X-Request-Id", reqID)
|
||||
}
|
||||
|
||||
err := c.Next()
|
||||
latency := time.Since(start)
|
||||
|
||||
status := c.Response().StatusCode()
|
||||
if err != nil {
|
||||
if fiberErr, ok := err.(*fiber.Error); ok {
|
||||
status = fiberErr.Code
|
||||
} else {
|
||||
status = fiber.StatusInternalServerError
|
||||
}
|
||||
}
|
||||
|
||||
statusText := "success"
|
||||
if status >= fiber.StatusBadRequest {
|
||||
statusText = "failure"
|
||||
}
|
||||
|
||||
details := map[string]any{
|
||||
"request_id": reqID,
|
||||
"method": c.Method(),
|
||||
"path": c.Path(),
|
||||
"status": status,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
}
|
||||
if err != nil {
|
||||
details["error"] = err.Error()
|
||||
}
|
||||
|
||||
detailsJSON, jsonErr := json.Marshal(details)
|
||||
if jsonErr != nil {
|
||||
slog.Warn("failed to marshal audit details", "error", jsonErr, "req_id", reqID)
|
||||
}
|
||||
|
||||
auditLog := &domain.AuditLog{
|
||||
EventID: reqID,
|
||||
Timestamp: time.Now(),
|
||||
UserID: "",
|
||||
EventType: fmt.Sprintf("%s %s", c.Method(), c.Path()),
|
||||
Status: statusText,
|
||||
IPAddress: c.IP(),
|
||||
UserAgent: c.Get("User-Agent"),
|
||||
DeviceID: "",
|
||||
Details: string(detailsJSON),
|
||||
}
|
||||
|
||||
if createErr := config.Repo.Create(auditLog); createErr != nil {
|
||||
slog.Error("audit log write failed", "error", createErr, "req_id", reqID, "path", c.Path())
|
||||
return fiber.NewError(fiber.StatusServiceUnavailable, "audit logging unavailable")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user