forked from baron/baron-sso
266 lines
7.4 KiB
Go
266 lines
7.4 KiB
Go
package middleware
|
|
|
|
import (
|
|
"baron-sso-backend/internal/domain"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
)
|
|
|
|
// MockAuditRepository is a mock implementation of AuditRepository
|
|
type MockAuditRepository struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockAuditRepository) Create(log *domain.AuditLog) error {
|
|
args := m.Called(log)
|
|
return args.Error(0)
|
|
}
|
|
|
|
func (m *MockAuditRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor, tenantID string) ([]domain.AuditLog, error) {
|
|
args := m.Called(ctx, limit, cursor, tenantID)
|
|
return args.Get(0).([]domain.AuditLog), args.Error(1)
|
|
}
|
|
|
|
func (m *MockAuditRepository) FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]domain.AuditLog, error) {
|
|
args := m.Called(ctx, userID, eventTypes, limit)
|
|
return args.Get(0).([]domain.AuditLog), args.Error(1)
|
|
}
|
|
|
|
func (m *MockAuditRepository) CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
|
|
return 0, nil
|
|
}
|
|
|
|
func (m *MockAuditRepository) CountEventsSince(ctx context.Context, since time.Time) (int64, error) {
|
|
return 0, nil
|
|
}
|
|
|
|
func (m *MockAuditRepository) CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
|
|
return 0, nil
|
|
}
|
|
|
|
func (m *MockAuditRepository) Ping(ctx context.Context) error {
|
|
args := m.Called(ctx)
|
|
return args.Error(0)
|
|
}
|
|
|
|
type recordingAuditRepository struct {
|
|
mu sync.Mutex
|
|
count int
|
|
}
|
|
|
|
func (r *recordingAuditRepository) Create(log *domain.AuditLog) error {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.count++
|
|
return nil
|
|
}
|
|
|
|
func (r *recordingAuditRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor, tenantID string) ([]domain.AuditLog, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (r *recordingAuditRepository) FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]domain.AuditLog, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (r *recordingAuditRepository) CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
|
|
return 0, nil
|
|
}
|
|
|
|
func (r *recordingAuditRepository) CountEventsSince(ctx context.Context, since time.Time) (int64, error) {
|
|
return 0, nil
|
|
}
|
|
|
|
func (r *recordingAuditRepository) CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
|
|
return 0, nil
|
|
}
|
|
|
|
func (r *recordingAuditRepository) Ping(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func (r *recordingAuditRepository) Calls() int {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
return r.count
|
|
}
|
|
|
|
func TestAuditMiddleware(t *testing.T) {
|
|
t.Run("POST request - Sync Success", func(t *testing.T) {
|
|
app := fiber.New()
|
|
mockRepo := new(MockAuditRepository)
|
|
|
|
app.Use(AuditMiddleware(AuditConfig{
|
|
Repo: mockRepo,
|
|
BodyDump: true,
|
|
}))
|
|
|
|
app.Post("/test", func(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusOK)
|
|
})
|
|
|
|
mockRepo.On("Create", mock.MatchedBy(func(log *domain.AuditLog) bool {
|
|
var details map[string]any
|
|
json.Unmarshal([]byte(log.Details), &details)
|
|
return log.Status == "success" &&
|
|
details["method"] == "POST" &&
|
|
details["request_body"] == `{"password":"*****","user":"test"}`
|
|
})).Return(nil)
|
|
|
|
req := httptest.NewRequest("POST", "/test", strings.NewReader(`{"user": "test", "password": "mypassword"}`))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, _ := app.Test(req)
|
|
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
|
|
mockRepo.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("POST request - Merge extra audit details", func(t *testing.T) {
|
|
app := fiber.New()
|
|
mockRepo := new(MockAuditRepository)
|
|
|
|
app.Use(AuditMiddleware(AuditConfig{
|
|
Repo: mockRepo,
|
|
}))
|
|
|
|
app.Post("/test", func(c *fiber.Ctx) error {
|
|
c.Locals("tenant_id", "tenant-a")
|
|
c.Locals("audit_details_extra", map[string]any{
|
|
"client_id": "rp-1",
|
|
"client_name": "Demo App",
|
|
})
|
|
c.Locals("auth_timeline_skip", true)
|
|
return c.SendStatus(fiber.StatusOK)
|
|
})
|
|
|
|
mockRepo.On("Create", mock.MatchedBy(func(log *domain.AuditLog) bool {
|
|
var details map[string]any
|
|
if err := json.Unmarshal([]byte(log.Details), &details); err != nil {
|
|
return false
|
|
}
|
|
if details["client_id"] != "rp-1" {
|
|
return false
|
|
}
|
|
if details["client_name"] != "Demo App" {
|
|
return false
|
|
}
|
|
if log.TenantID != "tenant-a" || details["tenant_id"] != "tenant-a" {
|
|
return false
|
|
}
|
|
skip, ok := details["auth_timeline_skip"].(bool)
|
|
return ok && skip
|
|
})).Return(nil)
|
|
|
|
req := httptest.NewRequest("POST", "/test", nil)
|
|
resp, _ := app.Test(req)
|
|
|
|
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
|
|
mockRepo.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("POST request - Prefer public forwarded IP", func(t *testing.T) {
|
|
app := fiber.New()
|
|
mockRepo := new(MockAuditRepository)
|
|
|
|
app.Use(AuditMiddleware(AuditConfig{
|
|
Repo: mockRepo,
|
|
}))
|
|
|
|
app.Post("/test", func(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusOK)
|
|
})
|
|
|
|
mockRepo.On("Create", mock.MatchedBy(func(log *domain.AuditLog) bool {
|
|
return log.IPAddress == "203.0.113.25"
|
|
})).Return(nil)
|
|
|
|
req := httptest.NewRequest("POST", "/test", nil)
|
|
req.Header.Set("X-Forwarded-For", "100.100.100.1, 203.0.113.25")
|
|
|
|
resp, _ := app.Test(req)
|
|
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
|
|
mockRepo.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("POST request - Sync Failure (Strict Mode)", func(t *testing.T) {
|
|
app := fiber.New()
|
|
mockRepo := new(MockAuditRepository)
|
|
|
|
app.Use(AuditMiddleware(AuditConfig{
|
|
Repo: mockRepo,
|
|
}))
|
|
|
|
app.Post("/test", func(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusOK)
|
|
})
|
|
|
|
mockRepo.On("Create", mock.Anything).Return(errors.New("db error"))
|
|
|
|
req := httptest.NewRequest("POST", "/test", nil)
|
|
resp, _ := app.Test(req)
|
|
|
|
// Should return 503 because Audit failed on a Write method
|
|
assert.Equal(t, fiber.StatusServiceUnavailable, resp.StatusCode)
|
|
})
|
|
|
|
t.Run("GET request - Async Load Shedding", func(t *testing.T) {
|
|
app := fiber.New()
|
|
mockRepo := new(MockAuditRepository)
|
|
|
|
// Set very small queue and no workers to force load shedding
|
|
app.Use(AuditMiddleware(AuditConfig{
|
|
Repo: mockRepo,
|
|
QueueSize: 1,
|
|
WorkerCount: 0, // This will be defaulted to 5 by the code, so let's use another way or just small queue
|
|
}))
|
|
|
|
app.Get("/test", func(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusOK)
|
|
})
|
|
|
|
// 1. First request fills the queue
|
|
mockRepo.On("Create", mock.Anything).Return(nil)
|
|
|
|
req1 := httptest.NewRequest("GET", "/test", nil)
|
|
resp1, _ := app.Test(req1)
|
|
assert.Equal(t, fiber.StatusOK, resp1.StatusCode)
|
|
|
|
// 2. Second request should be dropped (load shedding) if workers are slow
|
|
// Since we can't easily pause workers without modifying code,
|
|
// this test mostly ensures the non-blocking send doesn't hang.
|
|
req2 := httptest.NewRequest("GET", "/test", nil)
|
|
resp2, _ := app.Test(req2)
|
|
assert.Equal(t, fiber.StatusOK, resp2.StatusCode)
|
|
})
|
|
|
|
t.Run("GET request - Read audit is skipped by default", func(t *testing.T) {
|
|
app := fiber.New()
|
|
repo := &recordingAuditRepository{}
|
|
|
|
app.Use(AuditMiddleware(AuditConfig{
|
|
Repo: repo,
|
|
}))
|
|
|
|
app.Get("/test", func(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusOK)
|
|
})
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
resp, _ := app.Test(req)
|
|
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
assert.Equal(t, 0, repo.Calls())
|
|
})
|
|
}
|