forked from baron/baron-sso
audit 로그 개선. kratos 코드발급 링크로 전송까지 진행 완료 #104
This commit is contained in:
@@ -1,10 +1,14 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrNotSupported는 IDP가 특정 인증 흐름을 지원하지 않을 때 반환합니다.
|
||||
var ErrNotSupported = errors.New("idp: not supported")
|
||||
|
||||
// BrokerUser is the standard user model used within Baron SSO business logic.
|
||||
// It defines the canonical set of fields that must be supported by any underlying IDP.
|
||||
type BrokerUser struct {
|
||||
@@ -24,6 +28,16 @@ type IDPMetadata struct {
|
||||
SupportedFields []string
|
||||
}
|
||||
|
||||
// PasswordPolicy는 비밀번호 정책 정보를 표현합니다.
|
||||
type PasswordPolicy struct {
|
||||
MinLength int
|
||||
Lowercase bool
|
||||
Uppercase bool
|
||||
Number bool
|
||||
NonAlphanumeric bool
|
||||
MinCharacterTypes int
|
||||
}
|
||||
|
||||
// Token represents a session or refresh token.
|
||||
type Token struct {
|
||||
JWT string
|
||||
@@ -38,6 +52,14 @@ type AuthInfo struct {
|
||||
Subject string
|
||||
}
|
||||
|
||||
// LinkLoginInit는 링크 로그인 초기화 결과입니다.
|
||||
type LinkLoginInit struct {
|
||||
FlowID string
|
||||
ExpiresAt time.Time
|
||||
// Mode는 링크 로그인 완료 후 세션 처리 방식입니다. (예: "cookie")
|
||||
Mode string
|
||||
}
|
||||
|
||||
// IdentityProvider is the interface that all IDP adapters must implement.
|
||||
type IdentityProvider interface {
|
||||
Name() string
|
||||
@@ -48,6 +70,16 @@ type IdentityProvider interface {
|
||||
CreateUser(user *BrokerUser, password string) (string, error)
|
||||
// SignIn은 로그인 ID/비밀번호로 인증해 세션 정보를 반환합니다.
|
||||
SignIn(loginID, password string) (*AuthInfo, error)
|
||||
// UserExists는 loginID 기준으로 사용자 존재 여부를 확인합니다.
|
||||
UserExists(loginID string) (bool, error)
|
||||
// IssueSession은 비밀번호 없이 세션을 발급해야 하는 흐름에서 사용합니다.
|
||||
IssueSession(loginID string) (*AuthInfo, error)
|
||||
// InitiateLinkLogin은 링크 기반 로그인 요청을 IDP에 전달합니다.
|
||||
InitiateLinkLogin(loginID, returnTo string) (*LinkLoginInit, error)
|
||||
// VerifyLoginCode는 링크/코드 기반 로그인에서 코드를 제출해 세션을 발급합니다.
|
||||
VerifyLoginCode(loginID, flowID, code string) (*AuthInfo, error)
|
||||
// GetPasswordPolicy는 IDP가 제공하는 비밀번호 정책을 반환합니다.
|
||||
GetPasswordPolicy() (*PasswordPolicy, error)
|
||||
InitiatePasswordReset(loginID, redirectUrl string) error
|
||||
VerifyPasswordResetToken(token string) (*AuthInfo, error)
|
||||
UpdateUserPassword(loginID, newPassword string, r *http.Request) error
|
||||
|
||||
@@ -3,6 +3,8 @@ package handler
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/descope/go-sdk/descope/client"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
@@ -39,3 +41,23 @@ func NewAdminHandler() *AdminHandler {
|
||||
func (h *AdminHandler) CheckAuth(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "ok"})
|
||||
}
|
||||
|
||||
// GetSystemStats returns runtime statistics for monitoring
|
||||
func (h *AdminHandler) GetSystemStats(c *fiber.Ctx) error {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
|
||||
stats := fiber.Map{
|
||||
"goroutines": runtime.NumGoroutine(),
|
||||
"cpus": runtime.NumCPU(),
|
||||
"memory": fiber.Map{
|
||||
"alloc": m.Alloc,
|
||||
"totalAlign": m.TotalAlloc,
|
||||
"sys": m.Sys,
|
||||
"numGC": m.NumGC,
|
||||
},
|
||||
"timestamp": time.Now(),
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(stats)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -72,7 +72,7 @@ func TestCompletePasswordReset_InvalidPasswordPolicy(t *testing.T) {
|
||||
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if got["error"] != "Password must be at least 8 characters long" {
|
||||
if got["error"] != "비밀번호는 최소 12자 이상이어야 합니다" {
|
||||
t.Fatalf("unexpected error message: %v", got["error"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,43 +124,144 @@ func (c *chainedProvider) GetMetadata() (*domain.IDPMetadata, error) {
|
||||
}
|
||||
|
||||
func (c *chainedProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
|
||||
var errs []error
|
||||
for idx, p := range c.providers {
|
||||
for _, p := range c.providers {
|
||||
id, err := p.CreateUser(user, password)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "CreateUser", "error", err)
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
slog.Info("IDP fallback succeeded", "operation", "CreateUser", "provider", p.Name())
|
||||
return "", err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return "", fmt.Errorf("no IDP providers available for CreateUser")
|
||||
}
|
||||
return "", fmt.Errorf("all IDP providers failed for CreateUser: %w", errors.Join(errs...))
|
||||
return "", domain.ErrNotSupported
|
||||
}
|
||||
|
||||
func (c *chainedProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
|
||||
var errs []error
|
||||
for idx, p := range c.providers {
|
||||
for _, p := range c.providers {
|
||||
info, err := p.SignIn(loginID, password)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "SignIn", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
|
||||
func (c *chainedProvider) UserExists(loginID string) (bool, error) {
|
||||
var errs []error
|
||||
for _, p := range c.providers {
|
||||
exists, err := p.UserExists(loginID)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
continue
|
||||
}
|
||||
if exists {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("all IDP providers failed for UserExists: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *chainedProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
|
||||
var errs []error
|
||||
for idx, p := range c.providers {
|
||||
info, err := p.IssueSession(loginID)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "IssueSession", "error", err)
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
slog.Info("IDP fallback succeeded", "operation", "SignIn", "provider", p.Name())
|
||||
slog.Info("IDP fallback succeeded", "operation", "IssueSession", "provider", p.Name())
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return nil, fmt.Errorf("no IDP providers available for SignIn")
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
return nil, fmt.Errorf("all IDP providers failed for SignIn: %w", errors.Join(errs...))
|
||||
return nil, fmt.Errorf("all IDP providers failed for IssueSession: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *chainedProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
var errs []error
|
||||
for idx, p := range c.providers {
|
||||
info, err := p.InitiateLinkLogin(loginID, returnTo)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "InitiateLinkLogin", "error", err)
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
slog.Info("IDP fallback succeeded", "operation", "InitiateLinkLogin", "provider", p.Name())
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
return nil, fmt.Errorf("all IDP providers failed for InitiateLinkLogin: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *chainedProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
||||
var errs []error
|
||||
for idx, p := range c.providers {
|
||||
info, err := p.VerifyLoginCode(loginID, flowID, code)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "VerifyLoginCode", "error", err)
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
slog.Info("IDP fallback succeeded", "operation", "VerifyLoginCode", "provider", p.Name())
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
return nil, fmt.Errorf("all IDP providers failed for VerifyLoginCode: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *chainedProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
|
||||
var errs []error
|
||||
for _, p := range c.providers {
|
||||
policy, err := p.GetPasswordPolicy()
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
continue
|
||||
}
|
||||
if policy != nil {
|
||||
return policy, nil
|
||||
}
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
return nil, fmt.Errorf("all IDP providers failed for GetPasswordPolicy: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *chainedProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
|
||||
|
||||
@@ -10,19 +10,31 @@ import (
|
||||
)
|
||||
|
||||
type stubProvider struct {
|
||||
name string
|
||||
metadata []string
|
||||
createErr error
|
||||
initiateErr error
|
||||
verifyErr error
|
||||
updateErr error
|
||||
signInErr error
|
||||
initiateCalls int
|
||||
verifyCalls int
|
||||
updateCalls int
|
||||
signInCalls int
|
||||
createCalls int
|
||||
verifyResponse *domain.AuthInfo
|
||||
name string
|
||||
metadata []string
|
||||
createErr error
|
||||
initiateErr error
|
||||
verifyErr error
|
||||
updateErr error
|
||||
signInErr error
|
||||
userExistsErr error
|
||||
issueErr error
|
||||
linkInitErr error
|
||||
verifyCodeErr error
|
||||
policyErr error
|
||||
initiateCalls int
|
||||
verifyCalls int
|
||||
updateCalls int
|
||||
signInCalls int
|
||||
createCalls int
|
||||
userExistsCalls int
|
||||
issueCalls int
|
||||
linkInitCalls int
|
||||
verifyCodeCalls int
|
||||
policyCalls int
|
||||
verifyResponse *domain.AuthInfo
|
||||
userExists bool
|
||||
policy *domain.PasswordPolicy
|
||||
}
|
||||
|
||||
func (s *stubProvider) Name() string { return s.name }
|
||||
@@ -47,6 +59,46 @@ func (s *stubProvider) SignIn(loginID, password string) (*domain.AuthInfo, error
|
||||
return &domain.AuthInfo{Subject: "subject-123"}, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) UserExists(loginID string) (bool, error) {
|
||||
s.userExistsCalls++
|
||||
if s.userExistsErr != nil {
|
||||
return false, s.userExistsErr
|
||||
}
|
||||
return s.userExists, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
|
||||
s.issueCalls++
|
||||
if s.issueErr != nil {
|
||||
return nil, s.issueErr
|
||||
}
|
||||
return &domain.AuthInfo{Subject: "issue-subject"}, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
s.linkInitCalls++
|
||||
if s.linkInitErr != nil {
|
||||
return nil, s.linkInitErr
|
||||
}
|
||||
return &domain.LinkLoginInit{FlowID: "flow-123", Mode: "cookie"}, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
||||
s.verifyCodeCalls++
|
||||
if s.verifyCodeErr != nil {
|
||||
return nil, s.verifyCodeErr
|
||||
}
|
||||
return &domain.AuthInfo{Subject: "verify-code-subject"}, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
|
||||
s.policyCalls++
|
||||
if s.policyErr != nil {
|
||||
return nil, s.policyErr
|
||||
}
|
||||
return s.policy, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
|
||||
s.initiateCalls++
|
||||
return s.initiateErr
|
||||
|
||||
183
backend/internal/middleware/audit_middleware.go
Normal file
183
backend/internal/middleware/audit_middleware.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/utils"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditConfig struct {
|
||||
Repo domain.AuditRepository
|
||||
ExcludePaths map[string]struct{}
|
||||
BodyDump bool
|
||||
WorkerCount int
|
||||
QueueSize int
|
||||
}
|
||||
|
||||
// AuditMiddleware provides comprehensive audit logging for all requests.
|
||||
// It enforces strict logging for state-changing commands (POST, PUT, DELETE, PATCH)
|
||||
// and best-effort logging for queries (GET, HEAD, OPTIONS).
|
||||
func AuditMiddleware(config AuditConfig) fiber.Handler {
|
||||
// 0. Initialize Worker Pool for Async Logging
|
||||
if config.WorkerCount <= 0 {
|
||||
config.WorkerCount = 5 // Default workers
|
||||
}
|
||||
if config.QueueSize <= 0 {
|
||||
config.QueueSize = 1000 // Default queue size
|
||||
}
|
||||
|
||||
auditQueue := make(chan *domain.AuditLog, config.QueueSize)
|
||||
var once sync.Once
|
||||
|
||||
// Start workers only once
|
||||
once.Do(func() {
|
||||
for i := 0; i < config.WorkerCount; i++ {
|
||||
go func(workerID int) {
|
||||
slog.Debug("Audit worker started", "id", workerID)
|
||||
for log := range auditQueue {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("Audit worker panic recovery", "reason", r, "req_id", log.EventID)
|
||||
}
|
||||
}()
|
||||
if err := config.Repo.Create(log); err != nil {
|
||||
slog.Warn("Failed to write async audit log", "error", err, "req_id", log.EventID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
})
|
||||
|
||||
// Default methods classification
|
||||
writeMethods := map[string]struct{}{
|
||||
fiber.MethodPost: {},
|
||||
fiber.MethodPut: {},
|
||||
fiber.MethodPatch: {},
|
||||
fiber.MethodDelete: {},
|
||||
}
|
||||
|
||||
if config.ExcludePaths == nil {
|
||||
config.ExcludePaths = map[string]struct{}{}
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
// 1. Check exclusions
|
||||
if _, excluded := config.ExcludePaths[c.Path()]; excluded {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// 2. Setup context variables
|
||||
start := time.Now()
|
||||
reqID := c.Get("X-Request-Id")
|
||||
if reqID == "" {
|
||||
reqID = uuid.New().String()
|
||||
c.Set("X-Request-Id", reqID)
|
||||
}
|
||||
|
||||
// 3. Process Request
|
||||
err := c.Next()
|
||||
|
||||
// 4. Gather Metrics & Context
|
||||
latency := time.Since(start)
|
||||
status := c.Response().StatusCode()
|
||||
|
||||
// If Fiber handler returned an error, status might default to 500 or be in the error
|
||||
if err != nil {
|
||||
if fiberErr, ok := err.(*fiber.Error); ok {
|
||||
status = fiberErr.Code
|
||||
} else {
|
||||
status = fiber.StatusInternalServerError
|
||||
}
|
||||
}
|
||||
|
||||
statusText := "success"
|
||||
if status >= fiber.StatusBadRequest {
|
||||
statusText = "failure"
|
||||
}
|
||||
|
||||
// 5. Extract User Context (populated by AuthMiddleware/TenantGuard)
|
||||
userID, _ := c.Locals("user_id").(string)
|
||||
loginID, _ := c.Locals("login_id").(string)
|
||||
tenantID, _ := c.Locals("tenant_id").(string)
|
||||
|
||||
// 6. Capture & Mask Body
|
||||
var maskedBody string
|
||||
if config.BodyDump {
|
||||
if c.Method() != fiber.MethodGet && c.Method() != fiber.MethodHead {
|
||||
bodyBytes := c.Body()
|
||||
if len(bodyBytes) > 0 {
|
||||
maskedBytes := utils.MaskSensitiveJSON(bodyBytes)
|
||||
maskedBody = string(maskedBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 7. Construct Details JSON
|
||||
details := map[string]any{
|
||||
"request_id": reqID,
|
||||
"method": c.Method(),
|
||||
"path": c.Path(),
|
||||
"status": status,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
"login_id": loginID,
|
||||
"tenant_id": tenantID,
|
||||
"request_body": maskedBody,
|
||||
}
|
||||
if err != nil {
|
||||
details["error"] = err.Error()
|
||||
}
|
||||
|
||||
detailsJSON, _ := json.Marshal(details)
|
||||
|
||||
// 8. Create Audit Log Object
|
||||
auditLog := &domain.AuditLog{
|
||||
EventID: reqID,
|
||||
Timestamp: start,
|
||||
UserID: userID,
|
||||
EventType: fmt.Sprintf("%s %s", c.Method(), c.Path()),
|
||||
Status: statusText,
|
||||
IPAddress: c.IP(),
|
||||
UserAgent: c.Get("User-Agent"),
|
||||
Details: string(detailsJSON),
|
||||
}
|
||||
|
||||
// 9. Store Log (Policy Enforcement)
|
||||
_, isWrite := writeMethods[c.Method()]
|
||||
|
||||
if config.Repo == nil {
|
||||
if isWrite {
|
||||
slog.Error("Audit repository missing for command", "req_id", reqID)
|
||||
return fiber.NewError(fiber.StatusServiceUnavailable, "Audit system unavailable")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if isWrite {
|
||||
// Strict Mode: Synchronous write
|
||||
if createErr := config.Repo.Create(auditLog); createErr != nil {
|
||||
slog.Error("Failed to write audit log (sync)", "error", createErr, "req_id", reqID)
|
||||
return fiber.NewError(fiber.StatusServiceUnavailable, "Audit logging failed")
|
||||
}
|
||||
} else {
|
||||
// Best Effort: Load Shedding via Buffered Channel
|
||||
select {
|
||||
case auditQueue <- auditLog:
|
||||
// Successfully queued
|
||||
default:
|
||||
// Queue full -> DROP (Load Shedding)
|
||||
slog.Warn("Audit queue full, dropping log (load shedding)", "req_id", reqID, "path", c.Path())
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
117
backend/internal/middleware/audit_middleware_test.go
Normal file
117
backend/internal/middleware/audit_middleware_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockAuditRepository is a mock implementation of AuditRepository
|
||||
type MockAuditRepository struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockAuditRepository) Create(log *domain.AuditLog) error {
|
||||
args := m.Called(log)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAuditRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor) ([]domain.AuditLog, error) {
|
||||
args := m.Called(ctx, limit, cursor)
|
||||
return args.Get(0).([]domain.AuditLog), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAuditRepository) Ping(ctx context.Context) error {
|
||||
args := m.Called(ctx)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestAuditMiddleware(t *testing.T) {
|
||||
t.Run("POST request - Sync Success", func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockRepo := new(MockAuditRepository)
|
||||
|
||||
app.Use(AuditMiddleware(AuditConfig{
|
||||
Repo: mockRepo,
|
||||
BodyDump: true,
|
||||
}))
|
||||
|
||||
app.Post("/test", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
mockRepo.On("Create", mock.MatchedBy(func(log *domain.AuditLog) bool {
|
||||
var details map[string]any
|
||||
json.Unmarshal([]byte(log.Details), &details)
|
||||
return log.Status == "success" &&
|
||||
details["method"] == "POST" &&
|
||||
details["request_body"] == `{"password":"*****","user":"test"}`
|
||||
})).Return(nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", strings.NewReader(`{"user": "test", "password": "mypassword"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
mockRepo.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("POST request - Sync Failure (Strict Mode)", func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockRepo := new(MockAuditRepository)
|
||||
|
||||
app.Use(AuditMiddleware(AuditConfig{
|
||||
Repo: mockRepo,
|
||||
}))
|
||||
|
||||
app.Post("/test", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
mockRepo.On("Create", mock.Anything).Return(errors.New("db error"))
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
// Should return 503 because Audit failed on a Write method
|
||||
assert.Equal(t, fiber.StatusServiceUnavailable, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("GET request - Async Load Shedding", func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockRepo := new(MockAuditRepository)
|
||||
|
||||
// Set very small queue and no workers to force load shedding
|
||||
app.Use(AuditMiddleware(AuditConfig{
|
||||
Repo: mockRepo,
|
||||
QueueSize: 1,
|
||||
WorkerCount: 0, // This will be defaulted to 5 by the code, so let's use another way or just small queue
|
||||
}))
|
||||
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
// 1. First request fills the queue
|
||||
mockRepo.On("Create", mock.Anything).Return(nil)
|
||||
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
resp1, _ := app.Test(req1)
|
||||
assert.Equal(t, fiber.StatusOK, resp1.StatusCode)
|
||||
|
||||
// 2. Second request should be dropped (load shedding) if workers are slow
|
||||
// Since we can't easily pause workers without modifying code,
|
||||
// this test mostly ensures the non-blocking send doesn't hang.
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
resp2, _ := app.Test(req2)
|
||||
assert.Equal(t, fiber.StatusOK, resp2.StatusCode)
|
||||
})
|
||||
}
|
||||
@@ -1,106 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditRequiredConfig struct {
|
||||
Repo domain.AuditRepository
|
||||
ExcludePaths map[string]struct{}
|
||||
CommandMethods map[string]struct{}
|
||||
}
|
||||
|
||||
func RequireAudit(config AuditRequiredConfig) fiber.Handler {
|
||||
commandMethods := config.CommandMethods
|
||||
if len(commandMethods) == 0 {
|
||||
commandMethods = map[string]struct{}{
|
||||
fiber.MethodPost: {},
|
||||
fiber.MethodPut: {},
|
||||
fiber.MethodPatch: {},
|
||||
fiber.MethodDelete: {},
|
||||
}
|
||||
}
|
||||
|
||||
excludePaths := config.ExcludePaths
|
||||
if excludePaths == nil {
|
||||
excludePaths = map[string]struct{}{}
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
if _, ok := commandMethods[c.Method()]; !ok {
|
||||
return c.Next()
|
||||
}
|
||||
if _, excluded := excludePaths[c.Path()]; excluded {
|
||||
return c.Next()
|
||||
}
|
||||
if config.Repo == nil {
|
||||
return fiber.NewError(fiber.StatusServiceUnavailable, "audit repository unavailable")
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
reqID := c.Get("X-Request-Id")
|
||||
if reqID == "" {
|
||||
reqID = uuid.New().String()
|
||||
c.Set("X-Request-Id", reqID)
|
||||
}
|
||||
|
||||
err := c.Next()
|
||||
latency := time.Since(start)
|
||||
|
||||
status := c.Response().StatusCode()
|
||||
if err != nil {
|
||||
if fiberErr, ok := err.(*fiber.Error); ok {
|
||||
status = fiberErr.Code
|
||||
} else {
|
||||
status = fiber.StatusInternalServerError
|
||||
}
|
||||
}
|
||||
|
||||
statusText := "success"
|
||||
if status >= fiber.StatusBadRequest {
|
||||
statusText = "failure"
|
||||
}
|
||||
|
||||
details := map[string]any{
|
||||
"request_id": reqID,
|
||||
"method": c.Method(),
|
||||
"path": c.Path(),
|
||||
"status": status,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
}
|
||||
if err != nil {
|
||||
details["error"] = err.Error()
|
||||
}
|
||||
|
||||
detailsJSON, jsonErr := json.Marshal(details)
|
||||
if jsonErr != nil {
|
||||
slog.Warn("failed to marshal audit details", "error", jsonErr, "req_id", reqID)
|
||||
}
|
||||
|
||||
auditLog := &domain.AuditLog{
|
||||
EventID: reqID,
|
||||
Timestamp: time.Now(),
|
||||
UserID: "",
|
||||
EventType: fmt.Sprintf("%s %s", c.Method(), c.Path()),
|
||||
Status: statusText,
|
||||
IPAddress: c.IP(),
|
||||
UserAgent: c.Get("User-Agent"),
|
||||
DeviceID: "",
|
||||
Details: string(detailsJSON),
|
||||
}
|
||||
|
||||
if createErr := config.Repo.Create(auditLog); createErr != nil {
|
||||
slog.Error("audit log write failed", "error", createErr, "req_id", reqID, "path", c.Path())
|
||||
return fiber.NewError(fiber.StatusServiceUnavailable, "audit logging unavailable")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -145,6 +145,101 @@ func (d *DescopeProvider) SignIn(loginID, password string) (*domain.AuthInfo, er
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// UserExists는 loginID(이메일/전화번호) 기준으로 사용자가 있는지 확인합니다.
|
||||
func (d *DescopeProvider) UserExists(loginID string) (bool, error) {
|
||||
if d.Client == nil {
|
||||
return false, fmt.Errorf("descope provider: client is nil")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if strings.Contains(loginID, "@") {
|
||||
user, err := d.Client.Management.User().Load(ctx, loginID)
|
||||
if err != nil {
|
||||
if isDescopeNotFound(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return user != nil, nil
|
||||
}
|
||||
|
||||
phone := normalizePhone(loginID)
|
||||
searchOptions := &descope.UserSearchOptions{
|
||||
Phones: []string{phone},
|
||||
Limit: 1,
|
||||
}
|
||||
users, _, err := d.Client.Management.User().SearchAll(ctx, searchOptions)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return len(users) > 0, nil
|
||||
}
|
||||
|
||||
// IssueSession은 비밀번호 없이 로그인 세션을 발급합니다.
|
||||
func (d *DescopeProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
|
||||
if d.Client == nil {
|
||||
return nil, fmt.Errorf("descope provider: client is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
targetLoginID, err := d.resolveLoginID(loginID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
embeddedToken, err := d.Client.Management.User().GenerateEmbeddedLink(ctx, targetLoginID, nil, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("descope provider: generate embedded link failed: %w", err)
|
||||
}
|
||||
|
||||
authInfo, err := d.Client.Auth.MagicLink().Verify(ctx, embeddedToken, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("descope provider: magic link verify failed: %w", err)
|
||||
}
|
||||
|
||||
res := &domain.AuthInfo{
|
||||
SessionToken: &domain.Token{
|
||||
JWT: authInfo.SessionToken.JWT,
|
||||
Expiration: time.Unix(authInfo.SessionToken.Expiration, 0),
|
||||
},
|
||||
Subject: authInfo.User.UserID,
|
||||
}
|
||||
if authInfo.RefreshToken != nil {
|
||||
res.RefreshToken = &domain.Token{
|
||||
JWT: authInfo.RefreshToken.JWT,
|
||||
Expiration: time.Unix(authInfo.RefreshToken.Expiration, 0),
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (d *DescopeProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
|
||||
func (d *DescopeProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
|
||||
// GetPasswordPolicy는 Descope 비밀번호 정책을 반환합니다.
|
||||
func (d *DescopeProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
|
||||
if d.Client == nil {
|
||||
return nil, fmt.Errorf("descope provider: client is nil")
|
||||
}
|
||||
policy, err := d.Client.Auth.Password().GetPasswordPolicy(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &domain.PasswordPolicy{
|
||||
MinLength: int(policy.MinLength),
|
||||
Lowercase: policy.Lowercase,
|
||||
Uppercase: policy.Uppercase,
|
||||
Number: policy.Number,
|
||||
NonAlphanumeric: policy.NonAlphanumeric,
|
||||
MinCharacterTypes: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *DescopeProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
|
||||
ctx := context.Background()
|
||||
err := d.Client.Auth.Password().SendPasswordReset(ctx, loginID, redirectUrl, nil)
|
||||
@@ -197,3 +292,57 @@ func (d *DescopeProvider) UpdateUserPassword(loginID, newPassword string, r *htt
|
||||
ctx := context.Background()
|
||||
return d.Client.Auth.Password().UpdateUserPassword(ctx, loginID, newPassword, r)
|
||||
}
|
||||
|
||||
func (d *DescopeProvider) resolveLoginID(loginID string) (string, error) {
|
||||
if strings.Contains(loginID, "@") {
|
||||
return loginID, nil
|
||||
}
|
||||
|
||||
phone := normalizePhone(loginID)
|
||||
searchOptions := &descope.UserSearchOptions{
|
||||
Phones: []string{phone},
|
||||
Limit: 1,
|
||||
}
|
||||
users, _, err := d.Client.Management.User().SearchAll(context.Background(), searchOptions)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("descope provider: user search failed: %w", err)
|
||||
}
|
||||
if len(users) == 0 {
|
||||
return "", fmt.Errorf("descope provider: user not found")
|
||||
}
|
||||
if len(users[0].LoginIDs) > 0 {
|
||||
return users[0].LoginIDs[0], nil
|
||||
}
|
||||
if users[0].UserID != "" {
|
||||
return users[0].UserID, nil
|
||||
}
|
||||
return "", fmt.Errorf("descope provider: user found but login id missing")
|
||||
}
|
||||
|
||||
func normalizePhone(phone string) string {
|
||||
normalized := strings.ReplaceAll(phone, "-", "")
|
||||
normalized = strings.ReplaceAll(normalized, " ", "")
|
||||
if strings.HasPrefix(normalized, "010") {
|
||||
return "+82" + normalized[1:]
|
||||
}
|
||||
if strings.HasPrefix(normalized, "82") {
|
||||
return "+" + normalized
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func isDescopeNotFound(err error) bool {
|
||||
if de, ok := err.(*descope.Error); ok {
|
||||
if rawStatus, ok := de.Info[descope.ErrorInfoKeys.HTTPResponseStatusCode]; ok {
|
||||
switch v := rawStatus.(type) {
|
||||
case int:
|
||||
return v == http.StatusNotFound
|
||||
case float64:
|
||||
return int(v) == http.StatusNotFound
|
||||
case string:
|
||||
return v == fmt.Sprintf("%d", http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -63,6 +64,15 @@ func (o *OryProvider) CreateUser(user *domain.BrokerUser, password string) (stri
|
||||
if existingID != "" {
|
||||
return "", fmt.Errorf("ory provider: identity already exists for email=%s", user.Email)
|
||||
}
|
||||
if user.PhoneNumber != "" {
|
||||
existingPhoneID, err := o.findIdentityID(user.PhoneNumber)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ory provider: search identity failed: %w", err)
|
||||
}
|
||||
if existingPhoneID != "" {
|
||||
return "", fmt.Errorf("ory provider: identity already exists for phone=%s", user.PhoneNumber)
|
||||
}
|
||||
}
|
||||
|
||||
traits := map[string]interface{}{
|
||||
"email": user.Email,
|
||||
@@ -84,6 +94,27 @@ func (o *OryProvider) CreateUser(user *domain.BrokerUser, password string) (stri
|
||||
},
|
||||
},
|
||||
}
|
||||
verifiable := []map[string]interface{}{
|
||||
{
|
||||
"value": user.Email,
|
||||
"verified": true,
|
||||
"via": "email",
|
||||
},
|
||||
}
|
||||
if user.PhoneNumber != "" {
|
||||
verifiable = append(verifiable, map[string]interface{}{
|
||||
"value": user.PhoneNumber,
|
||||
"verified": true,
|
||||
"via": "sms",
|
||||
})
|
||||
}
|
||||
payload["verifiable_addresses"] = verifiable
|
||||
payload["recovery_addresses"] = []map[string]interface{}{
|
||||
{
|
||||
"value": user.Email,
|
||||
"via": "email",
|
||||
},
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(payload)
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, fmt.Sprintf("%s/admin/identities", o.KratosAdminURL), bytes.NewReader(body))
|
||||
@@ -119,7 +150,7 @@ func (o *OryProvider) SignIn(loginID, password string) (*domain.AuthInfo, error)
|
||||
return nil, fmt.Errorf("ory provider: loginID and password are required")
|
||||
}
|
||||
|
||||
flowID, err := o.startLoginFlow()
|
||||
flowID, err := o.startLoginFlow("")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -178,6 +209,326 @@ func (o *OryProvider) SignIn(loginID, password string) (*domain.AuthInfo, error)
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UserExists는 Kratos Admin API로 loginID 존재 여부를 확인합니다.
|
||||
func (o *OryProvider) UserExists(loginID string) (bool, error) {
|
||||
if loginID == "" {
|
||||
return false, fmt.Errorf("ory provider: loginID is empty")
|
||||
}
|
||||
identityID, err := o.findIdentityID(loginID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("ory provider: find identity failed: %w", err)
|
||||
}
|
||||
return identityID != "", nil
|
||||
}
|
||||
|
||||
// IssueSession은 Ory에서 별도 세션 발급이 필요할 때 사용합니다. (현재 미지원)
|
||||
func (o *OryProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
|
||||
// InitiateLinkLogin은 Kratos Public API로 링크 로그인 플로우를 시작하고 이메일 전송을 트리거합니다.
|
||||
func (o *OryProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
if loginID == "" {
|
||||
return nil, fmt.Errorf("ory provider: loginID is required")
|
||||
}
|
||||
|
||||
init, err := o.submitLoginCodeInit(loginID, returnTo)
|
||||
if err == nil {
|
||||
return init, nil
|
||||
}
|
||||
|
||||
if shouldBootstrapCodeLogin(err) {
|
||||
if ensureErr := o.ensureCodeLoginIdentifier(loginID); ensureErr == nil {
|
||||
return o.submitLoginCodeInit(loginID, returnTo)
|
||||
} else {
|
||||
slog.Warn("Ory code login bootstrap failed", "loginID", loginID, "error", ensureErr)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (o *OryProvider) submitLoginCodeInit(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
flowID, err := o.startLoginFlow(returnTo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"method": "code",
|
||||
"identifier": loginID,
|
||||
})
|
||||
loginURL := fmt.Sprintf("%s/self-service/login?flow=%s", o.KratosPublicURL, flowID)
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, loginURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: build link login request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: link login request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
if resp.StatusCode >= 300 {
|
||||
init, ok := parseKratosLinkLoginResponse(flowID, respBody)
|
||||
if ok {
|
||||
slog.Info("Ory link login initiated with non-2xx response", "loginID", loginID, "flow_id", flowID, "status", resp.StatusCode)
|
||||
return init, nil
|
||||
}
|
||||
return nil, fmt.Errorf("ory provider: link login failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
_ = json.Unmarshal(respBody, &result)
|
||||
|
||||
slog.Info("Ory link login initiated", "loginID", loginID, "flow_id", flowID)
|
||||
|
||||
return &domain.LinkLoginInit{
|
||||
FlowID: flowID,
|
||||
ExpiresAt: result.ExpiresAt,
|
||||
Mode: "link",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseKratosLinkLoginResponse(flowID string, body []byte) (*domain.LinkLoginInit, bool) {
|
||||
if len(body) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
var parsed struct {
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
State string `json:"state"`
|
||||
Active string `json:"active"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
state := strings.ToLower(parsed.State)
|
||||
active := strings.ToLower(parsed.Active)
|
||||
if strings.Contains(state, "sent") || active == "code" {
|
||||
return &domain.LinkLoginInit{
|
||||
FlowID: flowID,
|
||||
ExpiresAt: parsed.ExpiresAt,
|
||||
Mode: "link",
|
||||
}, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func shouldBootstrapCodeLogin(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
return strings.Contains(msg, "has not setup sign in with code") ||
|
||||
strings.Contains(msg, "4000035")
|
||||
}
|
||||
|
||||
type kratosVerifiableAddress struct {
|
||||
Value string `json:"value"`
|
||||
Via string `json:"via"`
|
||||
Verified bool `json:"verified"`
|
||||
Status string `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
func (o *OryProvider) ensureCodeLoginIdentifier(loginID string) error {
|
||||
identityID, err := o.findIdentityID(loginID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ory provider: find identity failed: %w", err)
|
||||
}
|
||||
if identityID == "" {
|
||||
return fmt.Errorf("ory provider: identity not found for loginID=%s", loginID)
|
||||
}
|
||||
|
||||
identity, err := o.fetchIdentity(identityID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
via := "sms"
|
||||
if strings.Contains(loginID, "@") {
|
||||
via = "email"
|
||||
}
|
||||
|
||||
exists := false
|
||||
existingIndex := -1
|
||||
addresses := make([]kratosVerifiableAddress, 0, len(identity.VerifiableAddresses)+1)
|
||||
for idx, addr := range identity.VerifiableAddresses {
|
||||
addresses = append(addresses, kratosVerifiableAddress{
|
||||
Value: addr.Value,
|
||||
Via: addr.Via,
|
||||
Verified: addr.Verified,
|
||||
Status: addr.Status,
|
||||
})
|
||||
if addr.Value == loginID && addr.Via == via {
|
||||
exists = true
|
||||
existingIndex = idx
|
||||
}
|
||||
}
|
||||
ops := make([]map[string]interface{}, 0, 2)
|
||||
if !exists {
|
||||
ops = append(ops, map[string]interface{}{
|
||||
"op": "add",
|
||||
"path": "/verifiable_addresses/-",
|
||||
"value": map[string]interface{}{
|
||||
"value": loginID,
|
||||
"via": via,
|
||||
"verified": true,
|
||||
"status": "completed",
|
||||
},
|
||||
})
|
||||
} else {
|
||||
addr := identity.VerifiableAddresses[existingIndex]
|
||||
if !addr.Verified {
|
||||
ops = append(ops, map[string]interface{}{
|
||||
"op": "replace",
|
||||
"path": fmt.Sprintf("/verifiable_addresses/%d/verified", existingIndex),
|
||||
"value": true,
|
||||
})
|
||||
}
|
||||
if addr.Status != "" && addr.Status != "completed" {
|
||||
ops = append(ops, map[string]interface{}{
|
||||
"op": "replace",
|
||||
"path": fmt.Sprintf("/verifiable_addresses/%d/status", existingIndex),
|
||||
"value": "completed",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(ops) == 0 {
|
||||
slog.Info("Ory identity verifiable address already ready", "identity_id", identityID, "loginID", loginID, "via", via)
|
||||
return nil
|
||||
}
|
||||
|
||||
return o.patchIdentity(identityID, ops)
|
||||
}
|
||||
|
||||
type kratosIdentity struct {
|
||||
VerifiableAddresses []kratosVerifiableAddress `json:"verifiable_addresses"`
|
||||
}
|
||||
|
||||
func (o *OryProvider) patchIdentity(identityID string, ops []map[string]interface{}) error {
|
||||
body, _ := json.Marshal(ops)
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodPatch, fmt.Sprintf("%s/admin/identities/%s", o.KratosAdminURL, identityID), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("ory provider: build identity patch failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json-patch+json")
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ory provider: identity patch failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return fmt.Errorf("ory provider: identity patch failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
slog.Info("Ory identity patched", "identity_id", identityID, "ops", len(ops))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *OryProvider) fetchIdentity(identityID string) (*kratosIdentity, error) {
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, fmt.Sprintf("%s/admin/identities/%s", o.KratosAdminURL, identityID), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: build identity get failed: %w", err)
|
||||
}
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: identity get failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return nil, fmt.Errorf("ory provider: identity get failed status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var identity kratosIdentity
|
||||
if err := json.NewDecoder(resp.Body).Decode(&identity); err != nil {
|
||||
return nil, fmt.Errorf("ory provider: decode identity failed: %w", err)
|
||||
}
|
||||
return &identity, nil
|
||||
}
|
||||
|
||||
// VerifyLoginCode는 Kratos 로그인 코드 제출로 세션을 발급합니다.
|
||||
func (o *OryProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
||||
if loginID == "" || flowID == "" || code == "" {
|
||||
return nil, fmt.Errorf("ory provider: loginID, flowID and code are required")
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"method": "code",
|
||||
"identifier": loginID,
|
||||
"code": code,
|
||||
})
|
||||
loginURL := fmt.Sprintf("%s/self-service/login?flow=%s", o.KratosPublicURL, flowID)
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, loginURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: build login code request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: login code request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return nil, fmt.Errorf("ory provider: login code failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
SessionToken string `json:"session_token"`
|
||||
SessionTokenExpiresAt time.Time `json:"session_token_expires_at"`
|
||||
Session struct {
|
||||
Identity struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"identity"`
|
||||
} `json:"session"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("ory provider: decode login code response failed: %w", err)
|
||||
}
|
||||
if result.SessionToken == "" {
|
||||
return nil, fmt.Errorf("ory provider: empty session token returned")
|
||||
}
|
||||
|
||||
slog.Info("Ory login code successful",
|
||||
"identity_id", result.Session.Identity.ID,
|
||||
"loginID", loginID,
|
||||
"expires_at", result.SessionTokenExpiresAt,
|
||||
)
|
||||
|
||||
return &domain.AuthInfo{
|
||||
SessionToken: &domain.Token{
|
||||
JWT: result.SessionToken,
|
||||
Expiration: result.SessionTokenExpiresAt,
|
||||
},
|
||||
Subject: result.Session.Identity.ID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetPasswordPolicy는 Ory 환경에서 사용하는 기본 정책을 반환합니다.
|
||||
func (o *OryProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
|
||||
return &domain.PasswordPolicy{
|
||||
MinLength: 12,
|
||||
Lowercase: true,
|
||||
Uppercase: false,
|
||||
Number: true,
|
||||
NonAlphanumeric: true,
|
||||
MinCharacterTypes: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// InitiatePasswordReset는 현재 내부 토큰/메일 흐름을 사용하고 있으므로 NO-OP로 둡니다.
|
||||
func (o *OryProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
|
||||
slog.Info("Ory InitiatePasswordReset bypassed (handled by app internal flow)", "loginID", loginID, "redirect", redirectUrl)
|
||||
@@ -301,8 +652,12 @@ func (o *OryProvider) httpClient() *http.Client {
|
||||
}
|
||||
|
||||
// startLoginFlow는 Kratos Public API에서 login flow ID를 발급받습니다.
|
||||
func (o *OryProvider) startLoginFlow() (string, error) {
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, fmt.Sprintf("%s/self-service/login/api", o.KratosPublicURL), nil)
|
||||
func (o *OryProvider) startLoginFlow(returnTo string) (string, error) {
|
||||
loginURL := fmt.Sprintf("%s/self-service/login/api", o.KratosPublicURL)
|
||||
if returnTo != "" {
|
||||
loginURL = loginURL + "?return_to=" + url.QueryEscape(returnTo)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, loginURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ory provider: build login flow request failed: %w", err)
|
||||
}
|
||||
|
||||
79
backend/internal/utils/masking.go
Normal file
79
backend/internal/utils/masking.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var sensitiveKeys = map[string]struct{}{
|
||||
"password": {},
|
||||
"newpassword": {},
|
||||
"oldpassword": {},
|
||||
"token": {},
|
||||
"accesstoken": {},
|
||||
"access_token": {},
|
||||
"refreshtoken": {},
|
||||
"refresh_token": {},
|
||||
"secret": {},
|
||||
"clientsecret": {},
|
||||
"client_secret": {},
|
||||
"authorization": {},
|
||||
"cookie": {},
|
||||
"set-cookie": {},
|
||||
"verificationcode": {},
|
||||
"verification_code": {},
|
||||
"code": {}, // Auth code (sensitive)
|
||||
}
|
||||
|
||||
// MaskSensitiveJSON parses a JSON byte slice and masks values of sensitive keys.
|
||||
// Returns the original data if it's not valid JSON.
|
||||
func MaskSensitiveJSON(data []byte) []byte {
|
||||
if len(data) == 0 {
|
||||
return data
|
||||
}
|
||||
|
||||
var obj interface{}
|
||||
if err := json.Unmarshal(data, &obj); err != nil {
|
||||
// Not a JSON object/array, return as is
|
||||
return data
|
||||
}
|
||||
|
||||
masked := maskValue(obj)
|
||||
|
||||
result, err := json.Marshal(masked)
|
||||
if err != nil {
|
||||
return data
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func maskValue(v interface{}) interface{} {
|
||||
switch val := v.(type) {
|
||||
case map[string]interface{}:
|
||||
newMap := make(map[string]interface{}, len(val))
|
||||
for k, v := range val {
|
||||
if isSensitive(k) {
|
||||
newMap[k] = "*****"
|
||||
} else {
|
||||
newMap[k] = maskValue(v)
|
||||
}
|
||||
}
|
||||
return newMap
|
||||
case []interface{}:
|
||||
newArr := make([]interface{}, len(val))
|
||||
for i, v := range val {
|
||||
newArr[i] = maskValue(v)
|
||||
}
|
||||
return newArr
|
||||
default:
|
||||
return val
|
||||
}
|
||||
}
|
||||
|
||||
func isSensitive(key string) bool {
|
||||
// Check case-insensitive
|
||||
// Remove common separators for looser matching? No, stick to lowercase check for now.
|
||||
k := strings.ToLower(key)
|
||||
_, ok := sensitiveKeys[k]
|
||||
return ok
|
||||
}
|
||||
59
backend/internal/utils/masking_test.go
Normal file
59
backend/internal/utils/masking_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMaskSensitiveJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string // We'll check containment or specific structure
|
||||
}{
|
||||
{
|
||||
name: "Flat object with password",
|
||||
input: `{"username": "user", "password": "secret123"}`,
|
||||
expected: `{"password":"*****","username":"user"}`,
|
||||
},
|
||||
{
|
||||
name: "Nested object with token",
|
||||
input: `{"data": {"token": "abc-def", "id": 123}}`,
|
||||
expected: `{"data":{"id":123,"token":"*****"}}`,
|
||||
},
|
||||
{
|
||||
name: "Case insensitive key",
|
||||
input: `{"NewPassword": "changed"}`,
|
||||
expected: `{"NewPassword":"*****"}`,
|
||||
},
|
||||
{
|
||||
name: "Array of objects",
|
||||
input: `[{"secret": "s1"}, {"secret": "s2"}]`,
|
||||
expected: `[{"secret":"*****"},{"secret":"*****"}]`,
|
||||
},
|
||||
{
|
||||
name: "Invalid JSON",
|
||||
input: `not-json`,
|
||||
expected: `not-json`,
|
||||
},
|
||||
{
|
||||
name: "Empty JSON",
|
||||
input: ``,
|
||||
expected: ``,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := MaskSensitiveJSON([]byte(tt.input))
|
||||
// Since JSON map order is undefined, exact string match might fail if keys are reordered.
|
||||
// Ideally we should unmarshal and compare maps, or use assert.JSONEq
|
||||
if tt.name == "Invalid JSON" || tt.name == "Empty JSON" {
|
||||
assert.Equal(t, tt.expected, string(result))
|
||||
} else {
|
||||
assert.JSONEq(t, tt.expected, string(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -29,6 +29,26 @@ func (m *MockProvider) SignIn(loginID, password string) (*domain.AuthInfo, error
|
||||
return &domain.AuthInfo{}, nil
|
||||
}
|
||||
|
||||
func (m *MockProvider) UserExists(loginID string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *MockProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
|
||||
func (m *MockProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
|
||||
func (m *MockProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
|
||||
func (m *MockProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
|
||||
// Stub implementations to satisfy the IdentityProvider interface for this unit test.
|
||||
func (m *MockProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user