forked from baron/baron-sso
420 lines
12 KiB
Go
420 lines
12 KiB
Go
package main
|
|
|
|
import (
|
|
"baron-sso-backend/internal/domain"
|
|
authhandler "baron-sso-backend/internal/handler"
|
|
"baron-sso-backend/internal/middleware"
|
|
"baron-sso-backend/internal/service"
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"encoding/json"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-jose/go-jose/v4"
|
|
josejwt "github.com/go-jose/go-jose/v4/jwt"
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/gofiber/fiber/v2/middleware/recover"
|
|
"github.com/gofiber/fiber/v2/middleware/requestid"
|
|
"github.com/stretchr/testify/mock"
|
|
)
|
|
|
|
type roundTripFunc func(req *http.Request) (*http.Response, error)
|
|
|
|
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
return f(req)
|
|
}
|
|
|
|
type e2eMockIdentityProvider struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *e2eMockIdentityProvider) Name() string {
|
|
return "mock-idp"
|
|
}
|
|
|
|
func (m *e2eMockIdentityProvider) GetMetadata() (*domain.IDPMetadata, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *e2eMockIdentityProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
|
|
return "", nil
|
|
}
|
|
|
|
func (m *e2eMockIdentityProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
|
|
args := m.Called(loginID, password)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*domain.AuthInfo), args.Error(1)
|
|
}
|
|
|
|
func (m *e2eMockIdentityProvider) UserExists(loginID string) (bool, error) {
|
|
return true, nil
|
|
}
|
|
|
|
func (m *e2eMockIdentityProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *e2eMockIdentityProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *e2eMockIdentityProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *e2eMockIdentityProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *e2eMockIdentityProvider) InitiatePasswordReset(loginID, redirectURL string) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *e2eMockIdentityProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *e2eMockIdentityProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
|
|
return nil
|
|
}
|
|
|
|
type e2eMockKratosAdminService struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *e2eMockKratosAdminService) FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error) {
|
|
args := m.Called(ctx, identifier)
|
|
return args.String(0), args.Error(1)
|
|
}
|
|
|
|
func (m *e2eMockKratosAdminService) GetIdentity(ctx context.Context, id string) (*service.KratosIdentity, error) {
|
|
args := m.Called(ctx, id)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*service.KratosIdentity), args.Error(1)
|
|
}
|
|
|
|
func (m *e2eMockKratosAdminService) ListIdentities(ctx context.Context) ([]service.KratosIdentity, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *e2eMockKratosAdminService) UpdateIdentity(ctx context.Context, identityID string, traits map[string]interface{}, state string) (*service.KratosIdentity, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *e2eMockKratosAdminService) UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *e2eMockKratosAdminService) DeleteIdentity(ctx context.Context, identityID string) error {
|
|
return nil
|
|
}
|
|
|
|
func newHeadlessLoginE2EApp(h *authhandler.AuthHandler, appEnv string) *fiber.App {
|
|
app := fiber.New(fiber.Config{
|
|
DisableStartupMessage: true,
|
|
ErrorHandler: newErrorHandler(appEnv),
|
|
})
|
|
|
|
app.Use(requestid.New(requestid.Config{
|
|
Generator: func() string {
|
|
return "req-e2e-headless"
|
|
},
|
|
}))
|
|
|
|
app.Use(func(c *fiber.Ctx) error {
|
|
start := time.Now()
|
|
err := c.Next()
|
|
|
|
status := c.Response().StatusCode()
|
|
if status < 400 {
|
|
return err
|
|
}
|
|
|
|
msg := "http_request"
|
|
if err != nil {
|
|
msg = "http_request_error"
|
|
}
|
|
|
|
slog.Info(msg,
|
|
"status", status,
|
|
"method", c.Method(),
|
|
"path", c.Path(),
|
|
"latency", time.Since(start).String(),
|
|
"ip", c.IP(),
|
|
"req_id", c.GetRespHeader(fiber.HeaderXRequestID),
|
|
)
|
|
return err
|
|
})
|
|
|
|
app.Use(recover.New(recover.Config{EnableStackTrace: true}))
|
|
app.Use(middleware.ErrorCodeEnricher())
|
|
|
|
api := app.Group("/api/v1")
|
|
auth := api.Group("/auth")
|
|
auth.Post("/headless/password/login", h.HeadlessPasswordLogin)
|
|
|
|
return app
|
|
}
|
|
|
|
func mustE2EHeadlessRSAJWK(t *testing.T) (*rsa.PrivateKey, map[string]any) {
|
|
t.Helper()
|
|
|
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("failed to generate rsa key: %v", err)
|
|
}
|
|
|
|
keySet := jose.JSONWebKeySet{
|
|
Keys: []jose.JSONWebKey{
|
|
{
|
|
Key: &privateKey.PublicKey,
|
|
KeyID: "test-kid",
|
|
Use: "sig",
|
|
Algorithm: string(jose.RS256),
|
|
},
|
|
},
|
|
}
|
|
|
|
raw, err := json.Marshal(keySet)
|
|
if err != nil {
|
|
t.Fatalf("failed to marshal jwks: %v", err)
|
|
}
|
|
|
|
var jwks map[string]any
|
|
if err := json.Unmarshal(raw, &jwks); err != nil {
|
|
t.Fatalf("failed to decode jwks map: %v", err)
|
|
}
|
|
|
|
return privateKey, jwks
|
|
}
|
|
|
|
func mustE2EHeadlessClientAssertion(t *testing.T, privateKey *rsa.PrivateKey, clientID, audience string) string {
|
|
t.Helper()
|
|
|
|
signer, err := jose.NewSigner(jose.SigningKey{
|
|
Algorithm: jose.RS256,
|
|
Key: jose.JSONWebKey{
|
|
Key: privateKey,
|
|
KeyID: "test-kid",
|
|
Use: "sig",
|
|
Algorithm: string(jose.RS256),
|
|
},
|
|
}, nil)
|
|
if err != nil {
|
|
t.Fatalf("failed to create signer: %v", err)
|
|
}
|
|
|
|
now := time.Now()
|
|
raw, err := josejwt.Signed(signer).Claims(josejwt.Claims{
|
|
Issuer: clientID,
|
|
Subject: clientID,
|
|
Audience: josejwt.Audience{audience},
|
|
Expiry: josejwt.NewNumericDate(now.Add(5 * time.Minute)),
|
|
IssuedAt: josejwt.NewNumericDate(now),
|
|
NotBefore: josejwt.NewNumericDate(now.Add(-1 * time.Minute)),
|
|
ID: "assertion-e2e",
|
|
}).Serialize()
|
|
if err != nil {
|
|
t.Fatalf("failed to sign client assertion: %v", err)
|
|
}
|
|
|
|
return raw
|
|
}
|
|
|
|
func mockHydraTransportForE2E(handler http.Handler) http.RoundTripper {
|
|
return roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
|
w := httptest.NewRecorder()
|
|
handler.ServeHTTP(w, req)
|
|
return w.Result(), nil
|
|
})
|
|
}
|
|
|
|
func runHeadlessPasswordLoginE2E(
|
|
t *testing.T,
|
|
logger *slog.Logger,
|
|
appEnv string,
|
|
jwks map[string]any,
|
|
clientAssertion string,
|
|
) (*http.Response, string) {
|
|
t.Helper()
|
|
|
|
logBuffer := &bytes.Buffer{}
|
|
if logger == nil {
|
|
logger = slog.New(slog.NewJSONHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelInfo}))
|
|
}
|
|
|
|
previous := slog.Default()
|
|
slog.SetDefault(logger)
|
|
t.Cleanup(func() {
|
|
slog.SetDefault(previous)
|
|
})
|
|
|
|
mockIDP := new(e2eMockIdentityProvider)
|
|
mockIDP.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{
|
|
SessionToken: &domain.Token{JWT: "valid-jwt"},
|
|
Subject: "kratos-identity-id",
|
|
}, nil)
|
|
|
|
mockKratos := new(e2eMockKratosAdminService)
|
|
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil)
|
|
|
|
jwksBody, err := json.Marshal(jwks)
|
|
if err != nil {
|
|
t.Fatalf("failed to marshal jwks body: %v", err)
|
|
}
|
|
|
|
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write(jwksBody)
|
|
}))
|
|
t.Cleanup(jwksServer.Close)
|
|
|
|
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
|
|
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
|
Challenge: "challenge-123",
|
|
Client: domain.HydraClient{
|
|
ClientID: "headless-login-client",
|
|
TokenEndpointAuthMethod: "none",
|
|
Metadata: map[string]interface{}{
|
|
"status": "active",
|
|
"headless_login_enabled": true,
|
|
"headless_token_endpoint_auth_method": "private_key_jwt",
|
|
"headless_jwks_uri": jwksServer.URL + "/.well-known/jwks.json",
|
|
},
|
|
},
|
|
})
|
|
return
|
|
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
|
|
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
|
|
return
|
|
}
|
|
|
|
http.NotFound(w, r)
|
|
})
|
|
|
|
h := &authhandler.AuthHandler{
|
|
IdpProvider: mockIDP,
|
|
KratosAdmin: mockKratos,
|
|
Hydra: &service.HydraAdminService{
|
|
AdminURL: "http://hydra.test",
|
|
HTTPClient: &http.Client{Transport: mockHydraTransportForE2E(hydraHandler)},
|
|
},
|
|
}
|
|
|
|
app := newHeadlessLoginE2EApp(h, appEnv)
|
|
|
|
body, _ := json.Marshal(map[string]string{
|
|
"client_id": "headless-login-client",
|
|
"client_assertion": clientAssertion,
|
|
"loginId": "employee001",
|
|
"password": "password",
|
|
"login_challenge": "challenge-123",
|
|
})
|
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("request failed: %v", err)
|
|
}
|
|
|
|
return resp, logBuffer.String()
|
|
}
|
|
|
|
func TestHeadlessPasswordLogin_E2E_ResponseIncludesDetailedCodeAndLogs(t *testing.T) {
|
|
privateKey, jwks := mustE2EHeadlessRSAJWK(t)
|
|
clientAssertion := mustE2EHeadlessClientAssertion(
|
|
t,
|
|
privateKey,
|
|
"headless-login-client",
|
|
"https://rp.example.com/oidc/token",
|
|
)
|
|
|
|
logBuffer := &bytes.Buffer{}
|
|
logger := slog.New(slog.NewJSONHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelInfo}))
|
|
|
|
resp, _ := runHeadlessPasswordLoginE2E(t, logger, "production", jwks, clientAssertion)
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusUnauthorized {
|
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
t.Fatalf("expected 401, got %d, body=%s", resp.StatusCode, string(bodyBytes))
|
|
}
|
|
|
|
var got map[string]any
|
|
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
|
t.Fatalf("failed to decode response body: %v", err)
|
|
}
|
|
|
|
if got["code"] != "invalid_client_assertion_audience" {
|
|
t.Fatalf("expected detailed code, got=%v", got["code"])
|
|
}
|
|
if got["error"] != "Client assertion audience mismatch" {
|
|
t.Fatalf("expected detailed error message, got=%v", got["error"])
|
|
}
|
|
|
|
output := logBuffer.String()
|
|
if !strings.Contains(output, "\"reason_code\":\"invalid_client_assertion_audience\"") {
|
|
t.Fatalf("expected headless failure log to include detailed reason code, got=%s", output)
|
|
}
|
|
if !strings.Contains(output, "\"req_id\":\"req-e2e-headless\"") {
|
|
t.Fatalf("expected logs to include request id, got=%s", output)
|
|
}
|
|
if !strings.Contains(output, "\"path\":\"/api/v1/auth/headless/password/login\"") {
|
|
t.Fatalf("expected request path in logs, got=%s", output)
|
|
}
|
|
}
|
|
|
|
func TestHeadlessPasswordLogin_E2E_DebugLogsIncludeDiagnostics(t *testing.T) {
|
|
privateKey, jwks := mustE2EHeadlessRSAJWK(t)
|
|
const receivedAudience = "https://sso.hmac.kr/api/v1/auth/headless/password/login"
|
|
clientAssertion := mustE2EHeadlessClientAssertion(
|
|
t,
|
|
privateKey,
|
|
"headless-login-client",
|
|
receivedAudience,
|
|
)
|
|
|
|
logBuffer := &bytes.Buffer{}
|
|
logger := slog.New(slog.NewJSONHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
|
|
|
resp, _ := runHeadlessPasswordLoginE2E(t, logger, "production", jwks, clientAssertion)
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusUnauthorized {
|
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
t.Fatalf("expected 401, got %d, body=%s", resp.StatusCode, string(bodyBytes))
|
|
}
|
|
|
|
output := logBuffer.String()
|
|
if !strings.Contains(output, "\"expected_audiences\"") {
|
|
t.Fatalf("expected debug logs to include expected_audiences, got=%s", output)
|
|
}
|
|
if !strings.Contains(output, "\"received_audiences\"") {
|
|
t.Fatalf("expected debug logs to include received_audiences, got=%s", output)
|
|
}
|
|
if !strings.Contains(output, "\"received_audiences_text\":\""+receivedAudience+"\"") {
|
|
t.Fatalf("expected debug logs to include received_audiences_text with full URL, got=%s", output)
|
|
}
|
|
if !strings.Contains(output, "\"expected_audiences_text\":\"http://example.com/api/v1/auth/headless/password/login, /api/v1/auth/headless/password/login\"") {
|
|
t.Fatalf("expected debug logs to include expected_audiences_text, got=%s", output)
|
|
}
|
|
if !strings.Contains(output, "\"login_challenge_prefix\":\"challenge-12\"") {
|
|
t.Fatalf("expected debug logs to include login challenge prefix, got=%s", output)
|
|
}
|
|
}
|