forked from baron/baron-sso
Resolve merge conflicts with main
This commit is contained in:
132
backend/cmd/server/health_monitor.go
Normal file
132
backend/cmd/server/health_monitor.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type HTTPProbe struct {
|
||||
name string
|
||||
url string
|
||||
interval time.Duration
|
||||
timeout time.Duration
|
||||
client *http.Client
|
||||
mu sync.RWMutex
|
||||
status string
|
||||
lastError string
|
||||
lastChecked time.Time
|
||||
lastSuccess time.Time
|
||||
}
|
||||
|
||||
type ProbeSnapshot struct {
|
||||
Status string
|
||||
Error string
|
||||
LastChecked time.Time
|
||||
LastSuccess time.Time
|
||||
}
|
||||
|
||||
func NewHTTPProbe(name, url string, interval, timeout time.Duration) *HTTPProbe {
|
||||
if interval <= 0 {
|
||||
interval = 10 * time.Second
|
||||
}
|
||||
if timeout <= 0 {
|
||||
timeout = 2 * time.Second
|
||||
}
|
||||
|
||||
return &HTTPProbe{
|
||||
name: name,
|
||||
url: url,
|
||||
interval: interval,
|
||||
timeout: timeout,
|
||||
client: &http.Client{
|
||||
Timeout: timeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Start는 프로브를 백그라운드에서 주기적으로 실행합니다.
|
||||
func (p *HTTPProbe) Start() {
|
||||
go func() {
|
||||
p.checkOnce()
|
||||
ticker := time.NewTicker(p.interval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
p.checkOnce()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (p *HTTPProbe) Snapshot() ProbeSnapshot {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
return ProbeSnapshot{
|
||||
Status: p.status,
|
||||
Error: p.lastError,
|
||||
LastChecked: p.lastChecked,
|
||||
LastSuccess: p.lastSuccess,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *HTTPProbe) StatusText() string {
|
||||
s := p.Snapshot()
|
||||
if s.Status == "ok" {
|
||||
return "ok"
|
||||
}
|
||||
if s.Status == "" {
|
||||
return "unknown"
|
||||
}
|
||||
if s.Error == "" {
|
||||
return "error"
|
||||
}
|
||||
return "error: " + s.Error
|
||||
}
|
||||
|
||||
func (p *HTTPProbe) checkOnce() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.url, nil)
|
||||
if err != nil {
|
||||
p.update("error", fmt.Sprintf("request build failed: %v", err), false)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
p.update("error", err.Error(), false)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
p.update("error", fmt.Sprintf("status=%d", resp.StatusCode), false)
|
||||
return
|
||||
}
|
||||
|
||||
p.update("ok", "", true)
|
||||
}
|
||||
|
||||
func (p *HTTPProbe) update(status, errMsg string, success bool) {
|
||||
p.mu.Lock()
|
||||
prevStatus := p.status
|
||||
p.status = status
|
||||
p.lastError = errMsg
|
||||
p.lastChecked = time.Now()
|
||||
if success {
|
||||
p.lastSuccess = p.lastChecked
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
if prevStatus == status {
|
||||
return
|
||||
}
|
||||
if status == "ok" {
|
||||
slog.Info("Service probe recovered", "name", p.name, "url", p.url)
|
||||
return
|
||||
}
|
||||
slog.Error("Service probe failed", "name", p.name, "url", p.url, "error", errMsg)
|
||||
}
|
||||
@@ -158,6 +158,28 @@ func main() {
|
||||
slog.Warn("Failed to connect to Redis. Auth features may fail.", "error", err)
|
||||
}
|
||||
|
||||
// Oathkeeper 상태를 주기적으로 확인해 다운을 감지합니다.
|
||||
var oathkeeperProbe *HTTPProbe
|
||||
if strings.ToLower(getEnv("OATHKEEPER_HEALTH_ENABLED", "true")) != "false" {
|
||||
intervalSec, err := strconv.Atoi(getEnv("OATHKEEPER_HEALTH_INTERVAL_SECONDS", "10"))
|
||||
if err != nil || intervalSec <= 0 {
|
||||
intervalSec = 10
|
||||
}
|
||||
timeoutSec, err := strconv.Atoi(getEnv("OATHKEEPER_HEALTH_TIMEOUT_SECONDS", "2"))
|
||||
if err != nil || timeoutSec <= 0 {
|
||||
timeoutSec = 2
|
||||
}
|
||||
oathkeeperProbe = NewHTTPProbe(
|
||||
"oathkeeper",
|
||||
getEnv("OATHKEEPER_HEALTH_URL", "http://oathkeeper:4456/health/ready"),
|
||||
time.Duration(intervalSec)*time.Second,
|
||||
time.Duration(timeoutSec)*time.Second,
|
||||
)
|
||||
oathkeeperProbe.Start()
|
||||
} else {
|
||||
slog.Info("Oathkeeper probe disabled")
|
||||
}
|
||||
|
||||
// 2. Initialize Handlers
|
||||
auditHandler := handler.NewAuditHandler(auditRepo)
|
||||
authHandler := handler.NewAuthHandler(redisService, idpProvider)
|
||||
@@ -249,10 +271,14 @@ func main() {
|
||||
app.Use(recover.New(recover.Config{
|
||||
EnableStackTrace: true,
|
||||
}))
|
||||
|
||||
allowedOrigins := getEnv("CORS_ALLOWED_ORIGINS", "http://localhost:5000")
|
||||
allowCredentials := allowedOrigins != "*"
|
||||
app.Use(cors.New(cors.Config{
|
||||
AllowOrigins: "*", // Adjust in production
|
||||
AllowHeaders: "Origin, Content-Type, Accept, Authorization",
|
||||
AllowMethods: "GET, POST, HEAD, PUT, DELETE, PATCH, OPTIONS",
|
||||
AllowOrigins: allowedOrigins,
|
||||
AllowHeaders: "Origin, Content-Type, Accept, Authorization",
|
||||
AllowMethods: "GET, POST, HEAD, PUT, DELETE, PATCH, OPTIONS",
|
||||
AllowCredentials: allowCredentials,
|
||||
}))
|
||||
|
||||
// Ensure COOKIE_SECRET is exactly 32 bytes for AES-256
|
||||
@@ -271,24 +297,30 @@ func main() {
|
||||
Key: cookieSecret,
|
||||
}))
|
||||
|
||||
app.Get("/docs", func(c *fiber.Ctx) error {
|
||||
return c.SendFile("./docs/swagger-ui/index.html")
|
||||
})
|
||||
app.Get("/docs/", func(c *fiber.Ctx) error {
|
||||
return c.SendFile("./docs/swagger-ui/index.html")
|
||||
})
|
||||
app.Static("/docs", "./docs/swagger-ui")
|
||||
app.Get("/redoc", func(c *fiber.Ctx) error {
|
||||
return c.SendFile("./docs/redoc/index.html")
|
||||
})
|
||||
app.Get("/redoc/", func(c *fiber.Ctx) error {
|
||||
return c.SendFile("./docs/redoc/index.html")
|
||||
})
|
||||
app.Static("/redoc", "./docs/redoc")
|
||||
app.Get("/openapi.yaml", func(c *fiber.Ctx) error {
|
||||
c.Type("yaml")
|
||||
return c.SendFile("./docs/openapi.yaml")
|
||||
})
|
||||
// [Security] Disable Swagger/ReDoc in Production
|
||||
if appEnv != "production" {
|
||||
app.Get("/docs", func(c *fiber.Ctx) error {
|
||||
return c.SendFile("./docs/swagger-ui/index.html")
|
||||
})
|
||||
app.Get("/docs/", func(c *fiber.Ctx) error {
|
||||
return c.SendFile("./docs/swagger-ui/index.html")
|
||||
})
|
||||
app.Static("/docs", "./docs/swagger-ui")
|
||||
app.Get("/redoc", func(c *fiber.Ctx) error {
|
||||
return c.SendFile("./docs/redoc/index.html")
|
||||
})
|
||||
app.Get("/redoc/", func(c *fiber.Ctx) error {
|
||||
return c.SendFile("./docs/redoc/index.html")
|
||||
})
|
||||
app.Static("/redoc", "./docs/redoc")
|
||||
app.Get("/openapi.yaml", func(c *fiber.Ctx) error {
|
||||
c.Type("yaml")
|
||||
return c.SendFile("./docs/openapi.yaml")
|
||||
})
|
||||
slog.Info("📚 API Docs enabled", "swagger", "/docs", "redoc", "/redoc")
|
||||
} else {
|
||||
slog.Info("🔒 API Docs disabled in production")
|
||||
}
|
||||
|
||||
// Routes
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
@@ -325,6 +357,32 @@ func main() {
|
||||
status = "degraded"
|
||||
}
|
||||
|
||||
// Check Oathkeeper
|
||||
if oathkeeperProbe != nil {
|
||||
snapshot := oathkeeperProbe.Snapshot()
|
||||
switch snapshot.Status {
|
||||
case "ok":
|
||||
checks["oathkeeper"] = "ok"
|
||||
case "":
|
||||
checks["oathkeeper"] = "unknown"
|
||||
if status != "error" {
|
||||
status = "degraded"
|
||||
}
|
||||
default:
|
||||
if snapshot.Error == "" {
|
||||
checks["oathkeeper"] = "error"
|
||||
} else {
|
||||
checks["oathkeeper"] = "error: " + snapshot.Error
|
||||
}
|
||||
status = "error"
|
||||
}
|
||||
} else {
|
||||
checks["oathkeeper"] = "disabled"
|
||||
if status != "error" {
|
||||
status = "degraded"
|
||||
}
|
||||
}
|
||||
|
||||
if status == "error" {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{
|
||||
"status": status,
|
||||
@@ -340,12 +398,19 @@ func main() {
|
||||
|
||||
// API Group
|
||||
api := app.Group("/api/v1")
|
||||
api.Use(middleware.RequireAudit(middleware.AuditRequiredConfig{
|
||||
|
||||
workerCount, _ := strconv.Atoi(getEnv("AUDIT_WORKER_COUNT", "5"))
|
||||
queueSize, _ := strconv.Atoi(getEnv("AUDIT_QUEUE_SIZE", "2000"))
|
||||
|
||||
api.Use(middleware.AuditMiddleware(middleware.AuditConfig{
|
||||
Repo: auditRepo,
|
||||
ExcludePaths: map[string]struct{}{
|
||||
"/api/v1/audit": {},
|
||||
"/api/v1/client-log": {},
|
||||
},
|
||||
BodyDump: true,
|
||||
WorkerCount: workerCount,
|
||||
QueueSize: queueSize,
|
||||
}))
|
||||
api.Post("/audit", auditHandler.CreateLog)
|
||||
api.Get("/audit", auditHandler.ListLogs)
|
||||
@@ -355,6 +420,8 @@ func main() {
|
||||
auth.Post("/enchanted-link/init", authHandler.InitEnchantedLink)
|
||||
auth.Post("/enchanted-link/poll", authHandler.PollEnchantedLink)
|
||||
auth.Post("/magic-link/verify", authHandler.VerifyMagicLink)
|
||||
auth.Post("/login/code/verify", authHandler.VerifyLoginCode)
|
||||
auth.Post("/login/code/verify-short", authHandler.VerifyLoginShortCode)
|
||||
auth.Post("/password/login", authHandler.PasswordLogin)
|
||||
auth.Post("/password/reset/initiate", authHandler.InitiatePasswordReset)
|
||||
// [Changed] Use Interstitial Page for GET to prevent Scanner consumption
|
||||
@@ -388,6 +455,7 @@ func main() {
|
||||
admin := api.Group("/admin")
|
||||
admin.Use(middleware.ApiKeyAuth(middleware.ApiKeyAuthConfig{DB: db})) // API Key 인증 추가
|
||||
admin.Get("/check", adminHandler.CheckAuth)
|
||||
admin.Get("/stats", adminHandler.GetSystemStats)
|
||||
admin.Get("/tenants", tenantHandler.ListTenants)
|
||||
admin.Post("/tenants", tenantHandler.CreateTenant)
|
||||
admin.Get("/tenants/:id", tenantHandler.GetTenant)
|
||||
@@ -423,6 +491,9 @@ func main() {
|
||||
// Webhook for Descope Generic Email Gateway (Fake Email Strategy)
|
||||
auth.Post("/webhooks/descope-email", authHandler.HandleDescopeEmailRelay)
|
||||
|
||||
// Webhook for Kratos courier (HTTP delivery)
|
||||
auth.Post("/webhooks/kratos-courier", authHandler.HandleKratosCourierRelay)
|
||||
|
||||
// Client Logging Route (Standardized & Flattened)
|
||||
api.Post("/client-log", func(c *fiber.Ctx) error {
|
||||
type LogReq struct {
|
||||
|
||||
@@ -230,7 +230,7 @@ paths:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/MessageResponse"
|
||||
$ref: "#/components/schemas/MagicLinkVerifyResponse"
|
||||
|
||||
/api/v1/auth/sms:
|
||||
post:
|
||||
@@ -266,7 +266,7 @@ paths:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/MessageResponse"
|
||||
$ref: "#/components/schemas/SmsVerifyResponse"
|
||||
|
||||
/api/v1/auth/qr/init:
|
||||
post:
|
||||
@@ -908,18 +908,28 @@ components:
|
||||
type: boolean
|
||||
nonAlphanumeric:
|
||||
type: boolean
|
||||
minCharacterTypes:
|
||||
type: integer
|
||||
|
||||
EnchantedLinkInitRequest:
|
||||
type: object
|
||||
properties:
|
||||
loginId:
|
||||
type: string
|
||||
uri:
|
||||
type: string
|
||||
method:
|
||||
type: string
|
||||
|
||||
EnchantedLinkInitResponse:
|
||||
type: object
|
||||
properties:
|
||||
linkId:
|
||||
type: string
|
||||
pendingRef:
|
||||
type: string
|
||||
maskedEmail:
|
||||
type: string
|
||||
expiresIn:
|
||||
type: integer
|
||||
|
||||
@@ -943,22 +953,36 @@ components:
|
||||
token:
|
||||
type: string
|
||||
|
||||
MagicLinkVerifyResponse:
|
||||
type: object
|
||||
properties:
|
||||
token:
|
||||
type: string
|
||||
message:
|
||||
type: string
|
||||
|
||||
SmsSendRequest:
|
||||
type: object
|
||||
properties:
|
||||
phone:
|
||||
type: string
|
||||
message:
|
||||
phoneNumber:
|
||||
type: string
|
||||
|
||||
SmsVerifyRequest:
|
||||
type: object
|
||||
properties:
|
||||
phone:
|
||||
phoneNumber:
|
||||
type: string
|
||||
code:
|
||||
type: string
|
||||
|
||||
SmsVerifyResponse:
|
||||
type: object
|
||||
properties:
|
||||
token:
|
||||
type: string
|
||||
message:
|
||||
type: string
|
||||
|
||||
QrInitResponse:
|
||||
type: object
|
||||
properties:
|
||||
|
||||
@@ -14,6 +14,7 @@ require (
|
||||
github.com/gofiber/fiber/v2 v2.52.10
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/stretchr/testify v1.11.1
|
||||
golang.org/x/crypto v0.46.0
|
||||
gorm.io/driver/postgres v1.6.0
|
||||
gorm.io/gorm v1.31.1
|
||||
@@ -34,6 +35,7 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.6 // indirect
|
||||
github.com/aws/smithy-go v1.24.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/go-faster/city v1.0.1 // indirect
|
||||
@@ -57,10 +59,12 @@ require (
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/paulmach/orb v0.12.0 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.22 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||
github.com/segmentio/asm v1.2.1 // indirect
|
||||
github.com/shopspring/decimal v1.4.0 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasthttp v1.51.0 // indirect
|
||||
github.com/valyala/tcplisten v1.0.0 // indirect
|
||||
@@ -71,4 +75,5 @@ require (
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
golang.org/x/text v0.32.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -137,6 +137,8 @@ github.com/segmentio/asm v1.2.1/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr
|
||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
|
||||
@@ -4,6 +4,7 @@ type EnchantedLinkInitRequest struct {
|
||||
LoginID string `json:"loginId"`
|
||||
URI string `json:"uri,omitempty"` // Redirect URI (optional for polling flow)
|
||||
Method string `json:"method,omitempty"` // "email" or "sms"
|
||||
CodeOnly bool `json:"codeOnly,omitempty"`
|
||||
}
|
||||
|
||||
type EnchantedLinkInitResponse struct {
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrNotSupported는 IDP가 특정 인증 흐름을 지원하지 않을 때 반환합니다.
|
||||
var ErrNotSupported = errors.New("idp: not supported")
|
||||
|
||||
// BrokerUser is the standard user model used within Baron SSO business logic.
|
||||
// It defines the canonical set of fields that must be supported by any underlying IDP.
|
||||
type BrokerUser struct {
|
||||
@@ -24,6 +28,16 @@ type IDPMetadata struct {
|
||||
SupportedFields []string
|
||||
}
|
||||
|
||||
// PasswordPolicy는 비밀번호 정책 정보를 표현합니다.
|
||||
type PasswordPolicy struct {
|
||||
MinLength int
|
||||
Lowercase bool
|
||||
Uppercase bool
|
||||
Number bool
|
||||
NonAlphanumeric bool
|
||||
MinCharacterTypes int
|
||||
}
|
||||
|
||||
// Token represents a session or refresh token.
|
||||
type Token struct {
|
||||
JWT string
|
||||
@@ -38,6 +52,16 @@ type AuthInfo struct {
|
||||
Subject string
|
||||
}
|
||||
|
||||
// LinkLoginInit는 링크 로그인 초기화 결과입니다.
|
||||
type LinkLoginInit struct {
|
||||
FlowID string
|
||||
ExpiresAt time.Time
|
||||
// Mode는 링크 로그인 완료 후 세션 처리 방식입니다. (예: "cookie")
|
||||
Mode string
|
||||
// LoginID는 IDP에 실제 전달된 식별자입니다.
|
||||
LoginID string
|
||||
}
|
||||
|
||||
// IdentityProvider is the interface that all IDP adapters must implement.
|
||||
type IdentityProvider interface {
|
||||
Name() string
|
||||
@@ -48,6 +72,16 @@ type IdentityProvider interface {
|
||||
CreateUser(user *BrokerUser, password string) (string, error)
|
||||
// SignIn은 로그인 ID/비밀번호로 인증해 세션 정보를 반환합니다.
|
||||
SignIn(loginID, password string) (*AuthInfo, error)
|
||||
// UserExists는 loginID 기준으로 사용자 존재 여부를 확인합니다.
|
||||
UserExists(loginID string) (bool, error)
|
||||
// IssueSession은 비밀번호 없이 세션을 발급해야 하는 흐름에서 사용합니다.
|
||||
IssueSession(loginID string) (*AuthInfo, error)
|
||||
// InitiateLinkLogin은 링크 기반 로그인 요청을 IDP에 전달합니다.
|
||||
InitiateLinkLogin(loginID, returnTo string) (*LinkLoginInit, error)
|
||||
// VerifyLoginCode는 링크/코드 기반 로그인에서 코드를 제출해 세션을 발급합니다.
|
||||
VerifyLoginCode(loginID, flowID, code string) (*AuthInfo, error)
|
||||
// GetPasswordPolicy는 IDP가 제공하는 비밀번호 정책을 반환합니다.
|
||||
GetPasswordPolicy() (*PasswordPolicy, error)
|
||||
InitiatePasswordReset(loginID, redirectUrl string) error
|
||||
VerifyPasswordResetToken(token string) (*AuthInfo, error)
|
||||
UpdateUserPassword(loginID, newPassword string, r *http.Request) error
|
||||
|
||||
@@ -3,6 +3,8 @@ package handler
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/descope/go-sdk/descope/client"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
@@ -39,3 +41,23 @@ func NewAdminHandler() *AdminHandler {
|
||||
func (h *AdminHandler) CheckAuth(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "ok"})
|
||||
}
|
||||
|
||||
// GetSystemStats returns runtime statistics for monitoring
|
||||
func (h *AdminHandler) GetSystemStats(c *fiber.Ctx) error {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
|
||||
stats := fiber.Map{
|
||||
"goroutines": runtime.NumGoroutine(),
|
||||
"cpus": runtime.NumCPU(),
|
||||
"memory": fiber.Map{
|
||||
"alloc": m.Alloc,
|
||||
"totalAlign": m.TotalAlloc,
|
||||
"sys": m.Sys,
|
||||
"numGC": m.NumGC,
|
||||
},
|
||||
"timestamp": time.Now(),
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(stats)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -72,7 +72,7 @@ func TestCompletePasswordReset_InvalidPasswordPolicy(t *testing.T) {
|
||||
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if got["error"] != "Password must be at least 8 characters long" {
|
||||
if got["error"] != "비밀번호는 최소 12자 이상이어야 합니다" {
|
||||
t.Fatalf("unexpected error message: %v", got["error"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,43 +124,144 @@ func (c *chainedProvider) GetMetadata() (*domain.IDPMetadata, error) {
|
||||
}
|
||||
|
||||
func (c *chainedProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
|
||||
var errs []error
|
||||
for idx, p := range c.providers {
|
||||
for _, p := range c.providers {
|
||||
id, err := p.CreateUser(user, password)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "CreateUser", "error", err)
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
slog.Info("IDP fallback succeeded", "operation", "CreateUser", "provider", p.Name())
|
||||
return "", err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return "", fmt.Errorf("no IDP providers available for CreateUser")
|
||||
}
|
||||
return "", fmt.Errorf("all IDP providers failed for CreateUser: %w", errors.Join(errs...))
|
||||
return "", domain.ErrNotSupported
|
||||
}
|
||||
|
||||
func (c *chainedProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
|
||||
var errs []error
|
||||
for idx, p := range c.providers {
|
||||
for _, p := range c.providers {
|
||||
info, err := p.SignIn(loginID, password)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "SignIn", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
|
||||
func (c *chainedProvider) UserExists(loginID string) (bool, error) {
|
||||
var errs []error
|
||||
for _, p := range c.providers {
|
||||
exists, err := p.UserExists(loginID)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
continue
|
||||
}
|
||||
if exists {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("all IDP providers failed for UserExists: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *chainedProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
|
||||
var errs []error
|
||||
for idx, p := range c.providers {
|
||||
info, err := p.IssueSession(loginID)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "IssueSession", "error", err)
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
slog.Info("IDP fallback succeeded", "operation", "SignIn", "provider", p.Name())
|
||||
slog.Info("IDP fallback succeeded", "operation", "IssueSession", "provider", p.Name())
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return nil, fmt.Errorf("no IDP providers available for SignIn")
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
return nil, fmt.Errorf("all IDP providers failed for SignIn: %w", errors.Join(errs...))
|
||||
return nil, fmt.Errorf("all IDP providers failed for IssueSession: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *chainedProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
var errs []error
|
||||
for idx, p := range c.providers {
|
||||
info, err := p.InitiateLinkLogin(loginID, returnTo)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "InitiateLinkLogin", "error", err)
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
slog.Info("IDP fallback succeeded", "operation", "InitiateLinkLogin", "provider", p.Name())
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
return nil, fmt.Errorf("all IDP providers failed for InitiateLinkLogin: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *chainedProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
||||
var errs []error
|
||||
for idx, p := range c.providers {
|
||||
info, err := p.VerifyLoginCode(loginID, flowID, code)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "VerifyLoginCode", "error", err)
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
slog.Info("IDP fallback succeeded", "operation", "VerifyLoginCode", "provider", p.Name())
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
return nil, fmt.Errorf("all IDP providers failed for VerifyLoginCode: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *chainedProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
|
||||
var errs []error
|
||||
for _, p := range c.providers {
|
||||
policy, err := p.GetPasswordPolicy()
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
continue
|
||||
}
|
||||
if policy != nil {
|
||||
return policy, nil
|
||||
}
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
return nil, fmt.Errorf("all IDP providers failed for GetPasswordPolicy: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *chainedProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
|
||||
|
||||
@@ -10,19 +10,31 @@ import (
|
||||
)
|
||||
|
||||
type stubProvider struct {
|
||||
name string
|
||||
metadata []string
|
||||
createErr error
|
||||
initiateErr error
|
||||
verifyErr error
|
||||
updateErr error
|
||||
signInErr error
|
||||
initiateCalls int
|
||||
verifyCalls int
|
||||
updateCalls int
|
||||
signInCalls int
|
||||
createCalls int
|
||||
verifyResponse *domain.AuthInfo
|
||||
name string
|
||||
metadata []string
|
||||
createErr error
|
||||
initiateErr error
|
||||
verifyErr error
|
||||
updateErr error
|
||||
signInErr error
|
||||
userExistsErr error
|
||||
issueErr error
|
||||
linkInitErr error
|
||||
verifyCodeErr error
|
||||
policyErr error
|
||||
initiateCalls int
|
||||
verifyCalls int
|
||||
updateCalls int
|
||||
signInCalls int
|
||||
createCalls int
|
||||
userExistsCalls int
|
||||
issueCalls int
|
||||
linkInitCalls int
|
||||
verifyCodeCalls int
|
||||
policyCalls int
|
||||
verifyResponse *domain.AuthInfo
|
||||
userExists bool
|
||||
policy *domain.PasswordPolicy
|
||||
}
|
||||
|
||||
func (s *stubProvider) Name() string { return s.name }
|
||||
@@ -47,6 +59,46 @@ func (s *stubProvider) SignIn(loginID, password string) (*domain.AuthInfo, error
|
||||
return &domain.AuthInfo{Subject: "subject-123"}, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) UserExists(loginID string) (bool, error) {
|
||||
s.userExistsCalls++
|
||||
if s.userExistsErr != nil {
|
||||
return false, s.userExistsErr
|
||||
}
|
||||
return s.userExists, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
|
||||
s.issueCalls++
|
||||
if s.issueErr != nil {
|
||||
return nil, s.issueErr
|
||||
}
|
||||
return &domain.AuthInfo{Subject: "issue-subject"}, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
s.linkInitCalls++
|
||||
if s.linkInitErr != nil {
|
||||
return nil, s.linkInitErr
|
||||
}
|
||||
return &domain.LinkLoginInit{FlowID: "flow-123", Mode: "cookie"}, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
||||
s.verifyCodeCalls++
|
||||
if s.verifyCodeErr != nil {
|
||||
return nil, s.verifyCodeErr
|
||||
}
|
||||
return &domain.AuthInfo{Subject: "verify-code-subject"}, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
|
||||
s.policyCalls++
|
||||
if s.policyErr != nil {
|
||||
return nil, s.policyErr
|
||||
}
|
||||
return s.policy, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
|
||||
s.initiateCalls++
|
||||
return s.initiateErr
|
||||
|
||||
192
backend/internal/middleware/audit_middleware.go
Normal file
192
backend/internal/middleware/audit_middleware.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/utils"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditConfig struct {
|
||||
Repo domain.AuditRepository
|
||||
ExcludePaths map[string]struct{}
|
||||
BodyDump bool
|
||||
WorkerCount int
|
||||
QueueSize int
|
||||
}
|
||||
|
||||
func isNil(i any) bool {
|
||||
if i == nil {
|
||||
return true
|
||||
}
|
||||
v := reflect.ValueOf(i)
|
||||
return v.Kind() == reflect.Ptr && v.IsNil()
|
||||
}
|
||||
|
||||
// AuditMiddleware provides comprehensive audit logging for all requests.
|
||||
// It enforces strict logging for state-changing commands (POST, PUT, DELETE, PATCH)
|
||||
// and best-effort logging for queries (GET, HEAD, OPTIONS).
|
||||
func AuditMiddleware(config AuditConfig) fiber.Handler {
|
||||
// 0. Initialize Worker Pool for Async Logging
|
||||
if config.WorkerCount <= 0 {
|
||||
config.WorkerCount = 5 // Default workers
|
||||
}
|
||||
if config.QueueSize <= 0 {
|
||||
config.QueueSize = 1000 // Default queue size
|
||||
}
|
||||
|
||||
auditQueue := make(chan *domain.AuditLog, config.QueueSize)
|
||||
var once sync.Once
|
||||
|
||||
// Start workers only once
|
||||
once.Do(func() {
|
||||
for i := 0; i < config.WorkerCount; i++ {
|
||||
go func(workerID int) {
|
||||
slog.Debug("Audit worker started", "id", workerID)
|
||||
for log := range auditQueue {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("Audit worker panic recovery", "reason", r, "req_id", log.EventID)
|
||||
}
|
||||
}()
|
||||
if err := config.Repo.Create(log); err != nil {
|
||||
slog.Warn("Failed to write async audit log", "error", err, "req_id", log.EventID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
})
|
||||
|
||||
// Default methods classification
|
||||
writeMethods := map[string]struct{}{
|
||||
fiber.MethodPost: {},
|
||||
fiber.MethodPut: {},
|
||||
fiber.MethodPatch: {},
|
||||
fiber.MethodDelete: {},
|
||||
}
|
||||
|
||||
if config.ExcludePaths == nil {
|
||||
config.ExcludePaths = map[string]struct{}{}
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
// 1. Check exclusions
|
||||
if _, excluded := config.ExcludePaths[c.Path()]; excluded {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// 2. Setup context variables
|
||||
start := time.Now()
|
||||
reqID := c.Get("X-Request-Id")
|
||||
if reqID == "" {
|
||||
reqID = uuid.New().String()
|
||||
c.Set("X-Request-Id", reqID)
|
||||
}
|
||||
|
||||
// 3. Process Request
|
||||
err := c.Next()
|
||||
|
||||
// 4. Gather Metrics & Context
|
||||
latency := time.Since(start)
|
||||
status := c.Response().StatusCode()
|
||||
|
||||
// If Fiber handler returned an error, status might default to 500 or be in the error
|
||||
if err != nil {
|
||||
if fiberErr, ok := err.(*fiber.Error); ok {
|
||||
status = fiberErr.Code
|
||||
} else {
|
||||
status = fiber.StatusInternalServerError
|
||||
}
|
||||
}
|
||||
|
||||
statusText := "success"
|
||||
if status >= fiber.StatusBadRequest {
|
||||
statusText = "failure"
|
||||
}
|
||||
|
||||
// 5. Extract User Context (populated by AuthMiddleware/TenantGuard)
|
||||
userID, _ := c.Locals("user_id").(string)
|
||||
loginID, _ := c.Locals("login_id").(string)
|
||||
tenantID, _ := c.Locals("tenant_id").(string)
|
||||
|
||||
// 6. Capture & Mask Body
|
||||
var maskedBody string
|
||||
if config.BodyDump {
|
||||
if c.Method() != fiber.MethodGet && c.Method() != fiber.MethodHead {
|
||||
bodyBytes := c.Body()
|
||||
if len(bodyBytes) > 0 {
|
||||
maskedBytes := utils.MaskSensitiveJSON(bodyBytes)
|
||||
maskedBody = string(maskedBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 7. Construct Details JSON
|
||||
details := map[string]any{
|
||||
"request_id": reqID,
|
||||
"method": c.Method(),
|
||||
"path": c.Path(),
|
||||
"status": status,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
"login_id": loginID,
|
||||
"tenant_id": tenantID,
|
||||
"request_body": maskedBody,
|
||||
}
|
||||
if err != nil {
|
||||
details["error"] = err.Error()
|
||||
}
|
||||
|
||||
detailsJSON, _ := json.Marshal(details)
|
||||
|
||||
// 8. Create Audit Log Object
|
||||
auditLog := &domain.AuditLog{
|
||||
EventID: reqID,
|
||||
Timestamp: start,
|
||||
UserID: userID,
|
||||
EventType: fmt.Sprintf("%s %s", c.Method(), c.Path()),
|
||||
Status: statusText,
|
||||
IPAddress: c.IP(),
|
||||
UserAgent: c.Get("User-Agent"),
|
||||
Details: string(detailsJSON),
|
||||
}
|
||||
|
||||
// 9. Store Log (Policy Enforcement)
|
||||
_, isWrite := writeMethods[c.Method()]
|
||||
|
||||
if isNil(config.Repo) {
|
||||
if isWrite {
|
||||
slog.Error("Audit repository missing for command", "req_id", reqID)
|
||||
return fiber.NewError(fiber.StatusServiceUnavailable, "Audit system unavailable")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if isWrite {
|
||||
// Strict Mode: Synchronous write
|
||||
if createErr := config.Repo.Create(auditLog); createErr != nil {
|
||||
slog.Error("Failed to write audit log (sync)", "error", createErr, "req_id", reqID)
|
||||
return fiber.NewError(fiber.StatusServiceUnavailable, "Audit logging failed")
|
||||
}
|
||||
} else {
|
||||
// Best Effort: Load Shedding via Buffered Channel
|
||||
select {
|
||||
case auditQueue <- auditLog:
|
||||
// Successfully queued
|
||||
default:
|
||||
// Queue full -> DROP (Load Shedding)
|
||||
slog.Warn("Audit queue full, dropping log (load shedding)", "req_id", reqID, "path", c.Path())
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
117
backend/internal/middleware/audit_middleware_test.go
Normal file
117
backend/internal/middleware/audit_middleware_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockAuditRepository is a mock implementation of AuditRepository
|
||||
type MockAuditRepository struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockAuditRepository) Create(log *domain.AuditLog) error {
|
||||
args := m.Called(log)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAuditRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor) ([]domain.AuditLog, error) {
|
||||
args := m.Called(ctx, limit, cursor)
|
||||
return args.Get(0).([]domain.AuditLog), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAuditRepository) Ping(ctx context.Context) error {
|
||||
args := m.Called(ctx)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestAuditMiddleware(t *testing.T) {
|
||||
t.Run("POST request - Sync Success", func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockRepo := new(MockAuditRepository)
|
||||
|
||||
app.Use(AuditMiddleware(AuditConfig{
|
||||
Repo: mockRepo,
|
||||
BodyDump: true,
|
||||
}))
|
||||
|
||||
app.Post("/test", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
mockRepo.On("Create", mock.MatchedBy(func(log *domain.AuditLog) bool {
|
||||
var details map[string]any
|
||||
json.Unmarshal([]byte(log.Details), &details)
|
||||
return log.Status == "success" &&
|
||||
details["method"] == "POST" &&
|
||||
details["request_body"] == `{"password":"*****","user":"test"}`
|
||||
})).Return(nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", strings.NewReader(`{"user": "test", "password": "mypassword"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
mockRepo.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("POST request - Sync Failure (Strict Mode)", func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockRepo := new(MockAuditRepository)
|
||||
|
||||
app.Use(AuditMiddleware(AuditConfig{
|
||||
Repo: mockRepo,
|
||||
}))
|
||||
|
||||
app.Post("/test", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
mockRepo.On("Create", mock.Anything).Return(errors.New("db error"))
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
// Should return 503 because Audit failed on a Write method
|
||||
assert.Equal(t, fiber.StatusServiceUnavailable, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("GET request - Async Load Shedding", func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockRepo := new(MockAuditRepository)
|
||||
|
||||
// Set very small queue and no workers to force load shedding
|
||||
app.Use(AuditMiddleware(AuditConfig{
|
||||
Repo: mockRepo,
|
||||
QueueSize: 1,
|
||||
WorkerCount: 0, // This will be defaulted to 5 by the code, so let's use another way or just small queue
|
||||
}))
|
||||
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
// 1. First request fills the queue
|
||||
mockRepo.On("Create", mock.Anything).Return(nil)
|
||||
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
resp1, _ := app.Test(req1)
|
||||
assert.Equal(t, fiber.StatusOK, resp1.StatusCode)
|
||||
|
||||
// 2. Second request should be dropped (load shedding) if workers are slow
|
||||
// Since we can't easily pause workers without modifying code,
|
||||
// this test mostly ensures the non-blocking send doesn't hang.
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
resp2, _ := app.Test(req2)
|
||||
assert.Equal(t, fiber.StatusOK, resp2.StatusCode)
|
||||
})
|
||||
}
|
||||
@@ -1,117 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditRequiredConfig struct {
|
||||
Repo domain.AuditRepository
|
||||
ExcludePaths map[string]struct{}
|
||||
CommandMethods map[string]struct{}
|
||||
}
|
||||
|
||||
func isNil(i any) bool {
|
||||
if i == nil {
|
||||
return true
|
||||
}
|
||||
v := reflect.ValueOf(i)
|
||||
return v.Kind() == reflect.Ptr && v.IsNil()
|
||||
}
|
||||
|
||||
func RequireAudit(config AuditRequiredConfig) fiber.Handler {
|
||||
commandMethods := config.CommandMethods
|
||||
if len(commandMethods) == 0 {
|
||||
commandMethods = map[string]struct{}{
|
||||
fiber.MethodPost: {},
|
||||
fiber.MethodPut: {},
|
||||
fiber.MethodPatch: {},
|
||||
fiber.MethodDelete: {},
|
||||
}
|
||||
}
|
||||
|
||||
excludePaths := config.ExcludePaths
|
||||
if excludePaths == nil {
|
||||
excludePaths = map[string]struct{}{}
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
if _, ok := commandMethods[c.Method()]; !ok {
|
||||
return c.Next()
|
||||
}
|
||||
if _, excluded := excludePaths[c.Path()]; excluded {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
if isNil(config.Repo) {
|
||||
slog.Warn("audit repository is nil, skipping audit log creation", "path", c.Path())
|
||||
return c.Next() // Don't block the request, just skip audit
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
reqID := c.Get("X-Request-Id")
|
||||
if reqID == "" {
|
||||
reqID = uuid.New().String()
|
||||
c.Set("X-Request-Id", reqID)
|
||||
}
|
||||
|
||||
err := c.Next()
|
||||
latency := time.Since(start)
|
||||
|
||||
status := c.Response().StatusCode()
|
||||
if err != nil {
|
||||
if fiberErr, ok := err.(*fiber.Error); ok {
|
||||
status = fiberErr.Code
|
||||
} else {
|
||||
status = fiber.StatusInternalServerError
|
||||
}
|
||||
}
|
||||
|
||||
statusText := "success"
|
||||
if status >= fiber.StatusBadRequest {
|
||||
statusText = "failure"
|
||||
}
|
||||
|
||||
details := map[string]any{
|
||||
"request_id": reqID,
|
||||
"method": c.Method(),
|
||||
"path": c.Path(),
|
||||
"status": status,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
}
|
||||
if err != nil {
|
||||
details["error"] = err.Error()
|
||||
}
|
||||
|
||||
detailsJSON, jsonErr := json.Marshal(details)
|
||||
if jsonErr != nil {
|
||||
slog.Warn("failed to marshal audit details", "error", jsonErr, "req_id", reqID)
|
||||
}
|
||||
|
||||
auditLog := &domain.AuditLog{
|
||||
EventID: reqID,
|
||||
Timestamp: time.Now(),
|
||||
UserID: "",
|
||||
EventType: fmt.Sprintf("%s %s", c.Method(), c.Path()),
|
||||
Status: statusText,
|
||||
IPAddress: c.IP(),
|
||||
UserAgent: c.Get("User-Agent"),
|
||||
DeviceID: "",
|
||||
Details: string(detailsJSON),
|
||||
}
|
||||
|
||||
if createErr := config.Repo.Create(auditLog); createErr != nil {
|
||||
slog.Error("audit log write failed", "error", createErr, "req_id", reqID, "path", c.Path())
|
||||
return fiber.NewError(fiber.StatusServiceUnavailable, "audit logging unavailable")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -145,6 +145,101 @@ func (d *DescopeProvider) SignIn(loginID, password string) (*domain.AuthInfo, er
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// UserExists는 loginID(이메일/전화번호) 기준으로 사용자가 있는지 확인합니다.
|
||||
func (d *DescopeProvider) UserExists(loginID string) (bool, error) {
|
||||
if d.Client == nil {
|
||||
return false, fmt.Errorf("descope provider: client is nil")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if strings.Contains(loginID, "@") {
|
||||
user, err := d.Client.Management.User().Load(ctx, loginID)
|
||||
if err != nil {
|
||||
if isDescopeNotFound(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return user != nil, nil
|
||||
}
|
||||
|
||||
phone := normalizePhone(loginID)
|
||||
searchOptions := &descope.UserSearchOptions{
|
||||
Phones: []string{phone},
|
||||
Limit: 1,
|
||||
}
|
||||
users, _, err := d.Client.Management.User().SearchAll(ctx, searchOptions)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return len(users) > 0, nil
|
||||
}
|
||||
|
||||
// IssueSession은 비밀번호 없이 로그인 세션을 발급합니다.
|
||||
func (d *DescopeProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
|
||||
if d.Client == nil {
|
||||
return nil, fmt.Errorf("descope provider: client is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
targetLoginID, err := d.resolveLoginID(loginID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
embeddedToken, err := d.Client.Management.User().GenerateEmbeddedLink(ctx, targetLoginID, nil, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("descope provider: generate embedded link failed: %w", err)
|
||||
}
|
||||
|
||||
authInfo, err := d.Client.Auth.MagicLink().Verify(ctx, embeddedToken, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("descope provider: magic link verify failed: %w", err)
|
||||
}
|
||||
|
||||
res := &domain.AuthInfo{
|
||||
SessionToken: &domain.Token{
|
||||
JWT: authInfo.SessionToken.JWT,
|
||||
Expiration: time.Unix(authInfo.SessionToken.Expiration, 0),
|
||||
},
|
||||
Subject: authInfo.User.UserID,
|
||||
}
|
||||
if authInfo.RefreshToken != nil {
|
||||
res.RefreshToken = &domain.Token{
|
||||
JWT: authInfo.RefreshToken.JWT,
|
||||
Expiration: time.Unix(authInfo.RefreshToken.Expiration, 0),
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (d *DescopeProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
|
||||
func (d *DescopeProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
|
||||
// GetPasswordPolicy는 Descope 비밀번호 정책을 반환합니다.
|
||||
func (d *DescopeProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
|
||||
if d.Client == nil {
|
||||
return nil, fmt.Errorf("descope provider: client is nil")
|
||||
}
|
||||
policy, err := d.Client.Auth.Password().GetPasswordPolicy(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &domain.PasswordPolicy{
|
||||
MinLength: int(policy.MinLength),
|
||||
Lowercase: policy.Lowercase,
|
||||
Uppercase: policy.Uppercase,
|
||||
Number: policy.Number,
|
||||
NonAlphanumeric: policy.NonAlphanumeric,
|
||||
MinCharacterTypes: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *DescopeProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
|
||||
ctx := context.Background()
|
||||
err := d.Client.Auth.Password().SendPasswordReset(ctx, loginID, redirectUrl, nil)
|
||||
@@ -197,3 +292,57 @@ func (d *DescopeProvider) UpdateUserPassword(loginID, newPassword string, r *htt
|
||||
ctx := context.Background()
|
||||
return d.Client.Auth.Password().UpdateUserPassword(ctx, loginID, newPassword, r)
|
||||
}
|
||||
|
||||
func (d *DescopeProvider) resolveLoginID(loginID string) (string, error) {
|
||||
if strings.Contains(loginID, "@") {
|
||||
return loginID, nil
|
||||
}
|
||||
|
||||
phone := normalizePhone(loginID)
|
||||
searchOptions := &descope.UserSearchOptions{
|
||||
Phones: []string{phone},
|
||||
Limit: 1,
|
||||
}
|
||||
users, _, err := d.Client.Management.User().SearchAll(context.Background(), searchOptions)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("descope provider: user search failed: %w", err)
|
||||
}
|
||||
if len(users) == 0 {
|
||||
return "", fmt.Errorf("descope provider: user not found")
|
||||
}
|
||||
if len(users[0].LoginIDs) > 0 {
|
||||
return users[0].LoginIDs[0], nil
|
||||
}
|
||||
if users[0].UserID != "" {
|
||||
return users[0].UserID, nil
|
||||
}
|
||||
return "", fmt.Errorf("descope provider: user found but login id missing")
|
||||
}
|
||||
|
||||
func normalizePhone(phone string) string {
|
||||
normalized := strings.ReplaceAll(phone, "-", "")
|
||||
normalized = strings.ReplaceAll(normalized, " ", "")
|
||||
if strings.HasPrefix(normalized, "010") {
|
||||
return "+82" + normalized[1:]
|
||||
}
|
||||
if strings.HasPrefix(normalized, "82") {
|
||||
return "+" + normalized
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func isDescopeNotFound(err error) bool {
|
||||
if de, ok := err.(*descope.Error); ok {
|
||||
if rawStatus, ok := de.Info[descope.ErrorInfoKeys.HTTPResponseStatusCode]; ok {
|
||||
switch v := rawStatus.(type) {
|
||||
case int:
|
||||
return v == http.StatusNotFound
|
||||
case float64:
|
||||
return int(v) == http.StatusNotFound
|
||||
case string:
|
||||
return v == fmt.Sprintf("%d", http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -63,6 +64,15 @@ func (o *OryProvider) CreateUser(user *domain.BrokerUser, password string) (stri
|
||||
if existingID != "" {
|
||||
return "", fmt.Errorf("ory provider: identity already exists for email=%s", user.Email)
|
||||
}
|
||||
if user.PhoneNumber != "" {
|
||||
existingPhoneID, err := o.findIdentityID(user.PhoneNumber)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ory provider: search identity failed: %w", err)
|
||||
}
|
||||
if existingPhoneID != "" {
|
||||
return "", fmt.Errorf("ory provider: identity already exists for phone=%s", user.PhoneNumber)
|
||||
}
|
||||
}
|
||||
|
||||
traits := map[string]interface{}{
|
||||
"email": user.Email,
|
||||
@@ -84,6 +94,27 @@ func (o *OryProvider) CreateUser(user *domain.BrokerUser, password string) (stri
|
||||
},
|
||||
},
|
||||
}
|
||||
verifiable := []map[string]interface{}{
|
||||
{
|
||||
"value": user.Email,
|
||||
"verified": true,
|
||||
"via": "email",
|
||||
},
|
||||
}
|
||||
if user.PhoneNumber != "" {
|
||||
verifiable = append(verifiable, map[string]interface{}{
|
||||
"value": user.PhoneNumber,
|
||||
"verified": true,
|
||||
"via": "sms",
|
||||
})
|
||||
}
|
||||
payload["verifiable_addresses"] = verifiable
|
||||
payload["recovery_addresses"] = []map[string]interface{}{
|
||||
{
|
||||
"value": user.Email,
|
||||
"via": "email",
|
||||
},
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(payload)
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, fmt.Sprintf("%s/admin/identities", o.KratosAdminURL), bytes.NewReader(body))
|
||||
@@ -119,7 +150,7 @@ func (o *OryProvider) SignIn(loginID, password string) (*domain.AuthInfo, error)
|
||||
return nil, fmt.Errorf("ory provider: loginID and password are required")
|
||||
}
|
||||
|
||||
flowID, err := o.startLoginFlow()
|
||||
flowID, err := o.startLoginFlow("")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -178,6 +209,462 @@ func (o *OryProvider) SignIn(loginID, password string) (*domain.AuthInfo, error)
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UserExists는 Kratos Admin API로 loginID 존재 여부를 확인합니다.
|
||||
func (o *OryProvider) UserExists(loginID string) (bool, error) {
|
||||
if loginID == "" {
|
||||
return false, fmt.Errorf("ory provider: loginID is empty")
|
||||
}
|
||||
identityID, err := o.findIdentityID(loginID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("ory provider: find identity failed: %w", err)
|
||||
}
|
||||
return identityID != "", nil
|
||||
}
|
||||
|
||||
// IssueSession은 Ory에서 별도 세션 발급이 필요할 때 사용합니다. (현재 미지원)
|
||||
func (o *OryProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
|
||||
// InitiateLinkLogin은 Kratos Public API로 링크 로그인 플로우를 시작하고 이메일 전송을 트리거합니다.
|
||||
func (o *OryProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
if loginID == "" {
|
||||
return nil, fmt.Errorf("ory provider: loginID is required")
|
||||
}
|
||||
|
||||
effectiveLoginID, err := o.resolveEffectiveLoginID(loginID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := o.ensureCodeLoginIdentifier(effectiveLoginID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
init, err := o.submitLoginCodeInit(effectiveLoginID, returnTo)
|
||||
if err == nil {
|
||||
init.LoginID = effectiveLoginID
|
||||
return init, nil
|
||||
}
|
||||
|
||||
if shouldBootstrapCodeLogin(err) {
|
||||
if ensureErr := o.ensureCodeLoginIdentifier(effectiveLoginID); ensureErr == nil {
|
||||
init, initErr := o.submitLoginCodeInit(effectiveLoginID, returnTo)
|
||||
if initErr == nil {
|
||||
init.LoginID = effectiveLoginID
|
||||
}
|
||||
return init, initErr
|
||||
} else {
|
||||
slog.Warn("Ory code login bootstrap failed", "loginID", effectiveLoginID, "error", ensureErr)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (o *OryProvider) resolveEffectiveLoginID(loginID string) (string, error) {
|
||||
if strings.Contains(loginID, "@") {
|
||||
return loginID, nil
|
||||
}
|
||||
|
||||
identityID, err := o.findIdentityID(loginID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if identityID == "" {
|
||||
return "", fmt.Errorf("ory provider: identity not found for loginID=%s", loginID)
|
||||
}
|
||||
|
||||
fullIdentity, err := o.fetchIdentityFull(identityID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if fullIdentity != nil {
|
||||
if emailRaw, ok := fullIdentity.Traits["email"]; ok {
|
||||
if email, ok := emailRaw.(string); ok && email != "" {
|
||||
return email, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("ory provider: email trait missing for loginID=%s", loginID)
|
||||
}
|
||||
|
||||
func (o *OryProvider) submitLoginCodeInit(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
flowID, err := o.startLoginFlow(returnTo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"method": "code",
|
||||
"identifier": loginID,
|
||||
})
|
||||
loginURL := fmt.Sprintf("%s/self-service/login?flow=%s", o.KratosPublicURL, flowID)
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, loginURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: build link login request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: link login request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
if resp.StatusCode >= 300 {
|
||||
init, ok := parseKratosLinkLoginResponse(flowID, respBody)
|
||||
if ok {
|
||||
slog.Info("Ory link login initiated with non-2xx response", "loginID", loginID, "flow_id", flowID, "status", resp.StatusCode)
|
||||
return init, nil
|
||||
}
|
||||
return nil, fmt.Errorf("ory provider: link login failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
_ = json.Unmarshal(respBody, &result)
|
||||
|
||||
slog.Info("Ory link login initiated", "loginID", loginID, "flow_id", flowID)
|
||||
|
||||
return &domain.LinkLoginInit{
|
||||
FlowID: flowID,
|
||||
ExpiresAt: result.ExpiresAt,
|
||||
Mode: "link",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseKratosLinkLoginResponse(flowID string, body []byte) (*domain.LinkLoginInit, bool) {
|
||||
if len(body) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
var parsed struct {
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
State string `json:"state"`
|
||||
Active string `json:"active"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
state := strings.ToLower(parsed.State)
|
||||
active := strings.ToLower(parsed.Active)
|
||||
if strings.Contains(state, "sent") || active == "code" {
|
||||
return &domain.LinkLoginInit{
|
||||
FlowID: flowID,
|
||||
ExpiresAt: parsed.ExpiresAt,
|
||||
Mode: "link",
|
||||
}, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func shouldBootstrapCodeLogin(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
return strings.Contains(msg, "has not setup sign in with code") ||
|
||||
strings.Contains(msg, "4000035")
|
||||
}
|
||||
|
||||
type kratosVerifiableAddress struct {
|
||||
Value string `json:"value"`
|
||||
Via string `json:"via"`
|
||||
Verified bool `json:"verified"`
|
||||
Status string `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
func (o *OryProvider) ensureCodeLoginIdentifier(loginID string) error {
|
||||
identityID, err := o.findIdentityID(loginID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ory provider: find identity failed: %w", err)
|
||||
}
|
||||
if identityID == "" {
|
||||
return fmt.Errorf("ory provider: identity not found for loginID=%s", loginID)
|
||||
}
|
||||
|
||||
identity, err := o.fetchIdentity(identityID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
via := "sms"
|
||||
if strings.Contains(loginID, "@") {
|
||||
via = "email"
|
||||
}
|
||||
|
||||
exists := false
|
||||
existingIndex := -1
|
||||
addresses := make([]kratosVerifiableAddress, 0, len(identity.VerifiableAddresses)+1)
|
||||
for idx, addr := range identity.VerifiableAddresses {
|
||||
addresses = append(addresses, kratosVerifiableAddress{
|
||||
Value: addr.Value,
|
||||
Via: addr.Via,
|
||||
Verified: addr.Verified,
|
||||
Status: addr.Status,
|
||||
})
|
||||
if addr.Value == loginID && addr.Via == via {
|
||||
exists = true
|
||||
existingIndex = idx
|
||||
}
|
||||
}
|
||||
ops := make([]map[string]interface{}, 0, 2)
|
||||
if !exists {
|
||||
ops = append(ops, map[string]interface{}{
|
||||
"op": "add",
|
||||
"path": "/verifiable_addresses/-",
|
||||
"value": map[string]interface{}{
|
||||
"value": loginID,
|
||||
"via": via,
|
||||
"verified": true,
|
||||
"status": "completed",
|
||||
},
|
||||
})
|
||||
} else {
|
||||
addr := identity.VerifiableAddresses[existingIndex]
|
||||
if !addr.Verified {
|
||||
ops = append(ops, map[string]interface{}{
|
||||
"op": "replace",
|
||||
"path": fmt.Sprintf("/verifiable_addresses/%d/verified", existingIndex),
|
||||
"value": true,
|
||||
})
|
||||
}
|
||||
if addr.Status != "" && addr.Status != "completed" {
|
||||
ops = append(ops, map[string]interface{}{
|
||||
"op": "replace",
|
||||
"path": fmt.Sprintf("/verifiable_addresses/%d/status", existingIndex),
|
||||
"value": "completed",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(ops) == 0 {
|
||||
slog.Info("Ory identity verifiable address already ready", "identity_id", identityID, "loginID", loginID, "via", via)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := o.patchIdentity(identityID, ops); err != nil {
|
||||
slog.Warn("Ory identity patch failed, trying full update", "identity_id", identityID, "error", err)
|
||||
}
|
||||
|
||||
fullIdentity, err := o.fetchIdentityFull(identityID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addresses = make([]kratosVerifiableAddress, 0, len(fullIdentity.VerifiableAddresses)+1)
|
||||
found := false
|
||||
for _, addr := range fullIdentity.VerifiableAddresses {
|
||||
addresses = append(addresses, kratosVerifiableAddress{
|
||||
Value: addr.Value,
|
||||
Via: addr.Via,
|
||||
Verified: addr.Verified,
|
||||
Status: addr.Status,
|
||||
})
|
||||
if addr.Value == loginID && addr.Via == via {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
addresses = append(addresses, kratosVerifiableAddress{
|
||||
Value: loginID,
|
||||
Via: via,
|
||||
Verified: true,
|
||||
Status: "completed",
|
||||
})
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"schema_id": fullIdentity.SchemaID,
|
||||
"traits": fullIdentity.Traits,
|
||||
"verifiable_addresses": addresses,
|
||||
}
|
||||
if len(fullIdentity.RecoveryAddresses) > 0 {
|
||||
payload["recovery_addresses"] = fullIdentity.RecoveryAddresses
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(payload)
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodPut, fmt.Sprintf("%s/admin/identities/%s", o.KratosAdminURL, identityID), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("ory provider: build identity update failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ory provider: identity update failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return fmt.Errorf("ory provider: identity update failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
slog.Info("Ory identity updated with verifiable address", "identity_id", identityID, "loginID", loginID, "via", via)
|
||||
return nil
|
||||
}
|
||||
|
||||
type kratosIdentity struct {
|
||||
VerifiableAddresses []kratosVerifiableAddress `json:"verifiable_addresses"`
|
||||
}
|
||||
|
||||
type kratosRecoveryAddress struct {
|
||||
Value string `json:"value"`
|
||||
Via string `json:"via"`
|
||||
}
|
||||
|
||||
type kratosIdentityFull struct {
|
||||
SchemaID string `json:"schema_id"`
|
||||
Traits map[string]interface{} `json:"traits"`
|
||||
VerifiableAddresses []kratosVerifiableAddress `json:"verifiable_addresses"`
|
||||
RecoveryAddresses []kratosRecoveryAddress `json:"recovery_addresses"`
|
||||
}
|
||||
|
||||
func (o *OryProvider) patchIdentity(identityID string, ops []map[string]interface{}) error {
|
||||
body, _ := json.Marshal(ops)
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodPatch, fmt.Sprintf("%s/admin/identities/%s", o.KratosAdminURL, identityID), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("ory provider: build identity patch failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json-patch+json")
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ory provider: identity patch failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return fmt.Errorf("ory provider: identity patch failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
slog.Info("Ory identity patched", "identity_id", identityID, "ops", len(ops))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *OryProvider) fetchIdentity(identityID string) (*kratosIdentity, error) {
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, fmt.Sprintf("%s/admin/identities/%s", o.KratosAdminURL, identityID), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: build identity get failed: %w", err)
|
||||
}
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: identity get failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return nil, fmt.Errorf("ory provider: identity get failed status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var identity kratosIdentity
|
||||
if err := json.NewDecoder(resp.Body).Decode(&identity); err != nil {
|
||||
return nil, fmt.Errorf("ory provider: decode identity failed: %w", err)
|
||||
}
|
||||
return &identity, nil
|
||||
}
|
||||
|
||||
func (o *OryProvider) fetchIdentityFull(identityID string) (*kratosIdentityFull, error) {
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, fmt.Sprintf("%s/admin/identities/%s", o.KratosAdminURL, identityID), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: build identity get failed: %w", err)
|
||||
}
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: identity get failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return nil, fmt.Errorf("ory provider: identity get failed status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var identity kratosIdentityFull
|
||||
if err := json.NewDecoder(resp.Body).Decode(&identity); err != nil {
|
||||
return nil, fmt.Errorf("ory provider: decode identity failed: %w", err)
|
||||
}
|
||||
return &identity, nil
|
||||
}
|
||||
|
||||
// VerifyLoginCode는 Kratos 로그인 코드 제출로 세션을 발급합니다.
|
||||
func (o *OryProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
||||
if loginID == "" || flowID == "" || code == "" {
|
||||
return nil, fmt.Errorf("ory provider: loginID, flowID and code are required")
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"method": "code",
|
||||
"identifier": loginID,
|
||||
"code": code,
|
||||
})
|
||||
loginURL := fmt.Sprintf("%s/self-service/login?flow=%s", o.KratosPublicURL, flowID)
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, loginURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: build login code request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: login code request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return nil, fmt.Errorf("ory provider: login code failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
SessionToken string `json:"session_token"`
|
||||
SessionTokenExpiresAt time.Time `json:"session_token_expires_at"`
|
||||
Session struct {
|
||||
Identity struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"identity"`
|
||||
} `json:"session"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("ory provider: decode login code response failed: %w", err)
|
||||
}
|
||||
if result.SessionToken == "" {
|
||||
return nil, fmt.Errorf("ory provider: empty session token returned")
|
||||
}
|
||||
|
||||
slog.Info("Ory login code successful",
|
||||
"identity_id", result.Session.Identity.ID,
|
||||
"loginID", loginID,
|
||||
"expires_at", result.SessionTokenExpiresAt,
|
||||
)
|
||||
|
||||
return &domain.AuthInfo{
|
||||
SessionToken: &domain.Token{
|
||||
JWT: result.SessionToken,
|
||||
Expiration: result.SessionTokenExpiresAt,
|
||||
},
|
||||
Subject: result.Session.Identity.ID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetPasswordPolicy는 Ory 환경에서 사용하는 기본 정책을 반환합니다.
|
||||
func (o *OryProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
|
||||
return &domain.PasswordPolicy{
|
||||
MinLength: 12,
|
||||
Lowercase: true,
|
||||
Uppercase: false,
|
||||
Number: true,
|
||||
NonAlphanumeric: true,
|
||||
MinCharacterTypes: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// InitiatePasswordReset는 현재 내부 토큰/메일 흐름을 사용하고 있으므로 NO-OP로 둡니다.
|
||||
func (o *OryProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
|
||||
slog.Info("Ory InitiatePasswordReset bypassed (handled by app internal flow)", "loginID", loginID, "redirect", redirectUrl)
|
||||
@@ -301,8 +788,12 @@ func (o *OryProvider) httpClient() *http.Client {
|
||||
}
|
||||
|
||||
// startLoginFlow는 Kratos Public API에서 login flow ID를 발급받습니다.
|
||||
func (o *OryProvider) startLoginFlow() (string, error) {
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, fmt.Sprintf("%s/self-service/login/api", o.KratosPublicURL), nil)
|
||||
func (o *OryProvider) startLoginFlow(returnTo string) (string, error) {
|
||||
loginURL := fmt.Sprintf("%s/self-service/login/api", o.KratosPublicURL)
|
||||
if returnTo != "" {
|
||||
loginURL = loginURL + "?return_to=" + url.QueryEscape(returnTo)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, loginURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ory provider: build login flow request failed: %w", err)
|
||||
}
|
||||
|
||||
79
backend/internal/utils/masking.go
Normal file
79
backend/internal/utils/masking.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var sensitiveKeys = map[string]struct{}{
|
||||
"password": {},
|
||||
"newpassword": {},
|
||||
"oldpassword": {},
|
||||
"token": {},
|
||||
"accesstoken": {},
|
||||
"access_token": {},
|
||||
"refreshtoken": {},
|
||||
"refresh_token": {},
|
||||
"secret": {},
|
||||
"clientsecret": {},
|
||||
"client_secret": {},
|
||||
"authorization": {},
|
||||
"cookie": {},
|
||||
"set-cookie": {},
|
||||
"verificationcode": {},
|
||||
"verification_code": {},
|
||||
"code": {}, // Auth code (sensitive)
|
||||
}
|
||||
|
||||
// MaskSensitiveJSON parses a JSON byte slice and masks values of sensitive keys.
|
||||
// Returns the original data if it's not valid JSON.
|
||||
func MaskSensitiveJSON(data []byte) []byte {
|
||||
if len(data) == 0 {
|
||||
return data
|
||||
}
|
||||
|
||||
var obj interface{}
|
||||
if err := json.Unmarshal(data, &obj); err != nil {
|
||||
// Not a JSON object/array, return as is
|
||||
return data
|
||||
}
|
||||
|
||||
masked := maskValue(obj)
|
||||
|
||||
result, err := json.Marshal(masked)
|
||||
if err != nil {
|
||||
return data
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func maskValue(v interface{}) interface{} {
|
||||
switch val := v.(type) {
|
||||
case map[string]interface{}:
|
||||
newMap := make(map[string]interface{}, len(val))
|
||||
for k, v := range val {
|
||||
if isSensitive(k) {
|
||||
newMap[k] = "*****"
|
||||
} else {
|
||||
newMap[k] = maskValue(v)
|
||||
}
|
||||
}
|
||||
return newMap
|
||||
case []interface{}:
|
||||
newArr := make([]interface{}, len(val))
|
||||
for i, v := range val {
|
||||
newArr[i] = maskValue(v)
|
||||
}
|
||||
return newArr
|
||||
default:
|
||||
return val
|
||||
}
|
||||
}
|
||||
|
||||
func isSensitive(key string) bool {
|
||||
// Check case-insensitive
|
||||
// Remove common separators for looser matching? No, stick to lowercase check for now.
|
||||
k := strings.ToLower(key)
|
||||
_, ok := sensitiveKeys[k]
|
||||
return ok
|
||||
}
|
||||
59
backend/internal/utils/masking_test.go
Normal file
59
backend/internal/utils/masking_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMaskSensitiveJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string // We'll check containment or specific structure
|
||||
}{
|
||||
{
|
||||
name: "Flat object with password",
|
||||
input: `{"username": "user", "password": "secret123"}`,
|
||||
expected: `{"password":"*****","username":"user"}`,
|
||||
},
|
||||
{
|
||||
name: "Nested object with token",
|
||||
input: `{"data": {"token": "abc-def", "id": 123}}`,
|
||||
expected: `{"data":{"id":123,"token":"*****"}}`,
|
||||
},
|
||||
{
|
||||
name: "Case insensitive key",
|
||||
input: `{"NewPassword": "changed"}`,
|
||||
expected: `{"NewPassword":"*****"}`,
|
||||
},
|
||||
{
|
||||
name: "Array of objects",
|
||||
input: `[{"secret": "s1"}, {"secret": "s2"}]`,
|
||||
expected: `[{"secret":"*****"},{"secret":"*****"}]`,
|
||||
},
|
||||
{
|
||||
name: "Invalid JSON",
|
||||
input: `not-json`,
|
||||
expected: `not-json`,
|
||||
},
|
||||
{
|
||||
name: "Empty JSON",
|
||||
input: ``,
|
||||
expected: ``,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := MaskSensitiveJSON([]byte(tt.input))
|
||||
// Since JSON map order is undefined, exact string match might fail if keys are reordered.
|
||||
// Ideally we should unmarshal and compare maps, or use assert.JSONEq
|
||||
if tt.name == "Invalid JSON" || tt.name == "Empty JSON" {
|
||||
assert.Equal(t, tt.expected, string(result))
|
||||
} else {
|
||||
assert.JSONEq(t, tt.expected, string(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -29,6 +29,26 @@ func (m *MockProvider) SignIn(loginID, password string) (*domain.AuthInfo, error
|
||||
return &domain.AuthInfo{}, nil
|
||||
}
|
||||
|
||||
func (m *MockProvider) UserExists(loginID string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *MockProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
|
||||
func (m *MockProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
|
||||
func (m *MockProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
|
||||
func (m *MockProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
|
||||
// Stub implementations to satisfy the IdentityProvider interface for this unit test.
|
||||
func (m *MockProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user