1
0
forked from baron/baron-sso

audit 로그 개선. kratos 코드발급 링크로 전송까지 진행 완료 #104

This commit is contained in:
Lectom C Han
2026-01-29 01:20:19 +09:00
parent ff17259117
commit b88de7ec91
46 changed files with 2843 additions and 585 deletions

View File

@@ -267,10 +267,13 @@ func main() {
})
app.Use(recover.New())
allowedOrigins := getEnv("CORS_ALLOWED_ORIGINS", "http://localhost:5000")
allowCredentials := allowedOrigins != "*"
app.Use(cors.New(cors.Config{
AllowOrigins: "*", // Adjust in production
AllowHeaders: "Origin, Content-Type, Accept, Authorization",
AllowMethods: "GET, POST, HEAD, PUT, DELETE, PATCH, OPTIONS",
AllowOrigins: allowedOrigins,
AllowHeaders: "Origin, Content-Type, Accept, Authorization",
AllowMethods: "GET, POST, HEAD, PUT, DELETE, PATCH, OPTIONS",
AllowCredentials: allowCredentials,
}))
// Ensure COOKIE_SECRET is exactly 32 bytes for AES-256
@@ -384,12 +387,19 @@ func main() {
// API Group
api := app.Group("/api/v1")
api.Use(middleware.RequireAudit(middleware.AuditRequiredConfig{
workerCount, _ := strconv.Atoi(getEnv("AUDIT_WORKER_COUNT", "5"))
queueSize, _ := strconv.Atoi(getEnv("AUDIT_QUEUE_SIZE", "2000"))
api.Use(middleware.AuditMiddleware(middleware.AuditConfig{
Repo: auditRepo,
ExcludePaths: map[string]struct{}{
"/api/v1/audit": {},
"/api/v1/client-log": {},
},
BodyDump: true,
WorkerCount: workerCount,
QueueSize: queueSize,
}))
api.Post("/audit", auditHandler.CreateLog)
api.Get("/audit", auditHandler.ListLogs)
@@ -399,6 +409,7 @@ func main() {
auth.Post("/enchanted-link/init", authHandler.InitEnchantedLink)
auth.Post("/enchanted-link/poll", authHandler.PollEnchantedLink)
auth.Post("/magic-link/verify", authHandler.VerifyMagicLink)
auth.Post("/login/code/verify", authHandler.VerifyLoginCode)
auth.Post("/password/login", authHandler.PasswordLogin)
auth.Post("/password/reset/initiate", authHandler.InitiatePasswordReset)
// [Changed] Use Interstitial Page for GET to prevent Scanner consumption
@@ -431,6 +442,7 @@ func main() {
// Admin Routes
admin := api.Group("/admin")
admin.Get("/check", adminHandler.CheckAuth)
admin.Get("/stats", adminHandler.GetSystemStats)
admin.Get("/tenants", tenantHandler.ListTenants)
admin.Post("/tenants", tenantHandler.CreateTenant)
admin.Get("/tenants/:id", tenantHandler.GetTenant)
@@ -454,6 +466,9 @@ func main() {
// Webhook for Descope Generic Email Gateway (Fake Email Strategy)
auth.Post("/webhooks/descope-email", authHandler.HandleDescopeEmailRelay)
// Webhook for Kratos courier (HTTP delivery)
auth.Post("/webhooks/kratos-courier", authHandler.HandleKratosCourierRelay)
// Client Logging Route (Standardized & Flattened)
api.Post("/client-log", func(c *fiber.Ctx) error {
type LogReq struct {

View File

@@ -230,7 +230,7 @@ paths:
content:
application/json:
schema:
$ref: "#/components/schemas/MessageResponse"
$ref: "#/components/schemas/MagicLinkVerifyResponse"
/api/v1/auth/sms:
post:
@@ -266,7 +266,7 @@ paths:
content:
application/json:
schema:
$ref: "#/components/schemas/MessageResponse"
$ref: "#/components/schemas/SmsVerifyResponse"
/api/v1/auth/qr/init:
post:
@@ -908,18 +908,28 @@ components:
type: boolean
nonAlphanumeric:
type: boolean
minCharacterTypes:
type: integer
EnchantedLinkInitRequest:
type: object
properties:
loginId:
type: string
uri:
type: string
method:
type: string
EnchantedLinkInitResponse:
type: object
properties:
linkId:
type: string
pendingRef:
type: string
maskedEmail:
type: string
expiresIn:
type: integer
@@ -943,22 +953,36 @@ components:
token:
type: string
MagicLinkVerifyResponse:
type: object
properties:
token:
type: string
message:
type: string
SmsSendRequest:
type: object
properties:
phone:
type: string
message:
phoneNumber:
type: string
SmsVerifyRequest:
type: object
properties:
phone:
phoneNumber:
type: string
code:
type: string
SmsVerifyResponse:
type: object
properties:
token:
type: string
message:
type: string
QrInitResponse:
type: object
properties:

View File

@@ -14,6 +14,7 @@ require (
github.com/gofiber/fiber/v2 v2.52.10
github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1
github.com/stretchr/testify v1.11.1
golang.org/x/crypto v0.46.0
gorm.io/driver/postgres v1.6.0
gorm.io/gorm v1.31.1
@@ -34,6 +35,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/sts v1.41.6 // indirect
github.com/aws/smithy-go v1.24.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/go-faster/city v1.0.1 // indirect
@@ -57,10 +59,12 @@ require (
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/paulmach/orb v0.12.0 // indirect
github.com/pierrec/lz4/v4 v4.1.22 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
github.com/segmentio/asm v1.2.1 // indirect
github.com/shopspring/decimal v1.4.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.51.0 // indirect
github.com/valyala/tcplisten v1.0.0 // indirect
@@ -71,4 +75,5 @@ require (
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.39.0 // indirect
golang.org/x/text v0.32.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -137,6 +137,8 @@ github.com/segmentio/asm v1.2.1/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=

View File

@@ -1,10 +1,14 @@
package domain
import (
"errors"
"net/http"
"time"
)
// ErrNotSupported는 IDP가 특정 인증 흐름을 지원하지 않을 때 반환합니다.
var ErrNotSupported = errors.New("idp: not supported")
// BrokerUser is the standard user model used within Baron SSO business logic.
// It defines the canonical set of fields that must be supported by any underlying IDP.
type BrokerUser struct {
@@ -24,6 +28,16 @@ type IDPMetadata struct {
SupportedFields []string
}
// PasswordPolicy는 비밀번호 정책 정보를 표현합니다.
type PasswordPolicy struct {
MinLength int
Lowercase bool
Uppercase bool
Number bool
NonAlphanumeric bool
MinCharacterTypes int
}
// Token represents a session or refresh token.
type Token struct {
JWT string
@@ -38,6 +52,14 @@ type AuthInfo struct {
Subject string
}
// LinkLoginInit는 링크 로그인 초기화 결과입니다.
type LinkLoginInit struct {
FlowID string
ExpiresAt time.Time
// Mode는 링크 로그인 완료 후 세션 처리 방식입니다. (예: "cookie")
Mode string
}
// IdentityProvider is the interface that all IDP adapters must implement.
type IdentityProvider interface {
Name() string
@@ -48,6 +70,16 @@ type IdentityProvider interface {
CreateUser(user *BrokerUser, password string) (string, error)
// SignIn은 로그인 ID/비밀번호로 인증해 세션 정보를 반환합니다.
SignIn(loginID, password string) (*AuthInfo, error)
// UserExists는 loginID 기준으로 사용자 존재 여부를 확인합니다.
UserExists(loginID string) (bool, error)
// IssueSession은 비밀번호 없이 세션을 발급해야 하는 흐름에서 사용합니다.
IssueSession(loginID string) (*AuthInfo, error)
// InitiateLinkLogin은 링크 기반 로그인 요청을 IDP에 전달합니다.
InitiateLinkLogin(loginID, returnTo string) (*LinkLoginInit, error)
// VerifyLoginCode는 링크/코드 기반 로그인에서 코드를 제출해 세션을 발급합니다.
VerifyLoginCode(loginID, flowID, code string) (*AuthInfo, error)
// GetPasswordPolicy는 IDP가 제공하는 비밀번호 정책을 반환합니다.
GetPasswordPolicy() (*PasswordPolicy, error)
InitiatePasswordReset(loginID, redirectUrl string) error
VerifyPasswordResetToken(token string) (*AuthInfo, error)
UpdateUserPassword(loginID, newPassword string, r *http.Request) error

View File

@@ -3,6 +3,8 @@ package handler
import (
"log/slog"
"os"
"runtime"
"time"
"github.com/descope/go-sdk/descope/client"
"github.com/gofiber/fiber/v2"
@@ -39,3 +41,23 @@ func NewAdminHandler() *AdminHandler {
func (h *AdminHandler) CheckAuth(c *fiber.Ctx) error {
return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "ok"})
}
// GetSystemStats returns runtime statistics for monitoring
func (h *AdminHandler) GetSystemStats(c *fiber.Ctx) error {
var m runtime.MemStats
runtime.ReadMemStats(&m)
stats := fiber.Map{
"goroutines": runtime.NumGoroutine(),
"cpus": runtime.NumCPU(),
"memory": fiber.Map{
"alloc": m.Alloc,
"totalAlign": m.TotalAlloc,
"sys": m.Sys,
"numGC": m.NumGC,
},
"timestamp": time.Now(),
}
return c.Status(fiber.StatusOK).JSON(stats)
}

File diff suppressed because it is too large Load Diff

View File

@@ -72,7 +72,7 @@ func TestCompletePasswordReset_InvalidPasswordPolicy(t *testing.T) {
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if got["error"] != "Password must be at least 8 characters long" {
if got["error"] != "비밀번호는 최소 12자 이상이어야 합니다" {
t.Fatalf("unexpected error message: %v", got["error"])
}
}

View File

@@ -124,43 +124,144 @@ func (c *chainedProvider) GetMetadata() (*domain.IDPMetadata, error) {
}
func (c *chainedProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
var errs []error
for idx, p := range c.providers {
for _, p := range c.providers {
id, err := p.CreateUser(user, password)
if err != nil {
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
if errors.Is(err, domain.ErrNotSupported) {
continue
}
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "CreateUser", "error", err)
continue
}
if idx > 0 {
slog.Info("IDP fallback succeeded", "operation", "CreateUser", "provider", p.Name())
return "", err
}
return id, nil
}
if len(errs) == 0 {
return "", fmt.Errorf("no IDP providers available for CreateUser")
}
return "", fmt.Errorf("all IDP providers failed for CreateUser: %w", errors.Join(errs...))
return "", domain.ErrNotSupported
}
func (c *chainedProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
var errs []error
for idx, p := range c.providers {
for _, p := range c.providers {
info, err := p.SignIn(loginID, password)
if err != nil {
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
if errors.Is(err, domain.ErrNotSupported) {
continue
}
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "SignIn", "error", err)
return nil, err
}
return info, nil
}
return nil, domain.ErrNotSupported
}
func (c *chainedProvider) UserExists(loginID string) (bool, error) {
var errs []error
for _, p := range c.providers {
exists, err := p.UserExists(loginID)
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
continue
}
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
continue
}
if exists {
return true, nil
}
}
if len(errs) == 0 {
return false, nil
}
return false, fmt.Errorf("all IDP providers failed for UserExists: %w", errors.Join(errs...))
}
func (c *chainedProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
var errs []error
for idx, p := range c.providers {
info, err := p.IssueSession(loginID)
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
continue
}
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "IssueSession", "error", err)
continue
}
if idx > 0 {
slog.Info("IDP fallback succeeded", "operation", "SignIn", "provider", p.Name())
slog.Info("IDP fallback succeeded", "operation", "IssueSession", "provider", p.Name())
}
return info, nil
}
if len(errs) == 0 {
return nil, fmt.Errorf("no IDP providers available for SignIn")
return nil, domain.ErrNotSupported
}
return nil, fmt.Errorf("all IDP providers failed for SignIn: %w", errors.Join(errs...))
return nil, fmt.Errorf("all IDP providers failed for IssueSession: %w", errors.Join(errs...))
}
func (c *chainedProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
var errs []error
for idx, p := range c.providers {
info, err := p.InitiateLinkLogin(loginID, returnTo)
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
continue
}
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "InitiateLinkLogin", "error", err)
continue
}
if idx > 0 {
slog.Info("IDP fallback succeeded", "operation", "InitiateLinkLogin", "provider", p.Name())
}
return info, nil
}
if len(errs) == 0 {
return nil, domain.ErrNotSupported
}
return nil, fmt.Errorf("all IDP providers failed for InitiateLinkLogin: %w", errors.Join(errs...))
}
func (c *chainedProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
var errs []error
for idx, p := range c.providers {
info, err := p.VerifyLoginCode(loginID, flowID, code)
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
continue
}
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "VerifyLoginCode", "error", err)
continue
}
if idx > 0 {
slog.Info("IDP fallback succeeded", "operation", "VerifyLoginCode", "provider", p.Name())
}
return info, nil
}
if len(errs) == 0 {
return nil, domain.ErrNotSupported
}
return nil, fmt.Errorf("all IDP providers failed for VerifyLoginCode: %w", errors.Join(errs...))
}
func (c *chainedProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
var errs []error
for _, p := range c.providers {
policy, err := p.GetPasswordPolicy()
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
continue
}
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
continue
}
if policy != nil {
return policy, nil
}
}
if len(errs) == 0 {
return nil, domain.ErrNotSupported
}
return nil, fmt.Errorf("all IDP providers failed for GetPasswordPolicy: %w", errors.Join(errs...))
}
func (c *chainedProvider) InitiatePasswordReset(loginID, redirectUrl string) error {

View File

@@ -10,19 +10,31 @@ import (
)
type stubProvider struct {
name string
metadata []string
createErr error
initiateErr error
verifyErr error
updateErr error
signInErr error
initiateCalls int
verifyCalls int
updateCalls int
signInCalls int
createCalls int
verifyResponse *domain.AuthInfo
name string
metadata []string
createErr error
initiateErr error
verifyErr error
updateErr error
signInErr error
userExistsErr error
issueErr error
linkInitErr error
verifyCodeErr error
policyErr error
initiateCalls int
verifyCalls int
updateCalls int
signInCalls int
createCalls int
userExistsCalls int
issueCalls int
linkInitCalls int
verifyCodeCalls int
policyCalls int
verifyResponse *domain.AuthInfo
userExists bool
policy *domain.PasswordPolicy
}
func (s *stubProvider) Name() string { return s.name }
@@ -47,6 +59,46 @@ func (s *stubProvider) SignIn(loginID, password string) (*domain.AuthInfo, error
return &domain.AuthInfo{Subject: "subject-123"}, nil
}
func (s *stubProvider) UserExists(loginID string) (bool, error) {
s.userExistsCalls++
if s.userExistsErr != nil {
return false, s.userExistsErr
}
return s.userExists, nil
}
func (s *stubProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
s.issueCalls++
if s.issueErr != nil {
return nil, s.issueErr
}
return &domain.AuthInfo{Subject: "issue-subject"}, nil
}
func (s *stubProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
s.linkInitCalls++
if s.linkInitErr != nil {
return nil, s.linkInitErr
}
return &domain.LinkLoginInit{FlowID: "flow-123", Mode: "cookie"}, nil
}
func (s *stubProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
s.verifyCodeCalls++
if s.verifyCodeErr != nil {
return nil, s.verifyCodeErr
}
return &domain.AuthInfo{Subject: "verify-code-subject"}, nil
}
func (s *stubProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
s.policyCalls++
if s.policyErr != nil {
return nil, s.policyErr
}
return s.policy, nil
}
func (s *stubProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
s.initiateCalls++
return s.initiateErr

View File

@@ -0,0 +1,183 @@
package middleware
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/utils"
"encoding/json"
"fmt"
"log/slog"
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
)
type AuditConfig struct {
Repo domain.AuditRepository
ExcludePaths map[string]struct{}
BodyDump bool
WorkerCount int
QueueSize int
}
// AuditMiddleware provides comprehensive audit logging for all requests.
// It enforces strict logging for state-changing commands (POST, PUT, DELETE, PATCH)
// and best-effort logging for queries (GET, HEAD, OPTIONS).
func AuditMiddleware(config AuditConfig) fiber.Handler {
// 0. Initialize Worker Pool for Async Logging
if config.WorkerCount <= 0 {
config.WorkerCount = 5 // Default workers
}
if config.QueueSize <= 0 {
config.QueueSize = 1000 // Default queue size
}
auditQueue := make(chan *domain.AuditLog, config.QueueSize)
var once sync.Once
// Start workers only once
once.Do(func() {
for i := 0; i < config.WorkerCount; i++ {
go func(workerID int) {
slog.Debug("Audit worker started", "id", workerID)
for log := range auditQueue {
func() {
defer func() {
if r := recover(); r != nil {
slog.Error("Audit worker panic recovery", "reason", r, "req_id", log.EventID)
}
}()
if err := config.Repo.Create(log); err != nil {
slog.Warn("Failed to write async audit log", "error", err, "req_id", log.EventID)
}
}()
}
}(i)
}
})
// Default methods classification
writeMethods := map[string]struct{}{
fiber.MethodPost: {},
fiber.MethodPut: {},
fiber.MethodPatch: {},
fiber.MethodDelete: {},
}
if config.ExcludePaths == nil {
config.ExcludePaths = map[string]struct{}{}
}
return func(c *fiber.Ctx) error {
// 1. Check exclusions
if _, excluded := config.ExcludePaths[c.Path()]; excluded {
return c.Next()
}
// 2. Setup context variables
start := time.Now()
reqID := c.Get("X-Request-Id")
if reqID == "" {
reqID = uuid.New().String()
c.Set("X-Request-Id", reqID)
}
// 3. Process Request
err := c.Next()
// 4. Gather Metrics & Context
latency := time.Since(start)
status := c.Response().StatusCode()
// If Fiber handler returned an error, status might default to 500 or be in the error
if err != nil {
if fiberErr, ok := err.(*fiber.Error); ok {
status = fiberErr.Code
} else {
status = fiber.StatusInternalServerError
}
}
statusText := "success"
if status >= fiber.StatusBadRequest {
statusText = "failure"
}
// 5. Extract User Context (populated by AuthMiddleware/TenantGuard)
userID, _ := c.Locals("user_id").(string)
loginID, _ := c.Locals("login_id").(string)
tenantID, _ := c.Locals("tenant_id").(string)
// 6. Capture & Mask Body
var maskedBody string
if config.BodyDump {
if c.Method() != fiber.MethodGet && c.Method() != fiber.MethodHead {
bodyBytes := c.Body()
if len(bodyBytes) > 0 {
maskedBytes := utils.MaskSensitiveJSON(bodyBytes)
maskedBody = string(maskedBytes)
}
}
}
// 7. Construct Details JSON
details := map[string]any{
"request_id": reqID,
"method": c.Method(),
"path": c.Path(),
"status": status,
"latency_ms": latency.Milliseconds(),
"login_id": loginID,
"tenant_id": tenantID,
"request_body": maskedBody,
}
if err != nil {
details["error"] = err.Error()
}
detailsJSON, _ := json.Marshal(details)
// 8. Create Audit Log Object
auditLog := &domain.AuditLog{
EventID: reqID,
Timestamp: start,
UserID: userID,
EventType: fmt.Sprintf("%s %s", c.Method(), c.Path()),
Status: statusText,
IPAddress: c.IP(),
UserAgent: c.Get("User-Agent"),
Details: string(detailsJSON),
}
// 9. Store Log (Policy Enforcement)
_, isWrite := writeMethods[c.Method()]
if config.Repo == nil {
if isWrite {
slog.Error("Audit repository missing for command", "req_id", reqID)
return fiber.NewError(fiber.StatusServiceUnavailable, "Audit system unavailable")
}
return err
}
if isWrite {
// Strict Mode: Synchronous write
if createErr := config.Repo.Create(auditLog); createErr != nil {
slog.Error("Failed to write audit log (sync)", "error", createErr, "req_id", reqID)
return fiber.NewError(fiber.StatusServiceUnavailable, "Audit logging failed")
}
} else {
// Best Effort: Load Shedding via Buffered Channel
select {
case auditQueue <- auditLog:
// Successfully queued
default:
// Queue full -> DROP (Load Shedding)
slog.Warn("Audit queue full, dropping log (load shedding)", "req_id", reqID, "path", c.Path())
}
}
return err
}
}

View File

@@ -0,0 +1,117 @@
package middleware
import (
"baron-sso-backend/internal/domain"
"context"
"encoding/json"
"errors"
"net/http/httptest"
"strings"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// MockAuditRepository is a mock implementation of AuditRepository
type MockAuditRepository struct {
mock.Mock
}
func (m *MockAuditRepository) Create(log *domain.AuditLog) error {
args := m.Called(log)
return args.Error(0)
}
func (m *MockAuditRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor) ([]domain.AuditLog, error) {
args := m.Called(ctx, limit, cursor)
return args.Get(0).([]domain.AuditLog), args.Error(1)
}
func (m *MockAuditRepository) Ping(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
func TestAuditMiddleware(t *testing.T) {
t.Run("POST request - Sync Success", func(t *testing.T) {
app := fiber.New()
mockRepo := new(MockAuditRepository)
app.Use(AuditMiddleware(AuditConfig{
Repo: mockRepo,
BodyDump: true,
}))
app.Post("/test", func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
mockRepo.On("Create", mock.MatchedBy(func(log *domain.AuditLog) bool {
var details map[string]any
json.Unmarshal([]byte(log.Details), &details)
return log.Status == "success" &&
details["method"] == "POST" &&
details["request_body"] == `{"password":"*****","user":"test"}`
})).Return(nil)
req := httptest.NewRequest("POST", "/test", strings.NewReader(`{"user": "test", "password": "mypassword"}`))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
mockRepo.AssertExpectations(t)
})
t.Run("POST request - Sync Failure (Strict Mode)", func(t *testing.T) {
app := fiber.New()
mockRepo := new(MockAuditRepository)
app.Use(AuditMiddleware(AuditConfig{
Repo: mockRepo,
}))
app.Post("/test", func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
mockRepo.On("Create", mock.Anything).Return(errors.New("db error"))
req := httptest.NewRequest("POST", "/test", nil)
resp, _ := app.Test(req)
// Should return 503 because Audit failed on a Write method
assert.Equal(t, fiber.StatusServiceUnavailable, resp.StatusCode)
})
t.Run("GET request - Async Load Shedding", func(t *testing.T) {
app := fiber.New()
mockRepo := new(MockAuditRepository)
// Set very small queue and no workers to force load shedding
app.Use(AuditMiddleware(AuditConfig{
Repo: mockRepo,
QueueSize: 1,
WorkerCount: 0, // This will be defaulted to 5 by the code, so let's use another way or just small queue
}))
app.Get("/test", func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
// 1. First request fills the queue
mockRepo.On("Create", mock.Anything).Return(nil)
req1 := httptest.NewRequest("GET", "/test", nil)
resp1, _ := app.Test(req1)
assert.Equal(t, fiber.StatusOK, resp1.StatusCode)
// 2. Second request should be dropped (load shedding) if workers are slow
// Since we can't easily pause workers without modifying code,
// this test mostly ensures the non-blocking send doesn't hang.
req2 := httptest.NewRequest("GET", "/test", nil)
resp2, _ := app.Test(req2)
assert.Equal(t, fiber.StatusOK, resp2.StatusCode)
})
}

View File

@@ -1,106 +0,0 @@
package middleware
import (
"baron-sso-backend/internal/domain"
"encoding/json"
"fmt"
"log/slog"
"time"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
)
type AuditRequiredConfig struct {
Repo domain.AuditRepository
ExcludePaths map[string]struct{}
CommandMethods map[string]struct{}
}
func RequireAudit(config AuditRequiredConfig) fiber.Handler {
commandMethods := config.CommandMethods
if len(commandMethods) == 0 {
commandMethods = map[string]struct{}{
fiber.MethodPost: {},
fiber.MethodPut: {},
fiber.MethodPatch: {},
fiber.MethodDelete: {},
}
}
excludePaths := config.ExcludePaths
if excludePaths == nil {
excludePaths = map[string]struct{}{}
}
return func(c *fiber.Ctx) error {
if _, ok := commandMethods[c.Method()]; !ok {
return c.Next()
}
if _, excluded := excludePaths[c.Path()]; excluded {
return c.Next()
}
if config.Repo == nil {
return fiber.NewError(fiber.StatusServiceUnavailable, "audit repository unavailable")
}
start := time.Now()
reqID := c.Get("X-Request-Id")
if reqID == "" {
reqID = uuid.New().String()
c.Set("X-Request-Id", reqID)
}
err := c.Next()
latency := time.Since(start)
status := c.Response().StatusCode()
if err != nil {
if fiberErr, ok := err.(*fiber.Error); ok {
status = fiberErr.Code
} else {
status = fiber.StatusInternalServerError
}
}
statusText := "success"
if status >= fiber.StatusBadRequest {
statusText = "failure"
}
details := map[string]any{
"request_id": reqID,
"method": c.Method(),
"path": c.Path(),
"status": status,
"latency_ms": latency.Milliseconds(),
}
if err != nil {
details["error"] = err.Error()
}
detailsJSON, jsonErr := json.Marshal(details)
if jsonErr != nil {
slog.Warn("failed to marshal audit details", "error", jsonErr, "req_id", reqID)
}
auditLog := &domain.AuditLog{
EventID: reqID,
Timestamp: time.Now(),
UserID: "",
EventType: fmt.Sprintf("%s %s", c.Method(), c.Path()),
Status: statusText,
IPAddress: c.IP(),
UserAgent: c.Get("User-Agent"),
DeviceID: "",
Details: string(detailsJSON),
}
if createErr := config.Repo.Create(auditLog); createErr != nil {
slog.Error("audit log write failed", "error", createErr, "req_id", reqID, "path", c.Path())
return fiber.NewError(fiber.StatusServiceUnavailable, "audit logging unavailable")
}
return err
}
}

View File

@@ -145,6 +145,101 @@ func (d *DescopeProvider) SignIn(loginID, password string) (*domain.AuthInfo, er
return res, nil
}
// UserExists는 loginID(이메일/전화번호) 기준으로 사용자가 있는지 확인합니다.
func (d *DescopeProvider) UserExists(loginID string) (bool, error) {
if d.Client == nil {
return false, fmt.Errorf("descope provider: client is nil")
}
ctx := context.Background()
if strings.Contains(loginID, "@") {
user, err := d.Client.Management.User().Load(ctx, loginID)
if err != nil {
if isDescopeNotFound(err) {
return false, nil
}
return false, err
}
return user != nil, nil
}
phone := normalizePhone(loginID)
searchOptions := &descope.UserSearchOptions{
Phones: []string{phone},
Limit: 1,
}
users, _, err := d.Client.Management.User().SearchAll(ctx, searchOptions)
if err != nil {
return false, err
}
return len(users) > 0, nil
}
// IssueSession은 비밀번호 없이 로그인 세션을 발급합니다.
func (d *DescopeProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
if d.Client == nil {
return nil, fmt.Errorf("descope provider: client is nil")
}
ctx := context.Background()
targetLoginID, err := d.resolveLoginID(loginID)
if err != nil {
return nil, err
}
embeddedToken, err := d.Client.Management.User().GenerateEmbeddedLink(ctx, targetLoginID, nil, 0)
if err != nil {
return nil, fmt.Errorf("descope provider: generate embedded link failed: %w", err)
}
authInfo, err := d.Client.Auth.MagicLink().Verify(ctx, embeddedToken, nil)
if err != nil {
return nil, fmt.Errorf("descope provider: magic link verify failed: %w", err)
}
res := &domain.AuthInfo{
SessionToken: &domain.Token{
JWT: authInfo.SessionToken.JWT,
Expiration: time.Unix(authInfo.SessionToken.Expiration, 0),
},
Subject: authInfo.User.UserID,
}
if authInfo.RefreshToken != nil {
res.RefreshToken = &domain.Token{
JWT: authInfo.RefreshToken.JWT,
Expiration: time.Unix(authInfo.RefreshToken.Expiration, 0),
}
}
return res, nil
}
func (d *DescopeProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
return nil, domain.ErrNotSupported
}
func (d *DescopeProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
return nil, domain.ErrNotSupported
}
// GetPasswordPolicy는 Descope 비밀번호 정책을 반환합니다.
func (d *DescopeProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
if d.Client == nil {
return nil, fmt.Errorf("descope provider: client is nil")
}
policy, err := d.Client.Auth.Password().GetPasswordPolicy(context.Background())
if err != nil {
return nil, err
}
return &domain.PasswordPolicy{
MinLength: int(policy.MinLength),
Lowercase: policy.Lowercase,
Uppercase: policy.Uppercase,
Number: policy.Number,
NonAlphanumeric: policy.NonAlphanumeric,
MinCharacterTypes: 0,
}, nil
}
func (d *DescopeProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
ctx := context.Background()
err := d.Client.Auth.Password().SendPasswordReset(ctx, loginID, redirectUrl, nil)
@@ -197,3 +292,57 @@ func (d *DescopeProvider) UpdateUserPassword(loginID, newPassword string, r *htt
ctx := context.Background()
return d.Client.Auth.Password().UpdateUserPassword(ctx, loginID, newPassword, r)
}
func (d *DescopeProvider) resolveLoginID(loginID string) (string, error) {
if strings.Contains(loginID, "@") {
return loginID, nil
}
phone := normalizePhone(loginID)
searchOptions := &descope.UserSearchOptions{
Phones: []string{phone},
Limit: 1,
}
users, _, err := d.Client.Management.User().SearchAll(context.Background(), searchOptions)
if err != nil {
return "", fmt.Errorf("descope provider: user search failed: %w", err)
}
if len(users) == 0 {
return "", fmt.Errorf("descope provider: user not found")
}
if len(users[0].LoginIDs) > 0 {
return users[0].LoginIDs[0], nil
}
if users[0].UserID != "" {
return users[0].UserID, nil
}
return "", fmt.Errorf("descope provider: user found but login id missing")
}
func normalizePhone(phone string) string {
normalized := strings.ReplaceAll(phone, "-", "")
normalized = strings.ReplaceAll(normalized, " ", "")
if strings.HasPrefix(normalized, "010") {
return "+82" + normalized[1:]
}
if strings.HasPrefix(normalized, "82") {
return "+" + normalized
}
return normalized
}
func isDescopeNotFound(err error) bool {
if de, ok := err.(*descope.Error); ok {
if rawStatus, ok := de.Info[descope.ErrorInfoKeys.HTTPResponseStatusCode]; ok {
switch v := rawStatus.(type) {
case int:
return v == http.StatusNotFound
case float64:
return int(v) == http.StatusNotFound
case string:
return v == fmt.Sprintf("%d", http.StatusNotFound)
}
}
}
return false
}

View File

@@ -12,6 +12,7 @@ import (
"net/http"
"net/url"
"os"
"strings"
"time"
)
@@ -63,6 +64,15 @@ func (o *OryProvider) CreateUser(user *domain.BrokerUser, password string) (stri
if existingID != "" {
return "", fmt.Errorf("ory provider: identity already exists for email=%s", user.Email)
}
if user.PhoneNumber != "" {
existingPhoneID, err := o.findIdentityID(user.PhoneNumber)
if err != nil {
return "", fmt.Errorf("ory provider: search identity failed: %w", err)
}
if existingPhoneID != "" {
return "", fmt.Errorf("ory provider: identity already exists for phone=%s", user.PhoneNumber)
}
}
traits := map[string]interface{}{
"email": user.Email,
@@ -84,6 +94,27 @@ func (o *OryProvider) CreateUser(user *domain.BrokerUser, password string) (stri
},
},
}
verifiable := []map[string]interface{}{
{
"value": user.Email,
"verified": true,
"via": "email",
},
}
if user.PhoneNumber != "" {
verifiable = append(verifiable, map[string]interface{}{
"value": user.PhoneNumber,
"verified": true,
"via": "sms",
})
}
payload["verifiable_addresses"] = verifiable
payload["recovery_addresses"] = []map[string]interface{}{
{
"value": user.Email,
"via": "email",
},
}
body, _ := json.Marshal(payload)
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, fmt.Sprintf("%s/admin/identities", o.KratosAdminURL), bytes.NewReader(body))
@@ -119,7 +150,7 @@ func (o *OryProvider) SignIn(loginID, password string) (*domain.AuthInfo, error)
return nil, fmt.Errorf("ory provider: loginID and password are required")
}
flowID, err := o.startLoginFlow()
flowID, err := o.startLoginFlow("")
if err != nil {
return nil, err
}
@@ -178,6 +209,326 @@ func (o *OryProvider) SignIn(loginID, password string) (*domain.AuthInfo, error)
}, nil
}
// UserExists는 Kratos Admin API로 loginID 존재 여부를 확인합니다.
func (o *OryProvider) UserExists(loginID string) (bool, error) {
if loginID == "" {
return false, fmt.Errorf("ory provider: loginID is empty")
}
identityID, err := o.findIdentityID(loginID)
if err != nil {
return false, fmt.Errorf("ory provider: find identity failed: %w", err)
}
return identityID != "", nil
}
// IssueSession은 Ory에서 별도 세션 발급이 필요할 때 사용합니다. (현재 미지원)
func (o *OryProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
return nil, domain.ErrNotSupported
}
// InitiateLinkLogin은 Kratos Public API로 링크 로그인 플로우를 시작하고 이메일 전송을 트리거합니다.
func (o *OryProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
if loginID == "" {
return nil, fmt.Errorf("ory provider: loginID is required")
}
init, err := o.submitLoginCodeInit(loginID, returnTo)
if err == nil {
return init, nil
}
if shouldBootstrapCodeLogin(err) {
if ensureErr := o.ensureCodeLoginIdentifier(loginID); ensureErr == nil {
return o.submitLoginCodeInit(loginID, returnTo)
} else {
slog.Warn("Ory code login bootstrap failed", "loginID", loginID, "error", ensureErr)
}
}
return nil, err
}
func (o *OryProvider) submitLoginCodeInit(loginID, returnTo string) (*domain.LinkLoginInit, error) {
flowID, err := o.startLoginFlow(returnTo)
if err != nil {
return nil, err
}
body, _ := json.Marshal(map[string]string{
"method": "code",
"identifier": loginID,
})
loginURL := fmt.Sprintf("%s/self-service/login?flow=%s", o.KratosPublicURL, flowID)
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, loginURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("ory provider: build link login request failed: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := o.httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("ory provider: link login request failed: %w", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
if resp.StatusCode >= 300 {
init, ok := parseKratosLinkLoginResponse(flowID, respBody)
if ok {
slog.Info("Ory link login initiated with non-2xx response", "loginID", loginID, "flow_id", flowID, "status", resp.StatusCode)
return init, nil
}
return nil, fmt.Errorf("ory provider: link login failed status=%d body=%s", resp.StatusCode, string(respBody))
}
var result struct {
ExpiresAt time.Time `json:"expires_at"`
}
_ = json.Unmarshal(respBody, &result)
slog.Info("Ory link login initiated", "loginID", loginID, "flow_id", flowID)
return &domain.LinkLoginInit{
FlowID: flowID,
ExpiresAt: result.ExpiresAt,
Mode: "link",
}, nil
}
func parseKratosLinkLoginResponse(flowID string, body []byte) (*domain.LinkLoginInit, bool) {
if len(body) == 0 {
return nil, false
}
var parsed struct {
ExpiresAt time.Time `json:"expires_at"`
State string `json:"state"`
Active string `json:"active"`
}
if err := json.Unmarshal(body, &parsed); err != nil {
return nil, false
}
state := strings.ToLower(parsed.State)
active := strings.ToLower(parsed.Active)
if strings.Contains(state, "sent") || active == "code" {
return &domain.LinkLoginInit{
FlowID: flowID,
ExpiresAt: parsed.ExpiresAt,
Mode: "link",
}, true
}
return nil, false
}
func shouldBootstrapCodeLogin(err error) bool {
if err == nil {
return false
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "has not setup sign in with code") ||
strings.Contains(msg, "4000035")
}
type kratosVerifiableAddress struct {
Value string `json:"value"`
Via string `json:"via"`
Verified bool `json:"verified"`
Status string `json:"status,omitempty"`
}
func (o *OryProvider) ensureCodeLoginIdentifier(loginID string) error {
identityID, err := o.findIdentityID(loginID)
if err != nil {
return fmt.Errorf("ory provider: find identity failed: %w", err)
}
if identityID == "" {
return fmt.Errorf("ory provider: identity not found for loginID=%s", loginID)
}
identity, err := o.fetchIdentity(identityID)
if err != nil {
return err
}
via := "sms"
if strings.Contains(loginID, "@") {
via = "email"
}
exists := false
existingIndex := -1
addresses := make([]kratosVerifiableAddress, 0, len(identity.VerifiableAddresses)+1)
for idx, addr := range identity.VerifiableAddresses {
addresses = append(addresses, kratosVerifiableAddress{
Value: addr.Value,
Via: addr.Via,
Verified: addr.Verified,
Status: addr.Status,
})
if addr.Value == loginID && addr.Via == via {
exists = true
existingIndex = idx
}
}
ops := make([]map[string]interface{}, 0, 2)
if !exists {
ops = append(ops, map[string]interface{}{
"op": "add",
"path": "/verifiable_addresses/-",
"value": map[string]interface{}{
"value": loginID,
"via": via,
"verified": true,
"status": "completed",
},
})
} else {
addr := identity.VerifiableAddresses[existingIndex]
if !addr.Verified {
ops = append(ops, map[string]interface{}{
"op": "replace",
"path": fmt.Sprintf("/verifiable_addresses/%d/verified", existingIndex),
"value": true,
})
}
if addr.Status != "" && addr.Status != "completed" {
ops = append(ops, map[string]interface{}{
"op": "replace",
"path": fmt.Sprintf("/verifiable_addresses/%d/status", existingIndex),
"value": "completed",
})
}
}
if len(ops) == 0 {
slog.Info("Ory identity verifiable address already ready", "identity_id", identityID, "loginID", loginID, "via", via)
return nil
}
return o.patchIdentity(identityID, ops)
}
type kratosIdentity struct {
VerifiableAddresses []kratosVerifiableAddress `json:"verifiable_addresses"`
}
func (o *OryProvider) patchIdentity(identityID string, ops []map[string]interface{}) error {
body, _ := json.Marshal(ops)
req, err := http.NewRequestWithContext(context.Background(), http.MethodPatch, fmt.Sprintf("%s/admin/identities/%s", o.KratosAdminURL, identityID), bytes.NewReader(body))
if err != nil {
return fmt.Errorf("ory provider: build identity patch failed: %w", err)
}
req.Header.Set("Content-Type", "application/json-patch+json")
resp, err := o.httpClient().Do(req)
if err != nil {
return fmt.Errorf("ory provider: identity patch failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return fmt.Errorf("ory provider: identity patch failed status=%d body=%s", resp.StatusCode, string(respBody))
}
slog.Info("Ory identity patched", "identity_id", identityID, "ops", len(ops))
return nil
}
func (o *OryProvider) fetchIdentity(identityID string) (*kratosIdentity, error) {
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, fmt.Sprintf("%s/admin/identities/%s", o.KratosAdminURL, identityID), nil)
if err != nil {
return nil, fmt.Errorf("ory provider: build identity get failed: %w", err)
}
resp, err := o.httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("ory provider: identity get failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return nil, fmt.Errorf("ory provider: identity get failed status=%d body=%s", resp.StatusCode, string(body))
}
var identity kratosIdentity
if err := json.NewDecoder(resp.Body).Decode(&identity); err != nil {
return nil, fmt.Errorf("ory provider: decode identity failed: %w", err)
}
return &identity, nil
}
// VerifyLoginCode는 Kratos 로그인 코드 제출로 세션을 발급합니다.
func (o *OryProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
if loginID == "" || flowID == "" || code == "" {
return nil, fmt.Errorf("ory provider: loginID, flowID and code are required")
}
body, _ := json.Marshal(map[string]string{
"method": "code",
"identifier": loginID,
"code": code,
})
loginURL := fmt.Sprintf("%s/self-service/login?flow=%s", o.KratosPublicURL, flowID)
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, loginURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("ory provider: build login code request failed: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := o.httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("ory provider: login code request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return nil, fmt.Errorf("ory provider: login code failed status=%d body=%s", resp.StatusCode, string(respBody))
}
var result struct {
SessionToken string `json:"session_token"`
SessionTokenExpiresAt time.Time `json:"session_token_expires_at"`
Session struct {
Identity struct {
ID string `json:"id"`
} `json:"identity"`
} `json:"session"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("ory provider: decode login code response failed: %w", err)
}
if result.SessionToken == "" {
return nil, fmt.Errorf("ory provider: empty session token returned")
}
slog.Info("Ory login code successful",
"identity_id", result.Session.Identity.ID,
"loginID", loginID,
"expires_at", result.SessionTokenExpiresAt,
)
return &domain.AuthInfo{
SessionToken: &domain.Token{
JWT: result.SessionToken,
Expiration: result.SessionTokenExpiresAt,
},
Subject: result.Session.Identity.ID,
}, nil
}
// GetPasswordPolicy는 Ory 환경에서 사용하는 기본 정책을 반환합니다.
func (o *OryProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
return &domain.PasswordPolicy{
MinLength: 12,
Lowercase: true,
Uppercase: false,
Number: true,
NonAlphanumeric: true,
MinCharacterTypes: 0,
}, nil
}
// InitiatePasswordReset는 현재 내부 토큰/메일 흐름을 사용하고 있으므로 NO-OP로 둡니다.
func (o *OryProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
slog.Info("Ory InitiatePasswordReset bypassed (handled by app internal flow)", "loginID", loginID, "redirect", redirectUrl)
@@ -301,8 +652,12 @@ func (o *OryProvider) httpClient() *http.Client {
}
// startLoginFlow는 Kratos Public API에서 login flow ID를 발급받습니다.
func (o *OryProvider) startLoginFlow() (string, error) {
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, fmt.Sprintf("%s/self-service/login/api", o.KratosPublicURL), nil)
func (o *OryProvider) startLoginFlow(returnTo string) (string, error) {
loginURL := fmt.Sprintf("%s/self-service/login/api", o.KratosPublicURL)
if returnTo != "" {
loginURL = loginURL + "?return_to=" + url.QueryEscape(returnTo)
}
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, loginURL, nil)
if err != nil {
return "", fmt.Errorf("ory provider: build login flow request failed: %w", err)
}

View File

@@ -0,0 +1,79 @@
package utils
import (
"encoding/json"
"strings"
)
var sensitiveKeys = map[string]struct{}{
"password": {},
"newpassword": {},
"oldpassword": {},
"token": {},
"accesstoken": {},
"access_token": {},
"refreshtoken": {},
"refresh_token": {},
"secret": {},
"clientsecret": {},
"client_secret": {},
"authorization": {},
"cookie": {},
"set-cookie": {},
"verificationcode": {},
"verification_code": {},
"code": {}, // Auth code (sensitive)
}
// MaskSensitiveJSON parses a JSON byte slice and masks values of sensitive keys.
// Returns the original data if it's not valid JSON.
func MaskSensitiveJSON(data []byte) []byte {
if len(data) == 0 {
return data
}
var obj interface{}
if err := json.Unmarshal(data, &obj); err != nil {
// Not a JSON object/array, return as is
return data
}
masked := maskValue(obj)
result, err := json.Marshal(masked)
if err != nil {
return data
}
return result
}
func maskValue(v interface{}) interface{} {
switch val := v.(type) {
case map[string]interface{}:
newMap := make(map[string]interface{}, len(val))
for k, v := range val {
if isSensitive(k) {
newMap[k] = "*****"
} else {
newMap[k] = maskValue(v)
}
}
return newMap
case []interface{}:
newArr := make([]interface{}, len(val))
for i, v := range val {
newArr[i] = maskValue(v)
}
return newArr
default:
return val
}
}
func isSensitive(key string) bool {
// Check case-insensitive
// Remove common separators for looser matching? No, stick to lowercase check for now.
k := strings.ToLower(key)
_, ok := sensitiveKeys[k]
return ok
}

View File

@@ -0,0 +1,59 @@
package utils
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestMaskSensitiveJSON(t *testing.T) {
tests := []struct {
name string
input string
expected string // We'll check containment or specific structure
}{
{
name: "Flat object with password",
input: `{"username": "user", "password": "secret123"}`,
expected: `{"password":"*****","username":"user"}`,
},
{
name: "Nested object with token",
input: `{"data": {"token": "abc-def", "id": 123}}`,
expected: `{"data":{"id":123,"token":"*****"}}`,
},
{
name: "Case insensitive key",
input: `{"NewPassword": "changed"}`,
expected: `{"NewPassword":"*****"}`,
},
{
name: "Array of objects",
input: `[{"secret": "s1"}, {"secret": "s2"}]`,
expected: `[{"secret":"*****"},{"secret":"*****"}]`,
},
{
name: "Invalid JSON",
input: `not-json`,
expected: `not-json`,
},
{
name: "Empty JSON",
input: ``,
expected: ``,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := MaskSensitiveJSON([]byte(tt.input))
// Since JSON map order is undefined, exact string match might fail if keys are reordered.
// Ideally we should unmarshal and compare maps, or use assert.JSONEq
if tt.name == "Invalid JSON" || tt.name == "Empty JSON" {
assert.Equal(t, tt.expected, string(result))
} else {
assert.JSONEq(t, tt.expected, string(result))
}
})
}
}

View File

@@ -29,6 +29,26 @@ func (m *MockProvider) SignIn(loginID, password string) (*domain.AuthInfo, error
return &domain.AuthInfo{}, nil
}
func (m *MockProvider) UserExists(loginID string) (bool, error) {
return false, nil
}
func (m *MockProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
return nil, domain.ErrNotSupported
}
func (m *MockProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
return nil, domain.ErrNotSupported
}
func (m *MockProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
return nil, domain.ErrNotSupported
}
func (m *MockProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
return nil, domain.ErrNotSupported
}
// Stub implementations to satisfy the IdentityProvider interface for this unit test.
func (m *MockProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
return nil