package middleware import ( "baron-sso-backend/internal/domain" "encoding/json" "fmt" "log/slog" "reflect" "time" "github.com/gofiber/fiber/v2" "github.com/google/uuid" ) type AuditRequiredConfig struct { Repo domain.AuditRepository ExcludePaths map[string]struct{} CommandMethods map[string]struct{} } func isNil(i any) bool { if i == nil { return true } v := reflect.ValueOf(i) return v.Kind() == reflect.Ptr && v.IsNil() } func RequireAudit(config AuditRequiredConfig) fiber.Handler { commandMethods := config.CommandMethods if len(commandMethods) == 0 { commandMethods = map[string]struct{}{ fiber.MethodPost: {}, fiber.MethodPut: {}, fiber.MethodPatch: {}, fiber.MethodDelete: {}, } } excludePaths := config.ExcludePaths if excludePaths == nil { excludePaths = map[string]struct{}{} } return func(c *fiber.Ctx) error { if _, ok := commandMethods[c.Method()]; !ok { return c.Next() } if _, excluded := excludePaths[c.Path()]; excluded { return c.Next() } if isNil(config.Repo) { slog.Warn("audit repository is nil, skipping audit log creation", "path", c.Path()) return c.Next() // Don't block the request, just skip audit } start := time.Now() reqID := c.Get("X-Request-Id") if reqID == "" { reqID = uuid.New().String() c.Set("X-Request-Id", reqID) } err := c.Next() latency := time.Since(start) status := c.Response().StatusCode() if err != nil { if fiberErr, ok := err.(*fiber.Error); ok { status = fiberErr.Code } else { status = fiber.StatusInternalServerError } } statusText := "success" if status >= fiber.StatusBadRequest { statusText = "failure" } details := map[string]any{ "request_id": reqID, "method": c.Method(), "path": c.Path(), "status": status, "latency_ms": latency.Milliseconds(), } if err != nil { details["error"] = err.Error() } detailsJSON, jsonErr := json.Marshal(details) if jsonErr != nil { slog.Warn("failed to marshal audit details", "error", jsonErr, "req_id", reqID) } auditLog := &domain.AuditLog{ EventID: reqID, Timestamp: time.Now(), UserID: "", EventType: fmt.Sprintf("%s %s", c.Method(), c.Path()), Status: statusText, IPAddress: c.IP(), UserAgent: c.Get("User-Agent"), DeviceID: "", Details: string(detailsJSON), } if createErr := config.Repo.Create(auditLog); createErr != nil { slog.Error("audit log write failed", "error", createErr, "req_id", reqID, "path", c.Path()) return fiber.NewError(fiber.StatusServiceUnavailable, "audit logging unavailable") } return err } }