package middleware import ( "baron-sso-backend/internal/domain" "baron-sso-backend/internal/utils" "encoding/json" "fmt" "log/slog" "maps" "reflect" "sync" "time" "github.com/gofiber/fiber/v2" "github.com/google/uuid" ) type AuditConfig struct { Repo domain.AuditRepository ExcludePaths map[string]struct{} ReadPaths map[string]struct{} BodyDump bool WorkerCount int QueueSize int } func isNil(i any) bool { if i == nil { return true } v := reflect.ValueOf(i) return v.Kind() == reflect.Pointer && v.IsNil() } // AuditMiddleware provides comprehensive audit logging for write requests by default. // Read requests are skipped unless they are explicitly allowlisted in ReadPaths. func AuditMiddleware(config AuditConfig) fiber.Handler { // 0. Initialize Worker Pool for Async Logging if config.WorkerCount <= 0 { config.WorkerCount = 5 // Default workers } if config.QueueSize <= 0 { config.QueueSize = 1000 // Default queue size } auditQueue := make(chan *domain.AuditLog, config.QueueSize) var once sync.Once // Start workers only once once.Do(func() { for i := 0; i < config.WorkerCount; i++ { go func(workerID int) { slog.Debug("Audit worker started", "id", workerID) for log := range auditQueue { func() { defer func() { if r := recover(); r != nil { slog.Error("Audit worker panic recovery", "reason", r, "req_id", log.EventID) } }() if err := config.Repo.Create(log); err != nil { slog.Warn("Failed to write async audit log", "error", err, "req_id", log.EventID) } }() } }(i) } }) // Default methods classification writeMethods := map[string]struct{}{ fiber.MethodPost: {}, fiber.MethodPut: {}, fiber.MethodPatch: {}, fiber.MethodDelete: {}, } if config.ExcludePaths == nil { config.ExcludePaths = map[string]struct{}{} } if config.ReadPaths == nil { config.ReadPaths = map[string]struct{}{} } return func(c *fiber.Ctx) error { // 1. Check exclusions if _, excluded := config.ExcludePaths[c.Path()]; excluded { return c.Next() } // 2. Setup context variables start := time.Now() reqID := c.Get("X-Request-Id") if reqID == "" { reqID = uuid.New().String() c.Set("X-Request-Id", reqID) } // 3. Process Request err := c.Next() // 4. Gather Metrics & Context latency := time.Since(start) status := c.Response().StatusCode() // If Fiber handler returned an error, status might default to 500 or be in the error if err != nil { if fiberErr, ok := err.(*fiber.Error); ok { status = fiberErr.Code } else { status = fiber.StatusInternalServerError } } statusText := "success" if status >= fiber.StatusBadRequest { statusText = "failure" } // 5. Extract User Context (populated by AuthMiddleware/TenantGuard) userID, _ := c.Locals("user_id").(string) loginID, _ := c.Locals("login_id").(string) tenantID, _ := c.Locals("tenant_id").(string) sessionID, _ := c.Locals("session_id").(string) clientIP := extractClientIP(c) // 6. Capture & Mask Body var maskedBody string if config.BodyDump { if c.Method() != fiber.MethodGet && c.Method() != fiber.MethodHead { bodyBytes := c.Body() if len(bodyBytes) > 0 { maskedBytes := utils.MaskSensitiveJSON(bodyBytes) maskedBody = string(maskedBytes) } } } // 7. Construct Details JSON details := map[string]any{ "request_id": reqID, "method": c.Method(), "path": c.Path(), "status": status, "latency_ms": latency.Milliseconds(), "login_id": loginID, "tenant_id": tenantID, "request_body": maskedBody, } // 핸들러에서 추가한 상세 정보를 병합합니다. if extra := c.Locals("audit_details_extra"); extra != nil { switch v := extra.(type) { case map[string]string: for key, value := range v { details[key] = value } case map[string]any: maps.Copy(details, v) } } if skipTimeline, ok := c.Locals("auth_timeline_skip").(bool); ok && skipTimeline { details["auth_timeline_skip"] = true } if sessionID != "" { details["session_id"] = sessionID } if approvedSessionID, ok := c.Locals("approved_session_id").(string); ok && approvedSessionID != "" { details["approved_session_id"] = approvedSessionID } if err != nil { details["error"] = err.Error() } detailsJSON, _ := json.Marshal(details) // 8. Create Audit Log Object auditLog := &domain.AuditLog{ EventID: reqID, Timestamp: start, UserID: userID, TenantID: tenantID, SessionID: sessionID, EventType: fmt.Sprintf("%s %s", c.Method(), c.Path()), Status: statusText, IPAddress: clientIP, UserAgent: c.Get("User-Agent"), Details: string(detailsJSON), } // 9. Store Log (Policy Enforcement) _, isWrite := writeMethods[c.Method()] _, allowRead := config.ReadPaths[c.Path()] if isNil(config.Repo) { if isWrite { slog.Warn("Audit repository missing for command, proceeding without audit log", "req_id", reqID, "method", c.Method(), "path", c.Path()) } return err } if isWrite { // Strict Mode: Synchronous write if createErr := config.Repo.Create(auditLog); createErr != nil { slog.Error("Failed to write audit log (sync)", "error", createErr, "req_id", reqID) return fiber.NewError(fiber.StatusServiceUnavailable, "Audit logging failed") } } else if allowRead { // Best Effort: Load Shedding via Buffered Channel select { case auditQueue <- auditLog: // Successfully queued default: // Queue full -> DROP (Load Shedding) slog.Warn("Audit queue full, dropping log (load shedding)", "req_id", reqID, "path", c.Path()) } } return err } } func extractClientIP(c *fiber.Ctx) string { return utils.ResolveClientIP(c.Get("X-Forwarded-For"), c.Get("X-Real-IP"), c.IP()) }