첫 커밋: 로컬 프로젝트 업로드
This commit is contained in:
121
baron-sso/backend/internal/middleware/api_key_auth.go
Normal file
121
baron-sso/backend/internal/middleware/api_key_auth.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ApiKeyAuthConfig struct {
|
||||
DB *gorm.DB
|
||||
}
|
||||
|
||||
func ApiKeyAuth(config ApiKeyAuthConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// 1. 헤더에서 ID와 Secret 추출
|
||||
clientID := c.Get("X-Baron-Key-ID")
|
||||
plainSecret := c.Get("X-Baron-Key-Secret")
|
||||
|
||||
// 헤더가 둘 다 없으면 API Key 인증 시도가 아닌 것으로 간주하고 다음으로 넘김 (UI 세션 등을 위해)
|
||||
if clientID == "" && plainSecret == "" {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
if clientID == "" || plainSecret == "" {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
||||
"error": "API Key ID or Secret is missing",
|
||||
})
|
||||
}
|
||||
|
||||
// 2. DB에서 ClientID로 키 정보 조회
|
||||
var apiKey domain.ApiKey
|
||||
if err := config.DB.Where("client_id = ? AND status = ?", clientID, "active").First(&apiKey).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
||||
"error": "Invalid or inactive API Key",
|
||||
})
|
||||
}
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
"error": "Database error during authentication",
|
||||
})
|
||||
}
|
||||
|
||||
// 3. Secret 해시 검증
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(apiKey.ClientSecretHash), []byte(plainSecret)); err != nil {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
||||
"error": "Invalid API Secret",
|
||||
})
|
||||
}
|
||||
|
||||
// 4. (비동기) 마지막 사용 시간 업데이트
|
||||
go func(id string) {
|
||||
now := time.Now()
|
||||
config.DB.Model(&domain.ApiKey{}).Where("id = ?", id).Update("last_used_at", &now)
|
||||
}(apiKey.ID)
|
||||
|
||||
// 5. 컨텍스트에 권한 정보 저장
|
||||
c.Locals("apiKeyName", apiKey.Name)
|
||||
c.Locals("apiScopes", apiKey.Scopes)
|
||||
|
||||
// 6. Scope 기반 권한 검증 (RBAC)
|
||||
if !validateScope(c.Method(), c.Path(), apiKey.Scopes) {
|
||||
slog.Warn("API Key scope insufficient", "name", apiKey.Name, "method", c.Method(), "path", c.Path(), "has_scopes", apiKey.Scopes)
|
||||
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
|
||||
"error": "Insufficient permissions (Scope mismatch)",
|
||||
})
|
||||
}
|
||||
|
||||
slog.Debug("API Key authenticated and authorized", "name", apiKey.Name, "path", c.Path())
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// validateScope - 요청된 메서드와 경로가 허용된 Scopes에 포함되는지 검사합니다.
|
||||
func validateScope(method, path string, rawScopes string) bool {
|
||||
scopes := strings.Fields(rawScopes)
|
||||
scopeMap := make(map[string]bool)
|
||||
for _, s := range scopes {
|
||||
scopeMap[s] = true
|
||||
}
|
||||
|
||||
if strings.Contains(path, "/integrations/org-context") {
|
||||
return method == fiber.MethodGet && scopeMap["org-context:read"]
|
||||
}
|
||||
|
||||
// 1. 감사 로그 관련 (audit:*)
|
||||
if strings.Contains(path, "/admin/audit") || strings.Contains(path, "/v1/audit") {
|
||||
if method == fiber.MethodGet {
|
||||
return scopeMap["audit:read"]
|
||||
}
|
||||
return scopeMap["audit:write"]
|
||||
}
|
||||
|
||||
// 2. 사용자 관리 관련 (user:*)
|
||||
if strings.Contains(path, "/admin/users") {
|
||||
if method == fiber.MethodGet {
|
||||
return scopeMap["user:read"]
|
||||
}
|
||||
return scopeMap["user:write"]
|
||||
}
|
||||
|
||||
// 3. 테넌트 관리 관련 (tenant:*)
|
||||
if strings.Contains(path, "/admin/tenants") || strings.Contains(path, "/admin/relying-parties") {
|
||||
if method == fiber.MethodGet {
|
||||
return scopeMap["tenant:read"]
|
||||
}
|
||||
return scopeMap["tenant:write"]
|
||||
}
|
||||
|
||||
// 4. API 체크 등 공통 (기본 허용)
|
||||
if strings.HasSuffix(path, "/admin/check") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
15
baron-sso/backend/internal/middleware/api_key_auth_test.go
Normal file
15
baron-sso/backend/internal/middleware/api_key_auth_test.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestValidateScope_AllowsOrgContextReadOnly(t *testing.T) {
|
||||
require.True(t, validateScope(fiber.MethodGet, "/api/v1/integrations/org-context", "org-context:read"))
|
||||
require.False(t, validateScope(fiber.MethodPost, "/api/v1/integrations/org-context", "org-context:read"))
|
||||
require.False(t, validateScope(fiber.MethodGet, "/api/v1/integrations/org-context", "tenant:read"))
|
||||
require.False(t, validateScope(fiber.MethodGet, "/api/v1/orgfront/org-context", "org-context:read"))
|
||||
}
|
||||
224
baron-sso/backend/internal/middleware/audit_middleware.go
Normal file
224
baron-sso/backend/internal/middleware/audit_middleware.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/utils"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditConfig struct {
|
||||
Repo domain.AuditRepository
|
||||
ExcludePaths map[string]struct{}
|
||||
ReadPaths map[string]struct{}
|
||||
BodyDump bool
|
||||
WorkerCount int
|
||||
QueueSize int
|
||||
}
|
||||
|
||||
func isNil(i any) bool {
|
||||
if i == nil {
|
||||
return true
|
||||
}
|
||||
v := reflect.ValueOf(i)
|
||||
return v.Kind() == reflect.Pointer && v.IsNil()
|
||||
}
|
||||
|
||||
// AuditMiddleware provides comprehensive audit logging for write requests by default.
|
||||
// Read requests are skipped unless they are explicitly allowlisted in ReadPaths.
|
||||
func AuditMiddleware(config AuditConfig) fiber.Handler {
|
||||
// 0. Initialize Worker Pool for Async Logging
|
||||
if config.WorkerCount <= 0 {
|
||||
config.WorkerCount = 5 // Default workers
|
||||
}
|
||||
if config.QueueSize <= 0 {
|
||||
config.QueueSize = 1000 // Default queue size
|
||||
}
|
||||
|
||||
auditQueue := make(chan *domain.AuditLog, config.QueueSize)
|
||||
var once sync.Once
|
||||
|
||||
// Start workers only once
|
||||
once.Do(func() {
|
||||
for i := 0; i < config.WorkerCount; i++ {
|
||||
go func(workerID int) {
|
||||
slog.Debug("Audit worker started", "id", workerID)
|
||||
for log := range auditQueue {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("Audit worker panic recovery", "reason", r, "req_id", log.EventID)
|
||||
}
|
||||
}()
|
||||
if err := config.Repo.Create(log); err != nil {
|
||||
slog.Warn("Failed to write async audit log", "error", err, "req_id", log.EventID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
})
|
||||
|
||||
// Default methods classification
|
||||
writeMethods := map[string]struct{}{
|
||||
fiber.MethodPost: {},
|
||||
fiber.MethodPut: {},
|
||||
fiber.MethodPatch: {},
|
||||
fiber.MethodDelete: {},
|
||||
}
|
||||
|
||||
if config.ExcludePaths == nil {
|
||||
config.ExcludePaths = map[string]struct{}{}
|
||||
}
|
||||
if config.ReadPaths == nil {
|
||||
config.ReadPaths = map[string]struct{}{}
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
// 1. Check exclusions
|
||||
if _, excluded := config.ExcludePaths[c.Path()]; excluded {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// 2. Setup context variables
|
||||
start := time.Now()
|
||||
reqID := c.Get("X-Request-Id")
|
||||
if reqID == "" {
|
||||
reqID = uuid.New().String()
|
||||
c.Set("X-Request-Id", reqID)
|
||||
}
|
||||
|
||||
// 3. Process Request
|
||||
err := c.Next()
|
||||
|
||||
// 4. Gather Metrics & Context
|
||||
latency := time.Since(start)
|
||||
status := c.Response().StatusCode()
|
||||
|
||||
// If Fiber handler returned an error, status might default to 500 or be in the error
|
||||
if err != nil {
|
||||
if fiberErr, ok := err.(*fiber.Error); ok {
|
||||
status = fiberErr.Code
|
||||
} else {
|
||||
status = fiber.StatusInternalServerError
|
||||
}
|
||||
}
|
||||
|
||||
statusText := "success"
|
||||
if status >= fiber.StatusBadRequest {
|
||||
statusText = "failure"
|
||||
}
|
||||
|
||||
// 5. Extract User Context (populated by AuthMiddleware/TenantGuard)
|
||||
userID, _ := c.Locals("user_id").(string)
|
||||
loginID, _ := c.Locals("login_id").(string)
|
||||
tenantID, _ := c.Locals("tenant_id").(string)
|
||||
sessionID, _ := c.Locals("session_id").(string)
|
||||
clientIP := extractClientIP(c)
|
||||
|
||||
// 6. Capture & Mask Body
|
||||
var maskedBody string
|
||||
if config.BodyDump {
|
||||
if c.Method() != fiber.MethodGet && c.Method() != fiber.MethodHead {
|
||||
bodyBytes := c.Body()
|
||||
if len(bodyBytes) > 0 {
|
||||
maskedBytes := utils.MaskSensitiveJSON(bodyBytes)
|
||||
maskedBody = string(maskedBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 7. Construct Details JSON
|
||||
details := map[string]any{
|
||||
"request_id": reqID,
|
||||
"method": c.Method(),
|
||||
"path": c.Path(),
|
||||
"status": status,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
"login_id": loginID,
|
||||
"tenant_id": tenantID,
|
||||
"request_body": maskedBody,
|
||||
}
|
||||
// 핸들러에서 추가한 상세 정보를 병합합니다.
|
||||
if extra := c.Locals("audit_details_extra"); extra != nil {
|
||||
switch v := extra.(type) {
|
||||
case map[string]string:
|
||||
for key, value := range v {
|
||||
details[key] = value
|
||||
}
|
||||
case map[string]any:
|
||||
maps.Copy(details, v)
|
||||
}
|
||||
}
|
||||
if skipTimeline, ok := c.Locals("auth_timeline_skip").(bool); ok && skipTimeline {
|
||||
details["auth_timeline_skip"] = true
|
||||
}
|
||||
if sessionID != "" {
|
||||
details["session_id"] = sessionID
|
||||
}
|
||||
if approvedSessionID, ok := c.Locals("approved_session_id").(string); ok && approvedSessionID != "" {
|
||||
details["approved_session_id"] = approvedSessionID
|
||||
}
|
||||
if err != nil {
|
||||
details["error"] = err.Error()
|
||||
}
|
||||
|
||||
detailsJSON, _ := json.Marshal(details)
|
||||
|
||||
// 8. Create Audit Log Object
|
||||
auditLog := &domain.AuditLog{
|
||||
EventID: reqID,
|
||||
Timestamp: start,
|
||||
UserID: userID,
|
||||
TenantID: tenantID,
|
||||
SessionID: sessionID,
|
||||
EventType: fmt.Sprintf("%s %s", c.Method(), c.Path()),
|
||||
Status: statusText,
|
||||
IPAddress: clientIP,
|
||||
UserAgent: c.Get("User-Agent"),
|
||||
Details: string(detailsJSON),
|
||||
}
|
||||
|
||||
// 9. Store Log (Policy Enforcement)
|
||||
_, isWrite := writeMethods[c.Method()]
|
||||
_, allowRead := config.ReadPaths[c.Path()]
|
||||
|
||||
if isNil(config.Repo) {
|
||||
if isWrite {
|
||||
slog.Warn("Audit repository missing for command, proceeding without audit log", "req_id", reqID, "method", c.Method(), "path", c.Path())
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if isWrite {
|
||||
// Strict Mode: Synchronous write
|
||||
if createErr := config.Repo.Create(auditLog); createErr != nil {
|
||||
slog.Error("Failed to write audit log (sync)", "error", createErr, "req_id", reqID)
|
||||
return fiber.NewError(fiber.StatusServiceUnavailable, "Audit logging failed")
|
||||
}
|
||||
} else if allowRead {
|
||||
// Best Effort: Load Shedding via Buffered Channel
|
||||
select {
|
||||
case auditQueue <- auditLog:
|
||||
// Successfully queued
|
||||
default:
|
||||
// Queue full -> DROP (Load Shedding)
|
||||
slog.Warn("Audit queue full, dropping log (load shedding)", "req_id", reqID, "path", c.Path())
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func extractClientIP(c *fiber.Ctx) string {
|
||||
return utils.ResolveClientIP(c.Get("X-Forwarded-For"), c.Get("X-Real-IP"), c.IP())
|
||||
}
|
||||
265
baron-sso/backend/internal/middleware/audit_middleware_test.go
Normal file
265
baron-sso/backend/internal/middleware/audit_middleware_test.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockAuditRepository is a mock implementation of AuditRepository
|
||||
type MockAuditRepository struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockAuditRepository) Create(log *domain.AuditLog) error {
|
||||
args := m.Called(log)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAuditRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor, tenantID string) ([]domain.AuditLog, error) {
|
||||
args := m.Called(ctx, limit, cursor, tenantID)
|
||||
return args.Get(0).([]domain.AuditLog), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAuditRepository) FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]domain.AuditLog, error) {
|
||||
args := m.Called(ctx, userID, eventTypes, limit)
|
||||
return args.Get(0).([]domain.AuditLog), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAuditRepository) CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *MockAuditRepository) CountEventsSince(ctx context.Context, since time.Time) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *MockAuditRepository) CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *MockAuditRepository) Ping(ctx context.Context) error {
|
||||
args := m.Called(ctx)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type recordingAuditRepository struct {
|
||||
mu sync.Mutex
|
||||
count int
|
||||
}
|
||||
|
||||
func (r *recordingAuditRepository) Create(log *domain.AuditLog) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.count++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *recordingAuditRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor, tenantID string) ([]domain.AuditLog, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *recordingAuditRepository) FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]domain.AuditLog, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *recordingAuditRepository) CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *recordingAuditRepository) CountEventsSince(ctx context.Context, since time.Time) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *recordingAuditRepository) CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *recordingAuditRepository) Ping(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *recordingAuditRepository) Calls() int {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.count
|
||||
}
|
||||
|
||||
func TestAuditMiddleware(t *testing.T) {
|
||||
t.Run("POST request - Sync Success", func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockRepo := new(MockAuditRepository)
|
||||
|
||||
app.Use(AuditMiddleware(AuditConfig{
|
||||
Repo: mockRepo,
|
||||
BodyDump: true,
|
||||
}))
|
||||
|
||||
app.Post("/test", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
mockRepo.On("Create", mock.MatchedBy(func(log *domain.AuditLog) bool {
|
||||
var details map[string]any
|
||||
json.Unmarshal([]byte(log.Details), &details)
|
||||
return log.Status == "success" &&
|
||||
details["method"] == "POST" &&
|
||||
details["request_body"] == `{"password":"*****","user":"test"}`
|
||||
})).Return(nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", strings.NewReader(`{"user": "test", "password": "mypassword"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
mockRepo.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("POST request - Merge extra audit details", func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockRepo := new(MockAuditRepository)
|
||||
|
||||
app.Use(AuditMiddleware(AuditConfig{
|
||||
Repo: mockRepo,
|
||||
}))
|
||||
|
||||
app.Post("/test", func(c *fiber.Ctx) error {
|
||||
c.Locals("tenant_id", "tenant-a")
|
||||
c.Locals("audit_details_extra", map[string]any{
|
||||
"client_id": "rp-1",
|
||||
"client_name": "Demo App",
|
||||
})
|
||||
c.Locals("auth_timeline_skip", true)
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
mockRepo.On("Create", mock.MatchedBy(func(log *domain.AuditLog) bool {
|
||||
var details map[string]any
|
||||
if err := json.Unmarshal([]byte(log.Details), &details); err != nil {
|
||||
return false
|
||||
}
|
||||
if details["client_id"] != "rp-1" {
|
||||
return false
|
||||
}
|
||||
if details["client_name"] != "Demo App" {
|
||||
return false
|
||||
}
|
||||
if log.TenantID != "tenant-a" || details["tenant_id"] != "tenant-a" {
|
||||
return false
|
||||
}
|
||||
skip, ok := details["auth_timeline_skip"].(bool)
|
||||
return ok && skip
|
||||
})).Return(nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
mockRepo.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("POST request - Prefer public forwarded IP", func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockRepo := new(MockAuditRepository)
|
||||
|
||||
app.Use(AuditMiddleware(AuditConfig{
|
||||
Repo: mockRepo,
|
||||
}))
|
||||
|
||||
app.Post("/test", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
mockRepo.On("Create", mock.MatchedBy(func(log *domain.AuditLog) bool {
|
||||
return log.IPAddress == "203.0.113.25"
|
||||
})).Return(nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
req.Header.Set("X-Forwarded-For", "100.100.100.1, 203.0.113.25")
|
||||
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
mockRepo.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("POST request - Sync Failure (Strict Mode)", func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockRepo := new(MockAuditRepository)
|
||||
|
||||
app.Use(AuditMiddleware(AuditConfig{
|
||||
Repo: mockRepo,
|
||||
}))
|
||||
|
||||
app.Post("/test", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
mockRepo.On("Create", mock.Anything).Return(errors.New("db error"))
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
// Should return 503 because Audit failed on a Write method
|
||||
assert.Equal(t, fiber.StatusServiceUnavailable, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("GET request - Async Load Shedding", func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockRepo := new(MockAuditRepository)
|
||||
|
||||
// Set very small queue and no workers to force load shedding
|
||||
app.Use(AuditMiddleware(AuditConfig{
|
||||
Repo: mockRepo,
|
||||
QueueSize: 1,
|
||||
WorkerCount: 0, // This will be defaulted to 5 by the code, so let's use another way or just small queue
|
||||
}))
|
||||
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
// 1. First request fills the queue
|
||||
mockRepo.On("Create", mock.Anything).Return(nil)
|
||||
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
resp1, _ := app.Test(req1)
|
||||
assert.Equal(t, fiber.StatusOK, resp1.StatusCode)
|
||||
|
||||
// 2. Second request should be dropped (load shedding) if workers are slow
|
||||
// Since we can't easily pause workers without modifying code,
|
||||
// this test mostly ensures the non-blocking send doesn't hang.
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
resp2, _ := app.Test(req2)
|
||||
assert.Equal(t, fiber.StatusOK, resp2.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("GET request - Read audit is skipped by default", func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
repo := &recordingAuditRepository{}
|
||||
|
||||
app.Use(AuditMiddleware(AuditConfig{
|
||||
Repo: repo,
|
||||
}))
|
||||
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
assert.Equal(t, 0, repo.Calls())
|
||||
})
|
||||
}
|
||||
49
baron-sso/backend/internal/middleware/error_code_enricher.go
Normal file
49
baron-sso/backend/internal/middleware/error_code_enricher.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/response"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// ErrorCodeEnricher injects machine-readable `code` into legacy error responses.
|
||||
// This keeps backward compatibility (`error` stays) while migrating handlers
|
||||
// incrementally to explicit code-based responses.
|
||||
func ErrorCodeEnricher() fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
err := c.Next()
|
||||
|
||||
body := c.Response().Body()
|
||||
if len(body) == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
contentType := string(c.Response().Header.ContentType())
|
||||
if contentType != "" && !strings.Contains(contentType, fiber.MIMEApplicationJSON) {
|
||||
return err
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if unmarshalErr := json.Unmarshal(body, &payload); unmarshalErr != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, hasError := payload["error"]; !hasError {
|
||||
return err
|
||||
}
|
||||
if _, hasCode := payload["code"]; hasCode {
|
||||
return err
|
||||
}
|
||||
|
||||
payload["code"] = response.StatusCode(c.Response().StatusCode())
|
||||
updated, marshalErr := json.Marshal(payload)
|
||||
if marshalErr != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Response().SetBodyRaw(updated)
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func decodeBody(t *testing.T, resp *http.Response) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
var body map[string]any
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
t.Fatalf("failed to decode body: %v", err)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func TestErrorCodeEnricher_AddsCodeToLegacyErrorResponse(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Use(ErrorCodeEnricher())
|
||||
app.Get("/legacy", func(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
||||
"error": "missing token",
|
||||
})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/legacy", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != fiber.StatusUnauthorized {
|
||||
t.Fatalf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body := decodeBody(t, resp)
|
||||
if body["error"] != "missing token" {
|
||||
t.Fatalf("unexpected error field: %v", body["error"])
|
||||
}
|
||||
if body["code"] != "invalid_session" {
|
||||
t.Fatalf("unexpected code: %v", body["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorCodeEnricher_DoesNotOverrideExistingCode(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Use(ErrorCodeEnricher())
|
||||
app.Get("/already", func(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": "invalid request",
|
||||
"code": "validation_format",
|
||||
})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/already", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
|
||||
body := decodeBody(t, resp)
|
||||
if body["code"] != "validation_format" {
|
||||
t.Fatalf("code should not be overridden: %v", body["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorCodeEnricher_IgnoreSuccessPayload(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Use(ErrorCodeEnricher())
|
||||
app.Get("/ok", func(c *fiber.Ctx) error {
|
||||
return c.JSON(fiber.Map{
|
||||
"message": "ok",
|
||||
})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/ok", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
|
||||
body := decodeBody(t, resp)
|
||||
if _, found := body["code"]; found {
|
||||
t.Fatalf("code should not be injected for success payload")
|
||||
}
|
||||
}
|
||||
17
baron-sso/backend/internal/middleware/error_helper.go
Normal file
17
baron-sso/backend/internal/middleware/error_helper.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/response"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// errorJSON은 legacy error 필드를 유지하면서 status 기반 code를 함께 반환합니다.
|
||||
func errorJSON(c *fiber.Ctx, status int, message string) error {
|
||||
return response.Error(c, status, response.StatusCode(status), message)
|
||||
}
|
||||
|
||||
// errorJSONCode는 상태코드 매핑과 다른 명시 코드가 필요할 때 사용합니다.
|
||||
func errorJSONCode(c *fiber.Ctx, status int, code, message string) error {
|
||||
return response.Error(c, status, code, message)
|
||||
}
|
||||
166
baron-sso/backend/internal/middleware/rbac.go
Normal file
166
baron-sso/backend/internal/middleware/rbac.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"log/slog"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// RBACConfig defines the configuration for RBAC middleware
|
||||
type RBACConfig struct {
|
||||
AllowedRoles []string
|
||||
AuthHandler AuthProfileProvider
|
||||
KetoService service.KetoService
|
||||
}
|
||||
|
||||
// AuthProfileProvider는 미들웨어에서 사용자 정보를 조회하기 위한 최소 인터페이스입니다.
|
||||
type AuthProfileProvider interface {
|
||||
GetEnrichedProfile(c *fiber.Ctx) (*domain.UserProfileResponse, error)
|
||||
}
|
||||
|
||||
func setAuditUserContext(c *fiber.Ctx, profile *domain.UserProfileResponse) {
|
||||
if profile == nil || profile.ID == "" {
|
||||
return
|
||||
}
|
||||
if existingUserID, _ := c.Locals("user_id").(string); existingUserID != "" {
|
||||
return
|
||||
}
|
||||
c.Locals("user_id", profile.ID)
|
||||
}
|
||||
|
||||
// RequireKetoPermission enforces permissions using Ory Keto (ReBAC)
|
||||
func RequireKetoPermission(config RBACConfig, namespace, relation string) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
profile, err := config.AuthHandler.GetEnrichedProfile(c)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusUnauthorized, "unauthorized (trace:rbac_keto)")
|
||||
}
|
||||
|
||||
// Store profile in locals for further use in handlers
|
||||
c.Locals("user_profile", profile)
|
||||
setAuditUserContext(c, profile)
|
||||
|
||||
role := domain.NormalizeRole(profile.Role)
|
||||
|
||||
// Super Admin bypass
|
||||
if role == domain.RoleSuperAdmin {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Get object ID from path (e.g., tenant ID)
|
||||
// Fix: For Tenant namespace, prioritize tenantId param if available
|
||||
objectID := ""
|
||||
if namespace == "Tenant" {
|
||||
objectID = c.Params("tenantId")
|
||||
}
|
||||
|
||||
if objectID == "" {
|
||||
objectID = c.Params("id")
|
||||
}
|
||||
|
||||
if objectID == "" {
|
||||
slog.Error("RBAC Keto check failed: missing object id", "path", c.Path())
|
||||
return errorJSON(c, fiber.StatusBadRequest, "missing object id for permission check")
|
||||
}
|
||||
|
||||
slog.Info("Performing Keto permission check", "userID", profile.ID, "namespace", namespace, "objectID", objectID, "relation", relation)
|
||||
|
||||
// Set tenant_id for audit logging if namespace is Tenant
|
||||
if namespace == "Tenant" {
|
||||
c.Locals("tenant_id", objectID)
|
||||
}
|
||||
|
||||
// Check with Keto - add User: prefix to subject
|
||||
allowed, err := config.KetoService.CheckPermission(c.Context(), "User:"+profile.ID, namespace, objectID, relation)
|
||||
if err != nil {
|
||||
slog.Error("Keto service error", "error", err, "userID", profile.ID, "objectID", objectID)
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "permission check error")
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
slog.Warn("Keto permission denied", "userID", profile.ID, "userRole", role, "namespace", namespace, "objectID", objectID, "relation", relation, "X-Test-Role", c.Get("X-Test-Role"))
|
||||
return errorJSON(c, fiber.StatusForbidden, "forbidden: keto permission denied for "+namespace+":"+objectID)
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireRole enforces that the user has one of the allowed roles
|
||||
func RequireRole(config RBACConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Bypass if already authenticated via API Key
|
||||
if c.Locals("apiKeyName") != nil {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
profile, err := config.AuthHandler.GetEnrichedProfile(c)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusUnauthorized, "unauthorized (trace:rbac_role): "+err.Error())
|
||||
}
|
||||
|
||||
// Store profile in locals for further use in handlers
|
||||
c.Locals("user_profile", profile)
|
||||
setAuditUserContext(c, profile)
|
||||
|
||||
userRole := domain.NormalizeRole(profile.Role)
|
||||
|
||||
// Super Admin always has access
|
||||
if userRole == domain.RoleSuperAdmin {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Check if user's role is in allowed roles
|
||||
roleAllowed := false
|
||||
for _, role := range config.AllowedRoles {
|
||||
if userRole == domain.NormalizeRole(role) {
|
||||
roleAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !roleAllowed {
|
||||
slog.Warn("RBAC access denied",
|
||||
"userID", profile.ID,
|
||||
"userRole", userRole,
|
||||
"allowedRoles", config.AllowedRoles,
|
||||
"path", c.Path(),
|
||||
"X-Test-Role", c.Get("X-Test-Role"),
|
||||
)
|
||||
return errorJSON(c, fiber.StatusForbidden, "forbidden: insufficient permissions")
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireTenantMatch enforces that a Tenant Admin can only access their own tenant's data
|
||||
func RequireTenantMatch(config RBACConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Bypass if already authenticated via API Key (System-wide for now)
|
||||
if c.Locals("apiKeyName") != nil {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
profile, err := config.AuthHandler.GetEnrichedProfile(c)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusUnauthorized, "unauthorized (trace:rbac_match)")
|
||||
}
|
||||
|
||||
// Store profile in locals for further use in handlers
|
||||
c.Locals("user_profile", profile)
|
||||
setAuditUserContext(c, profile)
|
||||
|
||||
userRole := domain.NormalizeRole(profile.Role)
|
||||
|
||||
// Super Admin bypass
|
||||
if userRole == domain.RoleSuperAdmin {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Since only Super Admin is maintained for tenant management, others are rejected here
|
||||
return errorJSON(c, fiber.StatusForbidden, "forbidden")
|
||||
}
|
||||
}
|
||||
266
baron-sso/backend/internal/middleware/rbac_test.go
Normal file
266
baron-sso/backend/internal/middleware/rbac_test.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockAuthProvider is a mock for AuthProvider interface
|
||||
type MockAuthProvider struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockAuthProvider) GetEnrichedProfile(c *fiber.Ctx) (*domain.UserProfileResponse, error) {
|
||||
args := m.Called(c)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.UserProfileResponse), args.Error(1)
|
||||
}
|
||||
|
||||
// MockKetoService is a mock for KetoService interface
|
||||
type MockKetoService struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockKetoService) CheckPermission(ctx context.Context, subject, namespace, object, relation string) (bool, error) {
|
||||
args := m.Called(ctx, subject, namespace, object, relation)
|
||||
return args.Bool(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockKetoService) CreateRelation(ctx context.Context, namespace, object, relation, subject string) error {
|
||||
args := m.Called(ctx, namespace, object, relation, subject)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockKetoService) DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error {
|
||||
args := m.Called(ctx, namespace, object, relation, subject)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockKetoService) ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]service.RelationTuple, error) {
|
||||
args := m.Called(ctx, namespace, object, relation, subject)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).([]service.RelationTuple), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockKetoService) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) {
|
||||
args := m.Called(ctx, namespace, relation, subject)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).([]string), args.Error(1)
|
||||
}
|
||||
|
||||
func TestRequireRole_Success(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockAuth := new(MockAuthProvider)
|
||||
config := RBACConfig{
|
||||
AllowedRoles: []string{domain.RoleSuperAdmin},
|
||||
AuthHandler: mockAuth,
|
||||
}
|
||||
|
||||
mockAuth.On("GetEnrichedProfile", mock.Anything).Return(&domain.UserProfileResponse{
|
||||
ID: "user1",
|
||||
Role: domain.RoleSuperAdmin,
|
||||
}, nil)
|
||||
|
||||
app.Get("/test", RequireRole(config), func(c *fiber.Ctx) error {
|
||||
return c.SendString("ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestRequireRole_SetsUserIDForAuditContext(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockAuth := new(MockAuthProvider)
|
||||
config := RBACConfig{
|
||||
AllowedRoles: []string{domain.RoleSuperAdmin},
|
||||
AuthHandler: mockAuth,
|
||||
}
|
||||
|
||||
mockAuth.On("GetEnrichedProfile", mock.Anything).Return(&domain.UserProfileResponse{
|
||||
ID: "user1",
|
||||
Role: domain.RoleSuperAdmin,
|
||||
}, nil)
|
||||
|
||||
app.Get("/test", RequireRole(config), func(c *fiber.Ctx) error {
|
||||
return c.JSON(fiber.Map{
|
||||
"user_id": c.Locals("user_id"),
|
||||
})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
var body map[string]string
|
||||
assert.NoError(t, readJSON(resp, &body))
|
||||
assert.Equal(t, "user1", body["user_id"])
|
||||
}
|
||||
|
||||
func TestRequireRole_PreservesExistingUserID(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockAuth := new(MockAuthProvider)
|
||||
config := RBACConfig{
|
||||
AllowedRoles: []string{domain.RoleSuperAdmin},
|
||||
AuthHandler: mockAuth,
|
||||
}
|
||||
|
||||
mockAuth.On("GetEnrichedProfile", mock.Anything).Return(&domain.UserProfileResponse{
|
||||
ID: "profile-user",
|
||||
Role: domain.RoleSuperAdmin,
|
||||
}, nil)
|
||||
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_id", "existing-user")
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/test", RequireRole(config), func(c *fiber.Ctx) error {
|
||||
return c.JSON(fiber.Map{
|
||||
"user_id": c.Locals("user_id"),
|
||||
})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
var body map[string]string
|
||||
assert.NoError(t, readJSON(resp, &body))
|
||||
assert.Equal(t, "existing-user", body["user_id"])
|
||||
}
|
||||
|
||||
func TestRequireRole_Forbidden(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockAuth := new(MockAuthProvider)
|
||||
config := RBACConfig{
|
||||
AllowedRoles: []string{domain.RoleSuperAdmin},
|
||||
AuthHandler: mockAuth,
|
||||
}
|
||||
|
||||
mockAuth.On("GetEnrichedProfile", mock.Anything).Return(&domain.UserProfileResponse{
|
||||
ID: "user1",
|
||||
Role: "user",
|
||||
}, nil)
|
||||
|
||||
app.Get("/test", RequireRole(config), func(c *fiber.Ctx) error {
|
||||
return c.SendString("ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, 403, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestRequireKetoPermission_Success(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockAuth := new(MockAuthProvider)
|
||||
mockKeto := new(MockKetoService)
|
||||
config := RBACConfig{
|
||||
AuthHandler: mockAuth,
|
||||
KetoService: mockKeto,
|
||||
}
|
||||
|
||||
profile := &domain.UserProfileResponse{ID: "user1", Role: "user"}
|
||||
mockAuth.On("GetEnrichedProfile", mock.Anything).Return(profile, nil)
|
||||
mockKeto.On("CheckPermission", mock.Anything, "User:user1", "tenants", "tenant1", "read").Return(true, nil)
|
||||
|
||||
app.Get("/tenants/:id", RequireKetoPermission(config, "tenants", "read"), func(c *fiber.Ctx) error {
|
||||
return c.SendString("ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/tenants/tenant1", nil)
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestRequireTenantMatch_SuperAdmin(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockAuth := new(MockAuthProvider)
|
||||
config := RBACConfig{
|
||||
AuthHandler: mockAuth,
|
||||
}
|
||||
|
||||
mockAuth.On("GetEnrichedProfile", mock.Anything).Return(&domain.UserProfileResponse{
|
||||
ID: "admin1",
|
||||
Role: domain.RoleSuperAdmin,
|
||||
}, nil)
|
||||
|
||||
app.Get("/tenants/:tenantId/data", RequireTenantMatch(config), func(c *fiber.Ctx) error {
|
||||
return c.SendString("ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/tenants/any-tenant/data", nil)
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestRequireTenantMatch_Forbidden(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockAuth := new(MockAuthProvider)
|
||||
config := RBACConfig{
|
||||
AuthHandler: mockAuth,
|
||||
}
|
||||
|
||||
tenant1 := "tenant1"
|
||||
mockAuth.On("GetEnrichedProfile", mock.Anything).Return(&domain.UserProfileResponse{
|
||||
ID: "user1",
|
||||
Role: "user", // Formerly tenant_admin, now mapped to user which is forbidden here for non-superadmin
|
||||
TenantID: &tenant1,
|
||||
}, nil)
|
||||
|
||||
app.Get("/tenants/:tenantId/data", RequireTenantMatch(config), func(c *fiber.Ctx) error {
|
||||
return c.SendString("ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/tenants/tenant2/data", nil)
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, 403, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestRequireRole_Unauthorized(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockAuth := new(MockAuthProvider)
|
||||
config := RBACConfig{
|
||||
AuthHandler: mockAuth,
|
||||
}
|
||||
|
||||
mockAuth.On("GetEnrichedProfile", mock.Anything).Return(nil, errors.New("unauthorized"))
|
||||
|
||||
app.Get("/test", RequireRole(config), func(c *fiber.Ctx) error {
|
||||
return c.SendString("ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, 401, resp.StatusCode)
|
||||
}
|
||||
|
||||
func readJSON(resp *http.Response, target any) error {
|
||||
defer resp.Body.Close()
|
||||
return json.NewDecoder(resp.Body).Decode(target)
|
||||
}
|
||||
69
baron-sso/backend/internal/middleware/tenant_middleware.go
Normal file
69
baron-sso/backend/internal/middleware/tenant_middleware.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/service"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// TenantContextConfig defines the configuration for Tenant context middleware
|
||||
type TenantContextConfig struct {
|
||||
TenantService service.TenantService
|
||||
}
|
||||
|
||||
// TenantContextMiddleware identifies the tenant based on the Host header (subdomain or custom domain)
|
||||
func TenantContextMiddleware(config TenantContextConfig) fiber.Handler {
|
||||
userfrontURL := os.Getenv("USERFRONT_URL")
|
||||
baseDomain := ""
|
||||
if userfrontURL != "" {
|
||||
if u, err := url.Parse(userfrontURL); err == nil {
|
||||
baseDomain = u.Hostname()
|
||||
}
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
hostname := c.Hostname()
|
||||
if hostname == "" {
|
||||
hostname = string(c.Request().Header.Host())
|
||||
}
|
||||
|
||||
if hostname == "" {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// 1. If it's the exact base domain or localhost, no specific tenant context from host
|
||||
if hostname == baseDomain || hostname == "localhost" || hostname == "127.0.0.1" {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// 2. Try to find by registered custom domain
|
||||
tenant, err := config.TenantService.GetTenantByDomain(c.Context(), hostname)
|
||||
if err == nil && tenant != nil {
|
||||
slog.Debug("Tenant identified by custom domain", "hostname", hostname, "tenantID", tenant.ID)
|
||||
c.Locals("tenant_id", tenant.ID)
|
||||
c.Locals("tenant_slug", tenant.Slug)
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// 3. Try to find by subdomain (slug.baseDomain)
|
||||
if baseDomain != "" && strings.HasSuffix(hostname, "."+baseDomain) {
|
||||
slug := strings.TrimSuffix(hostname, "."+baseDomain)
|
||||
// Handle cases like "www.sso.hmac.kr" if baseDomain is "sso.hmac.kr"
|
||||
if slug != "" && slug != "www" {
|
||||
tenant, err := config.TenantService.GetTenantBySlug(c.Context(), slug)
|
||||
if err == nil && tenant != nil {
|
||||
slog.Debug("Tenant identified by subdomain slug", "slug", slug, "tenantID", tenant.ID)
|
||||
c.Locals("tenant_id", tenant.ID)
|
||||
c.Locals("tenant_slug", tenant.Slug)
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
123
baron-sso/backend/internal/middleware/tenant_middleware_test.go
Normal file
123
baron-sso/backend/internal/middleware/tenant_middleware_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type MockTenantServiceForMiddleware struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error) {
|
||||
args := m.Called(mock.Anything, emailDomain)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
|
||||
args := m.Called(mock.Anything, slug)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) ListTenants(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) ListJoinedTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) ApproveTenant(ctx context.Context, id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) SetKetoService(keto service.KetoService) {}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) DeleteTenantsBulk(ctx context.Context, ids []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestTenantContextMiddleware(t *testing.T) {
|
||||
os.Setenv("USERFRONT_URL", "https://sso.hmac.kr")
|
||||
defer os.Unsetenv("USERFRONT_URL")
|
||||
|
||||
mockSvc := new(MockTenantServiceForMiddleware)
|
||||
app := fiber.New()
|
||||
app.Use(TenantContextMiddleware(TenantContextConfig{TenantService: mockSvc}))
|
||||
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return c.JSON(fiber.Map{
|
||||
"tenant_id": c.Locals("tenant_id"),
|
||||
"tenant_slug": c.Locals("tenant_slug"),
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Base Domain - No Tenant Context", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://sso.hmac.kr/test", nil)
|
||||
req.Host = "sso.hmac.kr"
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
mockSvc.AssertNotCalled(t, "GetTenantByDomain", mock.Anything, "sso.hmac.kr")
|
||||
})
|
||||
|
||||
t.Run("Subdomain - Identify by Slug", func(t *testing.T) {
|
||||
mockSvc.On("GetTenantByDomain", mock.Anything, "tenant1.sso.hmac.kr").Return(nil, nil).Once()
|
||||
mockSvc.On("GetTenantBySlug", mock.Anything, "tenant1").Return(&domain.Tenant{ID: "t1", Slug: "tenant1"}, nil).Once()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Host = "tenant1.sso.hmac.kr"
|
||||
req.Header.Set("Host", "tenant1.sso.hmac.kr")
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Custom Domain - Identify by Domain", func(t *testing.T) {
|
||||
mockSvc.On("GetTenantByDomain", mock.Anything, "company.com").Return(&domain.Tenant{ID: "t2", Slug: "company"}, nil).Once()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Host = "company.com"
|
||||
req.Header.Set("Host", "company.com")
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user