1
0
forked from baron/baron-sso

레포 업데이트

This commit is contained in:
Lectom C Han
2026-04-01 20:32:09 +09:00
parent 8bab8d44cc
commit 4b0fbdde98
31 changed files with 1636 additions and 43 deletions

View File

@@ -116,3 +116,24 @@ func TestNewErrorHandler_MapsUnauthorizedCode(t *testing.T) {
t.Fatalf("unexpected error code: %v", body["code"])
}
}
func TestShouldEnableDocs_DisabledInProductionLikeEnv(t *testing.T) {
testCases := []struct {
appEnv string
want bool
}{
{appEnv: "production", want: false},
{appEnv: "prod", want: false},
{appEnv: "stage", want: false},
{appEnv: "staging", want: false},
{appEnv: "dev", want: true},
{appEnv: "development", want: true},
}
for _, tc := range testCases {
got := shouldEnableDocs(tc.appEnv)
if got != tc.want {
t.Fatalf("appEnv=%s expected shouldEnableDocs=%v, got %v", tc.appEnv, tc.want, got)
}
}
}

View File

@@ -0,0 +1,419 @@
package main
import (
"baron-sso-backend/internal/domain"
authhandler "baron-sso-backend/internal/handler"
"baron-sso-backend/internal/middleware"
"baron-sso-backend/internal/service"
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/go-jose/go-jose/v4"
josejwt "github.com/go-jose/go-jose/v4/jwt"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/recover"
"github.com/gofiber/fiber/v2/middleware/requestid"
"github.com/stretchr/testify/mock"
)
type roundTripFunc func(req *http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
type e2eMockIdentityProvider struct {
mock.Mock
}
func (m *e2eMockIdentityProvider) Name() string {
return "mock-idp"
}
func (m *e2eMockIdentityProvider) GetMetadata() (*domain.IDPMetadata, error) {
return nil, nil
}
func (m *e2eMockIdentityProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
return "", nil
}
func (m *e2eMockIdentityProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
args := m.Called(loginID, password)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.AuthInfo), args.Error(1)
}
func (m *e2eMockIdentityProvider) UserExists(loginID string) (bool, error) {
return true, nil
}
func (m *e2eMockIdentityProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *e2eMockIdentityProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
return nil, nil
}
func (m *e2eMockIdentityProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *e2eMockIdentityProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
return nil, nil
}
func (m *e2eMockIdentityProvider) InitiatePasswordReset(loginID, redirectURL string) error {
return nil
}
func (m *e2eMockIdentityProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *e2eMockIdentityProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
return nil
}
type e2eMockKratosAdminService struct {
mock.Mock
}
func (m *e2eMockKratosAdminService) FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error) {
args := m.Called(ctx, identifier)
return args.String(0), args.Error(1)
}
func (m *e2eMockKratosAdminService) GetIdentity(ctx context.Context, id string) (*service.KratosIdentity, error) {
args := m.Called(ctx, id)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*service.KratosIdentity), args.Error(1)
}
func (m *e2eMockKratosAdminService) ListIdentities(ctx context.Context) ([]service.KratosIdentity, error) {
return nil, nil
}
func (m *e2eMockKratosAdminService) UpdateIdentity(ctx context.Context, identityID string, traits map[string]interface{}, state string) (*service.KratosIdentity, error) {
return nil, nil
}
func (m *e2eMockKratosAdminService) UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error {
return nil
}
func (m *e2eMockKratosAdminService) DeleteIdentity(ctx context.Context, identityID string) error {
return nil
}
func newHeadlessLoginE2EApp(h *authhandler.AuthHandler, appEnv string) *fiber.App {
app := fiber.New(fiber.Config{
DisableStartupMessage: true,
ErrorHandler: newErrorHandler(appEnv),
})
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return "req-e2e-headless"
},
}))
app.Use(func(c *fiber.Ctx) error {
start := time.Now()
err := c.Next()
status := c.Response().StatusCode()
if status < 400 {
return err
}
msg := "http_request"
if err != nil {
msg = "http_request_error"
}
slog.Info(msg,
"status", status,
"method", c.Method(),
"path", c.Path(),
"latency", time.Since(start).String(),
"ip", c.IP(),
"req_id", c.GetRespHeader(fiber.HeaderXRequestID),
)
return err
})
app.Use(recover.New(recover.Config{EnableStackTrace: true}))
app.Use(middleware.ErrorCodeEnricher())
api := app.Group("/api/v1")
auth := api.Group("/auth")
auth.Post("/headless/password/login", h.HeadlessPasswordLogin)
return app
}
func mustE2EHeadlessRSAJWK(t *testing.T) (*rsa.PrivateKey, map[string]any) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("failed to generate rsa key: %v", err)
}
keySet := jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
{
Key: &privateKey.PublicKey,
KeyID: "test-kid",
Use: "sig",
Algorithm: string(jose.RS256),
},
},
}
raw, err := json.Marshal(keySet)
if err != nil {
t.Fatalf("failed to marshal jwks: %v", err)
}
var jwks map[string]any
if err := json.Unmarshal(raw, &jwks); err != nil {
t.Fatalf("failed to decode jwks map: %v", err)
}
return privateKey, jwks
}
func mustE2EHeadlessClientAssertion(t *testing.T, privateKey *rsa.PrivateKey, clientID, audience string) string {
t.Helper()
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.RS256,
Key: jose.JSONWebKey{
Key: privateKey,
KeyID: "test-kid",
Use: "sig",
Algorithm: string(jose.RS256),
},
}, nil)
if err != nil {
t.Fatalf("failed to create signer: %v", err)
}
now := time.Now()
raw, err := josejwt.Signed(signer).Claims(josejwt.Claims{
Issuer: clientID,
Subject: clientID,
Audience: josejwt.Audience{audience},
Expiry: josejwt.NewNumericDate(now.Add(5 * time.Minute)),
IssuedAt: josejwt.NewNumericDate(now),
NotBefore: josejwt.NewNumericDate(now.Add(-1 * time.Minute)),
ID: "assertion-e2e",
}).Serialize()
if err != nil {
t.Fatalf("failed to sign client assertion: %v", err)
}
return raw
}
func mockHydraTransportForE2E(handler http.Handler) http.RoundTripper {
return roundTripFunc(func(req *http.Request) (*http.Response, error) {
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
return w.Result(), nil
})
}
func runHeadlessPasswordLoginE2E(
t *testing.T,
logger *slog.Logger,
appEnv string,
jwks map[string]any,
clientAssertion string,
) (*http.Response, string) {
t.Helper()
logBuffer := &bytes.Buffer{}
if logger == nil {
logger = slog.New(slog.NewJSONHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelInfo}))
}
previous := slog.Default()
slog.SetDefault(logger)
t.Cleanup(func() {
slog.SetDefault(previous)
})
mockIDP := new(e2eMockIdentityProvider)
mockIDP.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{
SessionToken: &domain.Token{JWT: "valid-jwt"},
Subject: "kratos-identity-id",
}, nil)
mockKratos := new(e2eMockKratosAdminService)
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil)
jwksBody, err := json.Marshal(jwks)
if err != nil {
t.Fatalf("failed to marshal jwks body: %v", err)
}
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(jwksBody)
}))
t.Cleanup(jwksServer.Close)
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "headless-login-client",
TokenEndpointAuthMethod: "none",
Metadata: map[string]interface{}{
"status": "active",
"headless_login_enabled": true,
"headless_token_endpoint_auth_method": "private_key_jwt",
"headless_jwks_uri": jwksServer.URL + "/.well-known/jwks.json",
},
},
})
return
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
return
}
http.NotFound(w, r)
})
h := &authhandler.AuthHandler{
IdpProvider: mockIDP,
KratosAdmin: mockKratos,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransportForE2E(hydraHandler)},
},
}
app := newHeadlessLoginE2EApp(h, appEnv)
body, _ := json.Marshal(map[string]string{
"client_id": "headless-login-client",
"client_assertion": clientAssertion,
"loginId": "employee001",
"password": "password",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
return resp, logBuffer.String()
}
func TestHeadlessPasswordLogin_E2E_ResponseIncludesDetailedCodeAndLogs(t *testing.T) {
privateKey, jwks := mustE2EHeadlessRSAJWK(t)
clientAssertion := mustE2EHeadlessClientAssertion(
t,
privateKey,
"headless-login-client",
"https://rp.example.com/oidc/token",
)
logBuffer := &bytes.Buffer{}
logger := slog.New(slog.NewJSONHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelInfo}))
resp, _ := runHeadlessPasswordLoginE2E(t, logger, "production", jwks, clientAssertion)
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 401, got %d, body=%s", resp.StatusCode, string(bodyBytes))
}
var got map[string]any
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
t.Fatalf("failed to decode response body: %v", err)
}
if got["code"] != "invalid_client_assertion_audience" {
t.Fatalf("expected detailed code, got=%v", got["code"])
}
if got["error"] != "Client assertion audience mismatch" {
t.Fatalf("expected detailed error message, got=%v", got["error"])
}
output := logBuffer.String()
if !strings.Contains(output, "\"reason_code\":\"invalid_client_assertion_audience\"") {
t.Fatalf("expected headless failure log to include detailed reason code, got=%s", output)
}
if !strings.Contains(output, "\"req_id\":\"req-e2e-headless\"") {
t.Fatalf("expected logs to include request id, got=%s", output)
}
if !strings.Contains(output, "\"path\":\"/api/v1/auth/headless/password/login\"") {
t.Fatalf("expected request path in logs, got=%s", output)
}
}
func TestHeadlessPasswordLogin_E2E_DebugLogsIncludeDiagnostics(t *testing.T) {
privateKey, jwks := mustE2EHeadlessRSAJWK(t)
const receivedAudience = "https://sso.hmac.kr/api/v1/auth/headless/password/login"
clientAssertion := mustE2EHeadlessClientAssertion(
t,
privateKey,
"headless-login-client",
receivedAudience,
)
logBuffer := &bytes.Buffer{}
logger := slog.New(slog.NewJSONHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
resp, _ := runHeadlessPasswordLoginE2E(t, logger, "production", jwks, clientAssertion)
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 401, got %d, body=%s", resp.StatusCode, string(bodyBytes))
}
output := logBuffer.String()
if !strings.Contains(output, "\"expected_audiences\"") {
t.Fatalf("expected debug logs to include expected_audiences, got=%s", output)
}
if !strings.Contains(output, "\"received_audiences\"") {
t.Fatalf("expected debug logs to include received_audiences, got=%s", output)
}
if !strings.Contains(output, "\"received_audiences_text\":\""+receivedAudience+"\"") {
t.Fatalf("expected debug logs to include received_audiences_text with full URL, got=%s", output)
}
if !strings.Contains(output, "\"expected_audiences_text\":\"http://example.com/api/v1/auth/headless/password/login, /api/v1/auth/headless/password/login\"") {
t.Fatalf("expected debug logs to include expected_audiences_text, got=%s", output)
}
if !strings.Contains(output, "\"login_challenge_prefix\":\"challenge-12\"") {
t.Fatalf("expected debug logs to include login challenge prefix, got=%s", output)
}
}

View File

@@ -50,6 +50,10 @@ func normalizeDocsPrefix(prefix string) string {
return strings.TrimRight(trimmed, "/")
}
func shouldEnableDocs(appEnv string) bool {
return !logger.IsProductionLikeEnv(appEnv)
}
func registerDocsRoutes(app *fiber.App, prefix string) {
base := normalizeDocsPrefix(prefix)
docsPath := base + "/docs"
@@ -90,9 +94,11 @@ func main() {
}
// 0. Initialize Logger
appEnvForLogger := getEnv("APP_ENV", getEnv("GO_ENV", "dev"))
logger.Init(logger.Config{
ServiceName: "baron-sso",
Environment: getEnv("GO_ENV", "dev"),
ServiceName: "baron-sso",
Environment: appEnvForLogger,
LevelOverride: getEnv("BACKEND_LOG_LEVEL", ""),
})
// Initialize Snowflake Node (Node 2 for Baron)
node, err := snowflake.NewNode(2)
@@ -407,7 +413,7 @@ func main() {
}))
// [Security] Disable Swagger/ReDoc in Production
if appEnv != "production" {
if shouldEnableDocs(appEnv) {
docsPrefix := getEnv("DOCS_BASE_PATH", "/api")
registerDocsRoutes(app, "")
if normalized := normalizeDocsPrefix(docsPrefix); normalized != "" {
@@ -415,7 +421,7 @@ func main() {
}
slog.Info("📚 API Docs enabled", "swagger", "/docs", "redoc", "/redoc", "docs_prefix", docsPrefix)
} else {
slog.Info("🔒 API Docs disabled in production")
slog.Info("🔒 API Docs disabled in production-like environment", "app_env", appEnv)
}
slog.Info("Client log policy configured",
"app_env", appEnv,

View File

@@ -1770,6 +1770,21 @@ func containsHeadlessAudience(expected []string, actual headlessAssertionAud) bo
return false
}
func joinHeadlessAudiences(values []string) string {
if len(values) == 0 {
return ""
}
trimmed := make([]string, 0, len(values))
for _, value := range values {
value = strings.TrimSpace(value)
if value == "" {
continue
}
trimmed = append(trimmed, value)
}
return strings.Join(trimmed, ", ")
}
func headlessRequestID(c *fiber.Ctx) string {
if c == nil {
return ""
@@ -1894,14 +1909,18 @@ func (h *AuthHandler) loadHeadlessJWKS(ctx context.Context, client domain.HydraC
func validateHeadlessClientAssertionClaims(c *fiber.Ctx, claims headlessClientAssertionClaims, clientID string) *headlessLoginFailure {
now := time.Now().Unix()
expectedAudiences := headlessAssertionAudiences(c)
receivedAudiences := []string(claims.Audience)
debugFields := map[string]any{
"claim_issuer": claims.Issuer,
"claim_subject": claims.Subject,
"claim_expires_at": claims.ExpiresAt,
"claim_not_before": claims.NotBefore,
"claim_issued_at": claims.IssuedAt,
"received_audiences": []string(claims.Audience),
"expected_audiences": headlessAssertionAudiences(c),
"claim_issuer": claims.Issuer,
"claim_subject": claims.Subject,
"claim_expires_at": claims.ExpiresAt,
"claim_not_before": claims.NotBefore,
"claim_issued_at": claims.IssuedAt,
"received_audiences": receivedAudiences,
"expected_audiences": expectedAudiences,
"received_audiences_text": joinHeadlessAudiences(receivedAudiences),
"expected_audiences_text": joinHeadlessAudiences(expectedAudiences),
}
if claims.Issuer != clientID || claims.Subject != clientID {
return newHeadlessLoginFailure(
@@ -1939,7 +1958,7 @@ func validateHeadlessClientAssertionClaims(c *fiber.Ctx, claims headlessClientAs
debugFields,
)
}
if !containsHeadlessAudience(headlessAssertionAudiences(c), claims.Audience) {
if !containsHeadlessAudience(expectedAudiences, claims.Audience) {
return newHeadlessLoginFailure(
fiber.StatusUnauthorized,
"invalid_client_assertion_audience",

View File

@@ -34,8 +34,7 @@ var (
)
func IsProductionEnv(appEnv string) bool {
env := strings.ToLower(strings.TrimSpace(appEnv))
return env == "prod" || env == "production"
return IsProductionLikeEnv(appEnv)
}
func parseBoolFlag(raw string) bool {

View File

@@ -15,6 +15,8 @@ func TestClientDebugEnabled(t *testing.T) {
t.Run("production disables debug by default", func(t *testing.T) {
assert.False(t, ClientDebugEnabled("production", ""))
assert.False(t, ClientDebugEnabled("prod", "false"))
assert.False(t, ClientDebugEnabled("stage", ""))
assert.False(t, ClientDebugEnabled("staging", "false"))
})
t.Run("production accepts explicit debug override", func(t *testing.T) {
@@ -27,14 +29,19 @@ func TestClientDebugEnabled(t *testing.T) {
func TestShouldAcceptClientLog(t *testing.T) {
assert.False(t, ShouldAcceptClientLog("production", "", "INFO"))
assert.False(t, ShouldAcceptClientLog("production", "", "DEBUG"))
assert.False(t, ShouldAcceptClientLog("stage", "", "INFO"))
assert.False(t, ShouldAcceptClientLog("stage", "", "DEBUG"))
assert.True(t, ShouldAcceptClientLog("production", "", "WARN"))
assert.True(t, ShouldAcceptClientLog("production", "", "ERROR"))
assert.True(t, ShouldAcceptClientLog("stage", "", "WARN"))
assert.True(t, ShouldAcceptClientLog("stage", "", "ERROR"))
assert.True(t, ShouldAcceptClientLog("production", "true", "INFO"))
assert.True(t, ShouldAcceptClientLog("dev", "", "INFO"))
}
func TestShouldFilterNoisyClientInfo(t *testing.T) {
assert.True(t, ShouldFilterNoisyClientInfo("production", "", "Navigating to /ko/signin"))
assert.True(t, ShouldFilterNoisyClientInfo("stage", "", "Navigating to /ko/signin"))
assert.False(t, ShouldFilterNoisyClientInfo("production", "true", "Navigating to /ko/signin"))
assert.False(t, ShouldFilterNoisyClientInfo("dev", "", "Navigating to /ko/signin"))
}

View File

@@ -1,6 +1,7 @@
package logger
import (
"io"
"log/slog"
"os"
"strings"
@@ -8,18 +9,28 @@ import (
// Config holds the logger configuration
type Config struct {
ServiceName string
Environment string // "dev", "local", "production"
ServiceName string
Environment string // APP_ENV 기준
LevelOverride string
Output io.Writer
}
func IsProductionLikeEnv(appEnv string) bool {
env := strings.ToLower(strings.TrimSpace(appEnv))
return env == "prod" || env == "production" || env == "stage" || env == "staging"
}
// Init initializes the global logger with slog.
// It detects the environment to switch between TextHandler (dev) and JSONHandler (prod).
func Init(cfg Config) {
var handler slog.Handler
output := cfg.Output
if output == nil {
output = os.Stdout
}
opts := &slog.HandlerOptions{
// Default level
Level: slog.LevelInfo,
Level: ResolveBackendLogLevel(cfg.Environment, cfg.LevelOverride),
// Customize attributes (Time format)
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
if a.Key == slog.TimeKey {
@@ -32,11 +43,10 @@ func Init(cfg Config) {
// Adjust level and format based on environment
env := strings.ToLower(cfg.Environment)
if env == "dev" || env == "local" || env == "development" {
opts.Level = slog.LevelDebug
handler = slog.NewTextHandler(os.Stdout, opts)
handler = slog.NewTextHandler(output, opts)
} else {
// Production defaults to JSON
handler = slog.NewJSONHandler(os.Stdout, opts)
handler = slog.NewJSONHandler(output, opts)
}
// Create logger with common attributes
@@ -47,3 +57,22 @@ func Init(cfg Config) {
// Set as global default logger
slog.SetDefault(logger)
}
func ResolveBackendLogLevel(appEnv, override string) slog.Level {
switch strings.ToLower(strings.TrimSpace(override)) {
case "debug":
return slog.LevelDebug
case "info":
return slog.LevelInfo
case "warn", "warning":
return slog.LevelWarn
case "error":
return slog.LevelError
}
env := strings.ToLower(strings.TrimSpace(appEnv))
if env == "dev" || env == "local" || env == "development" {
return slog.LevelDebug
}
return slog.LevelInfo
}

View File

@@ -0,0 +1,69 @@
package logger
import (
"bytes"
"log/slog"
"strings"
"testing"
)
func TestResolveBackendLogLevel_DefaultsByAppEnv(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
appEnv string
wantLevel slog.Level
}{
{name: "dev uses debug", appEnv: "dev", wantLevel: slog.LevelDebug},
{name: "local uses debug", appEnv: "local", wantLevel: slog.LevelDebug},
{name: "development uses debug", appEnv: "development", wantLevel: slog.LevelDebug},
{name: "stage uses info", appEnv: "stage", wantLevel: slog.LevelInfo},
{name: "production uses info", appEnv: "production", wantLevel: slog.LevelInfo},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := ResolveBackendLogLevel(tc.appEnv, "")
if got != tc.wantLevel {
t.Fatalf("expected level %v, got %v", tc.wantLevel, got)
}
})
}
}
func TestResolveBackendLogLevel_OverrideWins(t *testing.T) {
t.Parallel()
got := ResolveBackendLogLevel("production", "debug")
if got != slog.LevelDebug {
t.Fatalf("expected debug override, got %v", got)
}
got = ResolveBackendLogLevel("dev", "warn")
if got != slog.LevelWarn {
t.Fatalf("expected warn override, got %v", got)
}
}
func TestInit_UsesResolvedBackendLogLevel(t *testing.T) {
var buf bytes.Buffer
previous := slog.Default()
defer slog.SetDefault(previous)
Init(Config{
ServiceName: "baron-sso",
Environment: "stage",
LevelOverride: "debug",
Output: &buf,
})
slog.Debug("debug message should be visible")
if !strings.Contains(buf.String(), "debug message should be visible") {
t.Fatalf("expected debug log to be written, got=%s", buf.String())
}
}

View File

@@ -1,6 +1,7 @@
package service
import (
"baron-sso-backend/internal/logger"
"os"
"strings"
)
@@ -10,7 +11,7 @@ func IsProductionEnv() bool {
if env == "" {
env = strings.ToLower(os.Getenv("GO_ENV"))
}
return env == "prod" || env == "production"
return logger.IsProductionLikeEnv(env)
}
func IsDryRunAllowed() bool {

View File

@@ -0,0 +1,43 @@
package service
import (
"os"
"testing"
)
func TestIsProductionEnv_StageIsProductionLike(t *testing.T) {
t.Setenv("APP_ENV", "stage")
t.Setenv("GO_ENV", "")
if !IsProductionEnv() {
t.Fatalf("expected stage to be treated as production-like")
}
}
func TestIsDryRunAllowed_DisabledInStage(t *testing.T) {
t.Setenv("APP_ENV", "stage")
t.Setenv("GO_ENV", "")
if IsDryRunAllowed() {
t.Fatalf("expected dry-run to be disabled in stage")
}
}
func TestIsProductionEnv_FallsBackToGoEnv(t *testing.T) {
originalAppEnv, hadAppEnv := os.LookupEnv("APP_ENV")
if hadAppEnv {
t.Cleanup(func() {
_ = os.Setenv("APP_ENV", originalAppEnv)
})
} else {
t.Cleanup(func() {
_ = os.Unsetenv("APP_ENV")
})
}
_ = os.Unsetenv("APP_ENV")
t.Setenv("GO_ENV", "production")
if !IsProductionEnv() {
t.Fatalf("expected GO_ENV=production fallback to be production-like")
}
}