1
0
forked from baron/baron-sso
Files
baron-sso/backend/internal/middleware/audit_middleware_test.go

117 lines
3.3 KiB
Go

package middleware
import (
"baron-sso-backend/internal/domain"
"context"
"encoding/json"
"errors"
"net/http/httptest"
"strings"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// MockAuditRepository is a mock implementation of AuditRepository
type MockAuditRepository struct {
mock.Mock
}
func (m *MockAuditRepository) Create(log *domain.AuditLog) error {
args := m.Called(log)
return args.Error(0)
}
func (m *MockAuditRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor) ([]domain.AuditLog, error) {
args := m.Called(ctx, limit, cursor)
return args.Get(0).([]domain.AuditLog), args.Error(1)
}
func (m *MockAuditRepository) Ping(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
func TestAuditMiddleware(t *testing.T) {
t.Run("POST request - Sync Success", func(t *testing.T) {
app := fiber.New()
mockRepo := new(MockAuditRepository)
app.Use(AuditMiddleware(AuditConfig{
Repo: mockRepo,
BodyDump: true,
}))
app.Post("/test", func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
mockRepo.On("Create", mock.MatchedBy(func(log *domain.AuditLog) bool {
var details map[string]any
json.Unmarshal([]byte(log.Details), &details)
return log.Status == "success" &&
details["method"] == "POST" &&
details["request_body"] == `{"password":"*****","user":"test"}`
})).Return(nil)
req := httptest.NewRequest("POST", "/test", strings.NewReader(`{"user": "test", "password": "mypassword"}`))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
mockRepo.AssertExpectations(t)
})
t.Run("POST request - Sync Failure (Strict Mode)", func(t *testing.T) {
app := fiber.New()
mockRepo := new(MockAuditRepository)
app.Use(AuditMiddleware(AuditConfig{
Repo: mockRepo,
}))
app.Post("/test", func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
mockRepo.On("Create", mock.Anything).Return(errors.New("db error"))
req := httptest.NewRequest("POST", "/test", nil)
resp, _ := app.Test(req)
// Should return 503 because Audit failed on a Write method
assert.Equal(t, fiber.StatusServiceUnavailable, resp.StatusCode)
})
t.Run("GET request - Async Load Shedding", func(t *testing.T) {
app := fiber.New()
mockRepo := new(MockAuditRepository)
// Set very small queue and no workers to force load shedding
app.Use(AuditMiddleware(AuditConfig{
Repo: mockRepo,
QueueSize: 1,
WorkerCount: 0, // This will be defaulted to 5 by the code, so let's use another way or just small queue
}))
app.Get("/test", func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
// 1. First request fills the queue
mockRepo.On("Create", mock.Anything).Return(nil)
req1 := httptest.NewRequest("GET", "/test", nil)
resp1, _ := app.Test(req1)
assert.Equal(t, fiber.StatusOK, resp1.StatusCode)
// 2. Second request should be dropped (load shedding) if workers are slow
// Since we can't easily pause workers without modifying code,
// this test mostly ensures the non-blocking send doesn't hang.
req2 := httptest.NewRequest("GET", "/test", nil)
resp2, _ := app.Test(req2)
assert.Equal(t, fiber.StatusOK, resp2.StatusCode)
})
}