package middleware import ( "baron-sso-backend/internal/domain" "context" "encoding/json" "errors" "net/http/httptest" "strings" "sync" "testing" "time" "github.com/gofiber/fiber/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) // MockAuditRepository is a mock implementation of AuditRepository type MockAuditRepository struct { mock.Mock } func (m *MockAuditRepository) Create(log *domain.AuditLog) error { args := m.Called(log) return args.Error(0) } func (m *MockAuditRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor, tenantID string) ([]domain.AuditLog, error) { args := m.Called(ctx, limit, cursor, tenantID) return args.Get(0).([]domain.AuditLog), args.Error(1) } func (m *MockAuditRepository) FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]domain.AuditLog, error) { args := m.Called(ctx, userID, eventTypes, limit) return args.Get(0).([]domain.AuditLog), args.Error(1) } func (m *MockAuditRepository) CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error) { return 0, nil } func (m *MockAuditRepository) CountEventsSince(ctx context.Context, since time.Time) (int64, error) { return 0, nil } func (m *MockAuditRepository) CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error) { return 0, nil } func (m *MockAuditRepository) Ping(ctx context.Context) error { args := m.Called(ctx) return args.Error(0) } type recordingAuditRepository struct { mu sync.Mutex count int } func (r *recordingAuditRepository) Create(log *domain.AuditLog) error { r.mu.Lock() defer r.mu.Unlock() r.count++ return nil } func (r *recordingAuditRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor, tenantID string) ([]domain.AuditLog, error) { return nil, nil } func (r *recordingAuditRepository) FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]domain.AuditLog, error) { return nil, nil } func (r *recordingAuditRepository) CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error) { return 0, nil } func (r *recordingAuditRepository) CountEventsSince(ctx context.Context, since time.Time) (int64, error) { return 0, nil } func (r *recordingAuditRepository) CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error) { return 0, nil } func (r *recordingAuditRepository) Ping(ctx context.Context) error { return nil } func (r *recordingAuditRepository) Calls() int { r.mu.Lock() defer r.mu.Unlock() return r.count } func TestAuditMiddleware(t *testing.T) { t.Run("POST request - Sync Success", func(t *testing.T) { app := fiber.New() mockRepo := new(MockAuditRepository) app.Use(AuditMiddleware(AuditConfig{ Repo: mockRepo, BodyDump: true, })) app.Post("/test", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) mockRepo.On("Create", mock.MatchedBy(func(log *domain.AuditLog) bool { var details map[string]any json.Unmarshal([]byte(log.Details), &details) return log.Status == "success" && details["method"] == "POST" && details["request_body"] == `{"password":"*****","user":"test"}` })).Return(nil) req := httptest.NewRequest("POST", "/test", strings.NewReader(`{"user": "test", "password": "mypassword"}`)) req.Header.Set("Content-Type", "application/json") resp, _ := app.Test(req) assert.Equal(t, fiber.StatusOK, resp.StatusCode) mockRepo.AssertExpectations(t) }) t.Run("POST request - Merge extra audit details", func(t *testing.T) { app := fiber.New() mockRepo := new(MockAuditRepository) app.Use(AuditMiddleware(AuditConfig{ Repo: mockRepo, })) app.Post("/test", func(c *fiber.Ctx) error { c.Locals("tenant_id", "tenant-a") c.Locals("audit_details_extra", map[string]any{ "client_id": "rp-1", "client_name": "Demo App", }) c.Locals("auth_timeline_skip", true) return c.SendStatus(fiber.StatusOK) }) mockRepo.On("Create", mock.MatchedBy(func(log *domain.AuditLog) bool { var details map[string]any if err := json.Unmarshal([]byte(log.Details), &details); err != nil { return false } if details["client_id"] != "rp-1" { return false } if details["client_name"] != "Demo App" { return false } if log.TenantID != "tenant-a" || details["tenant_id"] != "tenant-a" { return false } skip, ok := details["auth_timeline_skip"].(bool) return ok && skip })).Return(nil) req := httptest.NewRequest("POST", "/test", nil) resp, _ := app.Test(req) assert.Equal(t, fiber.StatusOK, resp.StatusCode) mockRepo.AssertExpectations(t) }) t.Run("POST request - Prefer public forwarded IP", func(t *testing.T) { app := fiber.New() mockRepo := new(MockAuditRepository) app.Use(AuditMiddleware(AuditConfig{ Repo: mockRepo, })) app.Post("/test", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) mockRepo.On("Create", mock.MatchedBy(func(log *domain.AuditLog) bool { return log.IPAddress == "203.0.113.25" })).Return(nil) req := httptest.NewRequest("POST", "/test", nil) req.Header.Set("X-Forwarded-For", "100.100.100.1, 203.0.113.25") resp, _ := app.Test(req) assert.Equal(t, fiber.StatusOK, resp.StatusCode) mockRepo.AssertExpectations(t) }) t.Run("POST request - Sync Failure (Strict Mode)", func(t *testing.T) { app := fiber.New() mockRepo := new(MockAuditRepository) app.Use(AuditMiddleware(AuditConfig{ Repo: mockRepo, })) app.Post("/test", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) mockRepo.On("Create", mock.Anything).Return(errors.New("db error")) req := httptest.NewRequest("POST", "/test", nil) resp, _ := app.Test(req) // Should return 503 because Audit failed on a Write method assert.Equal(t, fiber.StatusServiceUnavailable, resp.StatusCode) }) t.Run("GET request - Async Load Shedding", func(t *testing.T) { app := fiber.New() mockRepo := new(MockAuditRepository) // Set very small queue and no workers to force load shedding app.Use(AuditMiddleware(AuditConfig{ Repo: mockRepo, QueueSize: 1, WorkerCount: 0, // This will be defaulted to 5 by the code, so let's use another way or just small queue })) app.Get("/test", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) // 1. First request fills the queue mockRepo.On("Create", mock.Anything).Return(nil) req1 := httptest.NewRequest("GET", "/test", nil) resp1, _ := app.Test(req1) assert.Equal(t, fiber.StatusOK, resp1.StatusCode) // 2. Second request should be dropped (load shedding) if workers are slow // Since we can't easily pause workers without modifying code, // this test mostly ensures the non-blocking send doesn't hang. req2 := httptest.NewRequest("GET", "/test", nil) resp2, _ := app.Test(req2) assert.Equal(t, fiber.StatusOK, resp2.StatusCode) }) t.Run("GET request - Read audit is skipped by default", func(t *testing.T) { app := fiber.New() repo := &recordingAuditRepository{} app.Use(AuditMiddleware(AuditConfig{ Repo: repo, })) app.Get("/test", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) req := httptest.NewRequest("GET", "/test", nil) resp, _ := app.Test(req) assert.Equal(t, fiber.StatusOK, resp.StatusCode) time.Sleep(50 * time.Millisecond) assert.Equal(t, 0, repo.Calls()) }) }