forked from baron/baron-sso
레포 업데이트
This commit is contained in:
@@ -116,3 +116,24 @@ func TestNewErrorHandler_MapsUnauthorizedCode(t *testing.T) {
|
||||
t.Fatalf("unexpected error code: %v", body["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldEnableDocs_DisabledInProductionLikeEnv(t *testing.T) {
|
||||
testCases := []struct {
|
||||
appEnv string
|
||||
want bool
|
||||
}{
|
||||
{appEnv: "production", want: false},
|
||||
{appEnv: "prod", want: false},
|
||||
{appEnv: "stage", want: false},
|
||||
{appEnv: "staging", want: false},
|
||||
{appEnv: "dev", want: true},
|
||||
{appEnv: "development", want: true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
got := shouldEnableDocs(tc.appEnv)
|
||||
if got != tc.want {
|
||||
t.Fatalf("appEnv=%s expected shouldEnableDocs=%v, got %v", tc.appEnv, tc.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
419
backend/cmd/server/headless_login_e2e_test.go
Normal file
419
backend/cmd/server/headless_login_e2e_test.go
Normal file
@@ -0,0 +1,419 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
authhandler "baron-sso-backend/internal/handler"
|
||||
"baron-sso-backend/internal/middleware"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
josejwt "github.com/go-jose/go-jose/v4/jwt"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/recover"
|
||||
"github.com/gofiber/fiber/v2/middleware/requestid"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type roundTripFunc func(req *http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
type e2eMockIdentityProvider struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *e2eMockIdentityProvider) Name() string {
|
||||
return "mock-idp"
|
||||
}
|
||||
|
||||
func (m *e2eMockIdentityProvider) GetMetadata() (*domain.IDPMetadata, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *e2eMockIdentityProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *e2eMockIdentityProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
|
||||
args := m.Called(loginID, password)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.AuthInfo), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *e2eMockIdentityProvider) UserExists(loginID string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *e2eMockIdentityProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *e2eMockIdentityProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *e2eMockIdentityProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *e2eMockIdentityProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *e2eMockIdentityProvider) InitiatePasswordReset(loginID, redirectURL string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *e2eMockIdentityProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *e2eMockIdentityProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type e2eMockKratosAdminService struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *e2eMockKratosAdminService) FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error) {
|
||||
args := m.Called(ctx, identifier)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *e2eMockKratosAdminService) GetIdentity(ctx context.Context, id string) (*service.KratosIdentity, error) {
|
||||
args := m.Called(ctx, id)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*service.KratosIdentity), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *e2eMockKratosAdminService) ListIdentities(ctx context.Context) ([]service.KratosIdentity, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *e2eMockKratosAdminService) UpdateIdentity(ctx context.Context, identityID string, traits map[string]interface{}, state string) (*service.KratosIdentity, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *e2eMockKratosAdminService) UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *e2eMockKratosAdminService) DeleteIdentity(ctx context.Context, identityID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newHeadlessLoginE2EApp(h *authhandler.AuthHandler, appEnv string) *fiber.App {
|
||||
app := fiber.New(fiber.Config{
|
||||
DisableStartupMessage: true,
|
||||
ErrorHandler: newErrorHandler(appEnv),
|
||||
})
|
||||
|
||||
app.Use(requestid.New(requestid.Config{
|
||||
Generator: func() string {
|
||||
return "req-e2e-headless"
|
||||
},
|
||||
}))
|
||||
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
start := time.Now()
|
||||
err := c.Next()
|
||||
|
||||
status := c.Response().StatusCode()
|
||||
if status < 400 {
|
||||
return err
|
||||
}
|
||||
|
||||
msg := "http_request"
|
||||
if err != nil {
|
||||
msg = "http_request_error"
|
||||
}
|
||||
|
||||
slog.Info(msg,
|
||||
"status", status,
|
||||
"method", c.Method(),
|
||||
"path", c.Path(),
|
||||
"latency", time.Since(start).String(),
|
||||
"ip", c.IP(),
|
||||
"req_id", c.GetRespHeader(fiber.HeaderXRequestID),
|
||||
)
|
||||
return err
|
||||
})
|
||||
|
||||
app.Use(recover.New(recover.Config{EnableStackTrace: true}))
|
||||
app.Use(middleware.ErrorCodeEnricher())
|
||||
|
||||
api := app.Group("/api/v1")
|
||||
auth := api.Group("/auth")
|
||||
auth.Post("/headless/password/login", h.HeadlessPasswordLogin)
|
||||
|
||||
return app
|
||||
}
|
||||
|
||||
func mustE2EHeadlessRSAJWK(t *testing.T) (*rsa.PrivateKey, map[string]any) {
|
||||
t.Helper()
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate rsa key: %v", err)
|
||||
}
|
||||
|
||||
keySet := jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{
|
||||
{
|
||||
Key: &privateKey.PublicKey,
|
||||
KeyID: "test-kid",
|
||||
Use: "sig",
|
||||
Algorithm: string(jose.RS256),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(keySet)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal jwks: %v", err)
|
||||
}
|
||||
|
||||
var jwks map[string]any
|
||||
if err := json.Unmarshal(raw, &jwks); err != nil {
|
||||
t.Fatalf("failed to decode jwks map: %v", err)
|
||||
}
|
||||
|
||||
return privateKey, jwks
|
||||
}
|
||||
|
||||
func mustE2EHeadlessClientAssertion(t *testing.T, privateKey *rsa.PrivateKey, clientID, audience string) string {
|
||||
t.Helper()
|
||||
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.RS256,
|
||||
Key: jose.JSONWebKey{
|
||||
Key: privateKey,
|
||||
KeyID: "test-kid",
|
||||
Use: "sig",
|
||||
Algorithm: string(jose.RS256),
|
||||
},
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create signer: %v", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
raw, err := josejwt.Signed(signer).Claims(josejwt.Claims{
|
||||
Issuer: clientID,
|
||||
Subject: clientID,
|
||||
Audience: josejwt.Audience{audience},
|
||||
Expiry: josejwt.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
IssuedAt: josejwt.NewNumericDate(now),
|
||||
NotBefore: josejwt.NewNumericDate(now.Add(-1 * time.Minute)),
|
||||
ID: "assertion-e2e",
|
||||
}).Serialize()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to sign client assertion: %v", err)
|
||||
}
|
||||
|
||||
return raw
|
||||
}
|
||||
|
||||
func mockHydraTransportForE2E(handler http.Handler) http.RoundTripper {
|
||||
return roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
return w.Result(), nil
|
||||
})
|
||||
}
|
||||
|
||||
func runHeadlessPasswordLoginE2E(
|
||||
t *testing.T,
|
||||
logger *slog.Logger,
|
||||
appEnv string,
|
||||
jwks map[string]any,
|
||||
clientAssertion string,
|
||||
) (*http.Response, string) {
|
||||
t.Helper()
|
||||
|
||||
logBuffer := &bytes.Buffer{}
|
||||
if logger == nil {
|
||||
logger = slog.New(slog.NewJSONHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelInfo}))
|
||||
}
|
||||
|
||||
previous := slog.Default()
|
||||
slog.SetDefault(logger)
|
||||
t.Cleanup(func() {
|
||||
slog.SetDefault(previous)
|
||||
})
|
||||
|
||||
mockIDP := new(e2eMockIdentityProvider)
|
||||
mockIDP.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{
|
||||
SessionToken: &domain.Token{JWT: "valid-jwt"},
|
||||
Subject: "kratos-identity-id",
|
||||
}, nil)
|
||||
|
||||
mockKratos := new(e2eMockKratosAdminService)
|
||||
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil)
|
||||
|
||||
jwksBody, err := json.Marshal(jwks)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal jwks body: %v", err)
|
||||
}
|
||||
|
||||
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(jwksBody)
|
||||
}))
|
||||
t.Cleanup(jwksServer.Close)
|
||||
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
|
||||
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||
Challenge: "challenge-123",
|
||||
Client: domain.HydraClient{
|
||||
ClientID: "headless-login-client",
|
||||
TokenEndpointAuthMethod: "none",
|
||||
Metadata: map[string]interface{}{
|
||||
"status": "active",
|
||||
"headless_login_enabled": true,
|
||||
"headless_token_endpoint_auth_method": "private_key_jwt",
|
||||
"headless_jwks_uri": jwksServer.URL + "/.well-known/jwks.json",
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
|
||||
return
|
||||
}
|
||||
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
h := &authhandler.AuthHandler{
|
||||
IdpProvider: mockIDP,
|
||||
KratosAdmin: mockKratos,
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: mockHydraTransportForE2E(hydraHandler)},
|
||||
},
|
||||
}
|
||||
|
||||
app := newHeadlessLoginE2EApp(h, appEnv)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"client_id": "headless-login-client",
|
||||
"client_assertion": clientAssertion,
|
||||
"loginId": "employee001",
|
||||
"password": "password",
|
||||
"login_challenge": "challenge-123",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
|
||||
return resp, logBuffer.String()
|
||||
}
|
||||
|
||||
func TestHeadlessPasswordLogin_E2E_ResponseIncludesDetailedCodeAndLogs(t *testing.T) {
|
||||
privateKey, jwks := mustE2EHeadlessRSAJWK(t)
|
||||
clientAssertion := mustE2EHeadlessClientAssertion(
|
||||
t,
|
||||
privateKey,
|
||||
"headless-login-client",
|
||||
"https://rp.example.com/oidc/token",
|
||||
)
|
||||
|
||||
logBuffer := &bytes.Buffer{}
|
||||
logger := slog.New(slog.NewJSONHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelInfo}))
|
||||
|
||||
resp, _ := runHeadlessPasswordLoginE2E(t, logger, "production", jwks, clientAssertion)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("expected 401, got %d, body=%s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||
t.Fatalf("failed to decode response body: %v", err)
|
||||
}
|
||||
|
||||
if got["code"] != "invalid_client_assertion_audience" {
|
||||
t.Fatalf("expected detailed code, got=%v", got["code"])
|
||||
}
|
||||
if got["error"] != "Client assertion audience mismatch" {
|
||||
t.Fatalf("expected detailed error message, got=%v", got["error"])
|
||||
}
|
||||
|
||||
output := logBuffer.String()
|
||||
if !strings.Contains(output, "\"reason_code\":\"invalid_client_assertion_audience\"") {
|
||||
t.Fatalf("expected headless failure log to include detailed reason code, got=%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "\"req_id\":\"req-e2e-headless\"") {
|
||||
t.Fatalf("expected logs to include request id, got=%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "\"path\":\"/api/v1/auth/headless/password/login\"") {
|
||||
t.Fatalf("expected request path in logs, got=%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadlessPasswordLogin_E2E_DebugLogsIncludeDiagnostics(t *testing.T) {
|
||||
privateKey, jwks := mustE2EHeadlessRSAJWK(t)
|
||||
const receivedAudience = "https://sso.hmac.kr/api/v1/auth/headless/password/login"
|
||||
clientAssertion := mustE2EHeadlessClientAssertion(
|
||||
t,
|
||||
privateKey,
|
||||
"headless-login-client",
|
||||
receivedAudience,
|
||||
)
|
||||
|
||||
logBuffer := &bytes.Buffer{}
|
||||
logger := slog.New(slog.NewJSONHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
|
||||
resp, _ := runHeadlessPasswordLoginE2E(t, logger, "production", jwks, clientAssertion)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("expected 401, got %d, body=%s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
output := logBuffer.String()
|
||||
if !strings.Contains(output, "\"expected_audiences\"") {
|
||||
t.Fatalf("expected debug logs to include expected_audiences, got=%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "\"received_audiences\"") {
|
||||
t.Fatalf("expected debug logs to include received_audiences, got=%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "\"received_audiences_text\":\""+receivedAudience+"\"") {
|
||||
t.Fatalf("expected debug logs to include received_audiences_text with full URL, got=%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "\"expected_audiences_text\":\"http://example.com/api/v1/auth/headless/password/login, /api/v1/auth/headless/password/login\"") {
|
||||
t.Fatalf("expected debug logs to include expected_audiences_text, got=%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "\"login_challenge_prefix\":\"challenge-12\"") {
|
||||
t.Fatalf("expected debug logs to include login challenge prefix, got=%s", output)
|
||||
}
|
||||
}
|
||||
@@ -50,6 +50,10 @@ func normalizeDocsPrefix(prefix string) string {
|
||||
return strings.TrimRight(trimmed, "/")
|
||||
}
|
||||
|
||||
func shouldEnableDocs(appEnv string) bool {
|
||||
return !logger.IsProductionLikeEnv(appEnv)
|
||||
}
|
||||
|
||||
func registerDocsRoutes(app *fiber.App, prefix string) {
|
||||
base := normalizeDocsPrefix(prefix)
|
||||
docsPath := base + "/docs"
|
||||
@@ -90,9 +94,11 @@ func main() {
|
||||
}
|
||||
|
||||
// 0. Initialize Logger
|
||||
appEnvForLogger := getEnv("APP_ENV", getEnv("GO_ENV", "dev"))
|
||||
logger.Init(logger.Config{
|
||||
ServiceName: "baron-sso",
|
||||
Environment: getEnv("GO_ENV", "dev"),
|
||||
ServiceName: "baron-sso",
|
||||
Environment: appEnvForLogger,
|
||||
LevelOverride: getEnv("BACKEND_LOG_LEVEL", ""),
|
||||
})
|
||||
// Initialize Snowflake Node (Node 2 for Baron)
|
||||
node, err := snowflake.NewNode(2)
|
||||
@@ -407,7 +413,7 @@ func main() {
|
||||
}))
|
||||
|
||||
// [Security] Disable Swagger/ReDoc in Production
|
||||
if appEnv != "production" {
|
||||
if shouldEnableDocs(appEnv) {
|
||||
docsPrefix := getEnv("DOCS_BASE_PATH", "/api")
|
||||
registerDocsRoutes(app, "")
|
||||
if normalized := normalizeDocsPrefix(docsPrefix); normalized != "" {
|
||||
@@ -415,7 +421,7 @@ func main() {
|
||||
}
|
||||
slog.Info("📚 API Docs enabled", "swagger", "/docs", "redoc", "/redoc", "docs_prefix", docsPrefix)
|
||||
} else {
|
||||
slog.Info("🔒 API Docs disabled in production")
|
||||
slog.Info("🔒 API Docs disabled in production-like environment", "app_env", appEnv)
|
||||
}
|
||||
slog.Info("Client log policy configured",
|
||||
"app_env", appEnv,
|
||||
|
||||
@@ -1770,6 +1770,21 @@ func containsHeadlessAudience(expected []string, actual headlessAssertionAud) bo
|
||||
return false
|
||||
}
|
||||
|
||||
func joinHeadlessAudiences(values []string) string {
|
||||
if len(values) == 0 {
|
||||
return ""
|
||||
}
|
||||
trimmed := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
trimmed = append(trimmed, value)
|
||||
}
|
||||
return strings.Join(trimmed, ", ")
|
||||
}
|
||||
|
||||
func headlessRequestID(c *fiber.Ctx) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
@@ -1894,14 +1909,18 @@ func (h *AuthHandler) loadHeadlessJWKS(ctx context.Context, client domain.HydraC
|
||||
|
||||
func validateHeadlessClientAssertionClaims(c *fiber.Ctx, claims headlessClientAssertionClaims, clientID string) *headlessLoginFailure {
|
||||
now := time.Now().Unix()
|
||||
expectedAudiences := headlessAssertionAudiences(c)
|
||||
receivedAudiences := []string(claims.Audience)
|
||||
debugFields := map[string]any{
|
||||
"claim_issuer": claims.Issuer,
|
||||
"claim_subject": claims.Subject,
|
||||
"claim_expires_at": claims.ExpiresAt,
|
||||
"claim_not_before": claims.NotBefore,
|
||||
"claim_issued_at": claims.IssuedAt,
|
||||
"received_audiences": []string(claims.Audience),
|
||||
"expected_audiences": headlessAssertionAudiences(c),
|
||||
"claim_issuer": claims.Issuer,
|
||||
"claim_subject": claims.Subject,
|
||||
"claim_expires_at": claims.ExpiresAt,
|
||||
"claim_not_before": claims.NotBefore,
|
||||
"claim_issued_at": claims.IssuedAt,
|
||||
"received_audiences": receivedAudiences,
|
||||
"expected_audiences": expectedAudiences,
|
||||
"received_audiences_text": joinHeadlessAudiences(receivedAudiences),
|
||||
"expected_audiences_text": joinHeadlessAudiences(expectedAudiences),
|
||||
}
|
||||
if claims.Issuer != clientID || claims.Subject != clientID {
|
||||
return newHeadlessLoginFailure(
|
||||
@@ -1939,7 +1958,7 @@ func validateHeadlessClientAssertionClaims(c *fiber.Ctx, claims headlessClientAs
|
||||
debugFields,
|
||||
)
|
||||
}
|
||||
if !containsHeadlessAudience(headlessAssertionAudiences(c), claims.Audience) {
|
||||
if !containsHeadlessAudience(expectedAudiences, claims.Audience) {
|
||||
return newHeadlessLoginFailure(
|
||||
fiber.StatusUnauthorized,
|
||||
"invalid_client_assertion_audience",
|
||||
|
||||
@@ -34,8 +34,7 @@ var (
|
||||
)
|
||||
|
||||
func IsProductionEnv(appEnv string) bool {
|
||||
env := strings.ToLower(strings.TrimSpace(appEnv))
|
||||
return env == "prod" || env == "production"
|
||||
return IsProductionLikeEnv(appEnv)
|
||||
}
|
||||
|
||||
func parseBoolFlag(raw string) bool {
|
||||
|
||||
@@ -15,6 +15,8 @@ func TestClientDebugEnabled(t *testing.T) {
|
||||
t.Run("production disables debug by default", func(t *testing.T) {
|
||||
assert.False(t, ClientDebugEnabled("production", ""))
|
||||
assert.False(t, ClientDebugEnabled("prod", "false"))
|
||||
assert.False(t, ClientDebugEnabled("stage", ""))
|
||||
assert.False(t, ClientDebugEnabled("staging", "false"))
|
||||
})
|
||||
|
||||
t.Run("production accepts explicit debug override", func(t *testing.T) {
|
||||
@@ -27,14 +29,19 @@ func TestClientDebugEnabled(t *testing.T) {
|
||||
func TestShouldAcceptClientLog(t *testing.T) {
|
||||
assert.False(t, ShouldAcceptClientLog("production", "", "INFO"))
|
||||
assert.False(t, ShouldAcceptClientLog("production", "", "DEBUG"))
|
||||
assert.False(t, ShouldAcceptClientLog("stage", "", "INFO"))
|
||||
assert.False(t, ShouldAcceptClientLog("stage", "", "DEBUG"))
|
||||
assert.True(t, ShouldAcceptClientLog("production", "", "WARN"))
|
||||
assert.True(t, ShouldAcceptClientLog("production", "", "ERROR"))
|
||||
assert.True(t, ShouldAcceptClientLog("stage", "", "WARN"))
|
||||
assert.True(t, ShouldAcceptClientLog("stage", "", "ERROR"))
|
||||
assert.True(t, ShouldAcceptClientLog("production", "true", "INFO"))
|
||||
assert.True(t, ShouldAcceptClientLog("dev", "", "INFO"))
|
||||
}
|
||||
|
||||
func TestShouldFilterNoisyClientInfo(t *testing.T) {
|
||||
assert.True(t, ShouldFilterNoisyClientInfo("production", "", "Navigating to /ko/signin"))
|
||||
assert.True(t, ShouldFilterNoisyClientInfo("stage", "", "Navigating to /ko/signin"))
|
||||
assert.False(t, ShouldFilterNoisyClientInfo("production", "true", "Navigating to /ko/signin"))
|
||||
assert.False(t, ShouldFilterNoisyClientInfo("dev", "", "Navigating to /ko/signin"))
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -8,18 +9,28 @@ import (
|
||||
|
||||
// Config holds the logger configuration
|
||||
type Config struct {
|
||||
ServiceName string
|
||||
Environment string // "dev", "local", "production"
|
||||
ServiceName string
|
||||
Environment string // APP_ENV 기준
|
||||
LevelOverride string
|
||||
Output io.Writer
|
||||
}
|
||||
|
||||
func IsProductionLikeEnv(appEnv string) bool {
|
||||
env := strings.ToLower(strings.TrimSpace(appEnv))
|
||||
return env == "prod" || env == "production" || env == "stage" || env == "staging"
|
||||
}
|
||||
|
||||
// Init initializes the global logger with slog.
|
||||
// It detects the environment to switch between TextHandler (dev) and JSONHandler (prod).
|
||||
func Init(cfg Config) {
|
||||
var handler slog.Handler
|
||||
output := cfg.Output
|
||||
if output == nil {
|
||||
output = os.Stdout
|
||||
}
|
||||
|
||||
opts := &slog.HandlerOptions{
|
||||
// Default level
|
||||
Level: slog.LevelInfo,
|
||||
Level: ResolveBackendLogLevel(cfg.Environment, cfg.LevelOverride),
|
||||
// Customize attributes (Time format)
|
||||
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
|
||||
if a.Key == slog.TimeKey {
|
||||
@@ -32,11 +43,10 @@ func Init(cfg Config) {
|
||||
// Adjust level and format based on environment
|
||||
env := strings.ToLower(cfg.Environment)
|
||||
if env == "dev" || env == "local" || env == "development" {
|
||||
opts.Level = slog.LevelDebug
|
||||
handler = slog.NewTextHandler(os.Stdout, opts)
|
||||
handler = slog.NewTextHandler(output, opts)
|
||||
} else {
|
||||
// Production defaults to JSON
|
||||
handler = slog.NewJSONHandler(os.Stdout, opts)
|
||||
handler = slog.NewJSONHandler(output, opts)
|
||||
}
|
||||
|
||||
// Create logger with common attributes
|
||||
@@ -47,3 +57,22 @@ func Init(cfg Config) {
|
||||
// Set as global default logger
|
||||
slog.SetDefault(logger)
|
||||
}
|
||||
|
||||
func ResolveBackendLogLevel(appEnv, override string) slog.Level {
|
||||
switch strings.ToLower(strings.TrimSpace(override)) {
|
||||
case "debug":
|
||||
return slog.LevelDebug
|
||||
case "info":
|
||||
return slog.LevelInfo
|
||||
case "warn", "warning":
|
||||
return slog.LevelWarn
|
||||
case "error":
|
||||
return slog.LevelError
|
||||
}
|
||||
|
||||
env := strings.ToLower(strings.TrimSpace(appEnv))
|
||||
if env == "dev" || env == "local" || env == "development" {
|
||||
return slog.LevelDebug
|
||||
}
|
||||
return slog.LevelInfo
|
||||
}
|
||||
|
||||
69
backend/internal/logger/logger_test.go
Normal file
69
backend/internal/logger/logger_test.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResolveBackendLogLevel_DefaultsByAppEnv(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
appEnv string
|
||||
wantLevel slog.Level
|
||||
}{
|
||||
{name: "dev uses debug", appEnv: "dev", wantLevel: slog.LevelDebug},
|
||||
{name: "local uses debug", appEnv: "local", wantLevel: slog.LevelDebug},
|
||||
{name: "development uses debug", appEnv: "development", wantLevel: slog.LevelDebug},
|
||||
{name: "stage uses info", appEnv: "stage", wantLevel: slog.LevelInfo},
|
||||
{name: "production uses info", appEnv: "production", wantLevel: slog.LevelInfo},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := ResolveBackendLogLevel(tc.appEnv, "")
|
||||
if got != tc.wantLevel {
|
||||
t.Fatalf("expected level %v, got %v", tc.wantLevel, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveBackendLogLevel_OverrideWins(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := ResolveBackendLogLevel("production", "debug")
|
||||
if got != slog.LevelDebug {
|
||||
t.Fatalf("expected debug override, got %v", got)
|
||||
}
|
||||
|
||||
got = ResolveBackendLogLevel("dev", "warn")
|
||||
if got != slog.LevelWarn {
|
||||
t.Fatalf("expected warn override, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInit_UsesResolvedBackendLogLevel(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
previous := slog.Default()
|
||||
defer slog.SetDefault(previous)
|
||||
|
||||
Init(Config{
|
||||
ServiceName: "baron-sso",
|
||||
Environment: "stage",
|
||||
LevelOverride: "debug",
|
||||
Output: &buf,
|
||||
})
|
||||
|
||||
slog.Debug("debug message should be visible")
|
||||
|
||||
if !strings.Contains(buf.String(), "debug message should be visible") {
|
||||
t.Fatalf("expected debug log to be written, got=%s", buf.String())
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/logger"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
@@ -10,7 +11,7 @@ func IsProductionEnv() bool {
|
||||
if env == "" {
|
||||
env = strings.ToLower(os.Getenv("GO_ENV"))
|
||||
}
|
||||
return env == "prod" || env == "production"
|
||||
return logger.IsProductionLikeEnv(env)
|
||||
}
|
||||
|
||||
func IsDryRunAllowed() bool {
|
||||
|
||||
43
backend/internal/service/dry_run_service_test.go
Normal file
43
backend/internal/service/dry_run_service_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsProductionEnv_StageIsProductionLike(t *testing.T) {
|
||||
t.Setenv("APP_ENV", "stage")
|
||||
t.Setenv("GO_ENV", "")
|
||||
|
||||
if !IsProductionEnv() {
|
||||
t.Fatalf("expected stage to be treated as production-like")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsDryRunAllowed_DisabledInStage(t *testing.T) {
|
||||
t.Setenv("APP_ENV", "stage")
|
||||
t.Setenv("GO_ENV", "")
|
||||
|
||||
if IsDryRunAllowed() {
|
||||
t.Fatalf("expected dry-run to be disabled in stage")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsProductionEnv_FallsBackToGoEnv(t *testing.T) {
|
||||
originalAppEnv, hadAppEnv := os.LookupEnv("APP_ENV")
|
||||
if hadAppEnv {
|
||||
t.Cleanup(func() {
|
||||
_ = os.Setenv("APP_ENV", originalAppEnv)
|
||||
})
|
||||
} else {
|
||||
t.Cleanup(func() {
|
||||
_ = os.Unsetenv("APP_ENV")
|
||||
})
|
||||
}
|
||||
_ = os.Unsetenv("APP_ENV")
|
||||
t.Setenv("GO_ENV", "production")
|
||||
|
||||
if !IsProductionEnv() {
|
||||
t.Fatalf("expected GO_ENV=production fallback to be production-like")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user