forked from baron/baron-sso
chore(headless-login): add request correlation logs
This commit is contained in:
@@ -126,6 +126,34 @@ type headlessClientAssertionClaims struct {
|
|||||||
|
|
||||||
type headlessAssertionAud []string
|
type headlessAssertionAud []string
|
||||||
|
|
||||||
|
type headlessLoginFailure struct {
|
||||||
|
status int
|
||||||
|
code string
|
||||||
|
safeMessage string
|
||||||
|
logMessage string
|
||||||
|
debugFields map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *headlessLoginFailure) Error() string {
|
||||||
|
if e == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if e.code != "" {
|
||||||
|
return e.code
|
||||||
|
}
|
||||||
|
return e.safeMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHeadlessLoginFailure(status int, code, safeMessage, logMessage string, debugFields map[string]any) *headlessLoginFailure {
|
||||||
|
return &headlessLoginFailure{
|
||||||
|
status: status,
|
||||||
|
code: code,
|
||||||
|
safeMessage: safeMessage,
|
||||||
|
logMessage: logMessage,
|
||||||
|
debugFields: debugFields,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (a *headlessAssertionAud) UnmarshalJSON(data []byte) error {
|
func (a *headlessAssertionAud) UnmarshalJSON(data []byte) error {
|
||||||
var single string
|
var single string
|
||||||
if err := json.Unmarshal(data, &single); err == nil {
|
if err := json.Unmarshal(data, &single); err == nil {
|
||||||
@@ -1742,6 +1770,117 @@ func containsHeadlessAudience(expected []string, actual headlessAssertionAud) bo
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func headlessRequestID(c *fiber.Ctx) string {
|
||||||
|
if c == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
reqID := strings.TrimSpace(c.GetRespHeader(fiber.HeaderXRequestID))
|
||||||
|
if reqID != "" {
|
||||||
|
return reqID
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(c.Get(fiber.HeaderXRequestID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func isHeadlessDebugLoggingEnabled() bool {
|
||||||
|
return slog.Default().Enabled(context.Background(), slog.LevelDebug)
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateHeadlessLogValue(value string, limit int) string {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if limit <= 0 || len(value) <= limit {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return value[:limit]
|
||||||
|
}
|
||||||
|
|
||||||
|
func logHeadlessLoginFailure(c *fiber.Ctx, message string, failure *headlessLoginFailure, clientID, loginChallenge string) {
|
||||||
|
if failure == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
args := []any{
|
||||||
|
"reason_code", failure.code,
|
||||||
|
"client_id", strings.TrimSpace(clientID),
|
||||||
|
"path", c.Path(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if reqID := headlessRequestID(c); reqID != "" {
|
||||||
|
args = append(args, "req_id", reqID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if trimmedChallenge := truncateHeadlessLogValue(loginChallenge, 12); trimmedChallenge != "" {
|
||||||
|
args = append(args, "login_challenge_prefix", trimmedChallenge)
|
||||||
|
}
|
||||||
|
|
||||||
|
if isHeadlessDebugLoggingEnabled() {
|
||||||
|
keys := make([]string, 0, len(failure.debugFields))
|
||||||
|
for key := range failure.debugFields {
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
for _, key := range keys {
|
||||||
|
args = append(args, key, failure.debugFields[key])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
level := slog.LevelWarn
|
||||||
|
if failure.status >= 500 {
|
||||||
|
level = slog.LevelError
|
||||||
|
}
|
||||||
|
slog.Log(context.Background(), level, message, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func logHeadlessLoginSuccess(c *fiber.Ctx, clientID, loginChallenge, redirectTo string) {
|
||||||
|
args := []any{
|
||||||
|
"client_id", strings.TrimSpace(clientID),
|
||||||
|
"path", c.Path(),
|
||||||
|
"response_status", fiber.StatusOK,
|
||||||
|
}
|
||||||
|
|
||||||
|
if reqID := headlessRequestID(c); reqID != "" {
|
||||||
|
args = append(args, "req_id", reqID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if trimmedChallenge := truncateHeadlessLogValue(loginChallenge, 12); trimmedChallenge != "" {
|
||||||
|
args = append(args, "login_challenge_prefix", trimmedChallenge)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, err := url.Parse(redirectTo)
|
||||||
|
if err != nil {
|
||||||
|
args = append(args, "redirect_to_length", len(redirectTo), "redirect_parse_error", err.Error())
|
||||||
|
slog.Info("headless password login succeeded", args...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
query := parsed.Query()
|
||||||
|
args = append(
|
||||||
|
args,
|
||||||
|
"redirect_to_length", len(redirectTo),
|
||||||
|
"redirect_to_host", parsed.Host,
|
||||||
|
"redirect_to_path", parsed.Path,
|
||||||
|
"redirect_has_login_verifier", query.Has("login_verifier"),
|
||||||
|
"redirect_has_redirect_uri", query.Has("redirect_uri"),
|
||||||
|
)
|
||||||
|
slog.Info("headless password login succeeded", args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func respondHeadlessLoginFailure(c *fiber.Ctx, failure *headlessLoginFailure) error {
|
||||||
|
if failure == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return errorJSONCode(c, failure.status, failure.code, failure.safeMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHeadlessCredentialFailure(status int, code, safeMessage string) *headlessLoginFailure {
|
||||||
|
return newHeadlessLoginFailure(
|
||||||
|
status,
|
||||||
|
code,
|
||||||
|
safeMessage,
|
||||||
|
"headless password login credential authentication failed",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *AuthHandler) loadHeadlessJWKS(ctx context.Context, client domain.HydraClient, expectedKid string) (*jose.JSONWebKeySet, bool, error) {
|
func (h *AuthHandler) loadHeadlessJWKS(ctx context.Context, client domain.HydraClient, expectedKid string) (*jose.JSONWebKeySet, bool, error) {
|
||||||
if h.HeadlessJWKS == nil {
|
if h.HeadlessJWKS == nil {
|
||||||
h.HeadlessJWKS = service.NewHeadlessJWKSCacheService(h.RedisService, nil)
|
h.HeadlessJWKS = service.NewHeadlessJWKSCacheService(h.RedisService, nil)
|
||||||
@@ -1753,30 +1892,75 @@ func (h *AuthHandler) loadHeadlessJWKS(ctx context.Context, client domain.HydraC
|
|||||||
return keySet, refreshed, nil
|
return keySet, refreshed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateHeadlessClientAssertionClaims(c *fiber.Ctx, claims headlessClientAssertionClaims, clientID string) error {
|
func validateHeadlessClientAssertionClaims(c *fiber.Ctx, claims headlessClientAssertionClaims, clientID string) *headlessLoginFailure {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
|
debugFields := map[string]any{
|
||||||
|
"claim_issuer": claims.Issuer,
|
||||||
|
"claim_subject": claims.Subject,
|
||||||
|
"claim_expires_at": claims.ExpiresAt,
|
||||||
|
"claim_not_before": claims.NotBefore,
|
||||||
|
"claim_issued_at": claims.IssuedAt,
|
||||||
|
"received_audiences": []string(claims.Audience),
|
||||||
|
"expected_audiences": headlessAssertionAudiences(c),
|
||||||
|
}
|
||||||
if claims.Issuer != clientID || claims.Subject != clientID {
|
if claims.Issuer != clientID || claims.Subject != clientID {
|
||||||
return fmt.Errorf("client assertion iss/sub mismatch")
|
return newHeadlessLoginFailure(
|
||||||
|
fiber.StatusUnauthorized,
|
||||||
|
"invalid_client_assertion_iss_sub",
|
||||||
|
"Client assertion issuer or subject mismatch",
|
||||||
|
"headless password login client assertion claims mismatch",
|
||||||
|
debugFields,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if claims.ExpiresAt == 0 || claims.ExpiresAt <= now {
|
if claims.ExpiresAt == 0 || claims.ExpiresAt <= now {
|
||||||
return fmt.Errorf("client assertion expired")
|
return newHeadlessLoginFailure(
|
||||||
|
fiber.StatusUnauthorized,
|
||||||
|
"invalid_client_assertion_expired",
|
||||||
|
"Client assertion has expired",
|
||||||
|
"headless password login client assertion expired",
|
||||||
|
debugFields,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if claims.NotBefore != 0 && claims.NotBefore > now {
|
if claims.NotBefore != 0 && claims.NotBefore > now {
|
||||||
return fmt.Errorf("client assertion not active yet")
|
return newHeadlessLoginFailure(
|
||||||
|
fiber.StatusUnauthorized,
|
||||||
|
"invalid_client_assertion_not_before",
|
||||||
|
"Client assertion is not active yet",
|
||||||
|
"headless password login client assertion not active yet",
|
||||||
|
debugFields,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if claims.IssuedAt != 0 && claims.IssuedAt > now+60 {
|
if claims.IssuedAt != 0 && claims.IssuedAt > now+60 {
|
||||||
return fmt.Errorf("client assertion issued in the future")
|
return newHeadlessLoginFailure(
|
||||||
|
fiber.StatusUnauthorized,
|
||||||
|
"invalid_client_assertion_iat_future",
|
||||||
|
"Client assertion issued-at time is invalid",
|
||||||
|
"headless password login client assertion issued in the future",
|
||||||
|
debugFields,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if !containsHeadlessAudience(headlessAssertionAudiences(c), claims.Audience) {
|
if !containsHeadlessAudience(headlessAssertionAudiences(c), claims.Audience) {
|
||||||
return fmt.Errorf("client assertion audience mismatch")
|
return newHeadlessLoginFailure(
|
||||||
|
fiber.StatusUnauthorized,
|
||||||
|
"invalid_client_assertion_audience",
|
||||||
|
"Client assertion audience mismatch",
|
||||||
|
"headless password login client assertion audience mismatch",
|
||||||
|
debugFields,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.HydraClient, clientID, clientAssertion string) error {
|
func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.HydraClient, clientID, clientAssertion string) *headlessLoginFailure {
|
||||||
assertion := strings.TrimSpace(clientAssertion)
|
assertion := strings.TrimSpace(clientAssertion)
|
||||||
if assertion == "" {
|
if assertion == "" {
|
||||||
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_assertion is required")
|
return newHeadlessLoginFailure(
|
||||||
|
fiber.StatusBadRequest,
|
||||||
|
"bad_request",
|
||||||
|
"client_assertion is required",
|
||||||
|
"headless password login client assertion missing",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := josejwt.ParseSigned(assertion, []jose.SignatureAlgorithm{
|
token, err := josejwt.ParseSigned(assertion, []jose.SignatureAlgorithm{
|
||||||
@@ -1786,7 +1970,13 @@ func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.
|
|||||||
jose.EdDSA,
|
jose.EdDSA,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
|
return newHeadlessLoginFailure(
|
||||||
|
fiber.StatusUnauthorized,
|
||||||
|
"invalid_client_assertion_parse",
|
||||||
|
"Client assertion format is invalid",
|
||||||
|
"headless password login client assertion parse failed",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedKid := ""
|
expectedKid := ""
|
||||||
@@ -1797,7 +1987,15 @@ func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.
|
|||||||
keySet, refreshed, err := h.loadHeadlessJWKS(c.Context(), client, expectedKid)
|
keySet, refreshed, err := h.loadHeadlessJWKS(c.Context(), client, expectedKid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to load jwks for headless client assertion", "clientID", clientID, "error", err)
|
slog.Error("failed to load jwks for headless client assertion", "clientID", clientID, "error", err)
|
||||||
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", headlessClientAssertionErrorMessage(err))
|
return newHeadlessLoginFailure(
|
||||||
|
fiber.StatusUnauthorized,
|
||||||
|
"invalid_client_assertion_jwks_load",
|
||||||
|
headlessClientAssertionErrorMessage(err),
|
||||||
|
"headless password login client assertion jwks load failed",
|
||||||
|
map[string]any{
|
||||||
|
"received_kid": expectedKid,
|
||||||
|
},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
matchingKidPresent := expectedKid != "" && containsHeadlessKeyID(keySet, expectedKid)
|
matchingKidPresent := expectedKid != "" && containsHeadlessKeyID(keySet, expectedKid)
|
||||||
@@ -1810,8 +2008,13 @@ func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.
|
|||||||
if err := token.Claims(key.Key, &claims); err != nil {
|
if err := token.Claims(key.Key, &claims); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err := validateHeadlessClientAssertionClaims(c, claims, clientID); err != nil {
|
if failure := validateHeadlessClientAssertionClaims(c, claims, clientID); failure != nil {
|
||||||
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion claims")
|
if failure.debugFields == nil {
|
||||||
|
failure.debugFields = map[string]any{}
|
||||||
|
}
|
||||||
|
failure.debugFields["received_kid"] = expectedKid
|
||||||
|
failure.debugFields["jwks_refreshed"] = refreshed
|
||||||
|
return failure
|
||||||
}
|
}
|
||||||
_ = h.HeadlessJWKS.MarkVerificationSuccess(clientID)
|
_ = h.HeadlessJWKS.MarkVerificationSuccess(clientID)
|
||||||
return nil
|
return nil
|
||||||
@@ -1829,8 +2032,13 @@ func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.
|
|||||||
if err := token.Claims(key.Key, &claims); err != nil {
|
if err := token.Claims(key.Key, &claims); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err := validateHeadlessClientAssertionClaims(c, claims, clientID); err != nil {
|
if failure := validateHeadlessClientAssertionClaims(c, claims, clientID); failure != nil {
|
||||||
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion claims")
|
if failure.debugFields == nil {
|
||||||
|
failure.debugFields = map[string]any{}
|
||||||
|
}
|
||||||
|
failure.debugFields["received_kid"] = expectedKid
|
||||||
|
failure.debugFields["jwks_refreshed"] = true
|
||||||
|
return failure
|
||||||
}
|
}
|
||||||
_ = h.HeadlessJWKS.MarkVerificationSuccess(clientID)
|
_ = h.HeadlessJWKS.MarkVerificationSuccess(clientID)
|
||||||
return nil
|
return nil
|
||||||
@@ -1838,7 +2046,16 @@ func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion signature with jwksUri")
|
return newHeadlessLoginFailure(
|
||||||
|
fiber.StatusUnauthorized,
|
||||||
|
"invalid_client_assertion_signature",
|
||||||
|
"Client assertion signature verification failed",
|
||||||
|
"headless password login client assertion signature verification failed",
|
||||||
|
map[string]any{
|
||||||
|
"received_kid": expectedKid,
|
||||||
|
"jwks_refreshed": refreshed,
|
||||||
|
},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func headlessClientAssertionErrorMessage(err error) string {
|
func headlessClientAssertionErrorMessage(err error) string {
|
||||||
@@ -2045,13 +2262,21 @@ func (h *AuthHandler) HeadlessPasswordLogin(c *fiber.Ctx) error {
|
|||||||
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
|
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil {
|
if failure := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); failure != nil {
|
||||||
return err
|
logHeadlessLoginFailure(c, failure.logMessage, failure, clientID, loginChallenge)
|
||||||
|
return respondHeadlessLoginFailure(c, failure)
|
||||||
}
|
}
|
||||||
|
|
||||||
authInfo, authErr := h.authenticatePasswordLogin(c.Context(), loginID, req.Password)
|
authInfo, authErr := h.authenticatePasswordLogin(c.Context(), loginID, req.Password)
|
||||||
if authErr != nil {
|
if authErr != nil {
|
||||||
status, code, message := passwordLoginErrorSpec(authErr)
|
status, code, message := passwordLoginErrorSpec(authErr)
|
||||||
|
logHeadlessLoginFailure(
|
||||||
|
c,
|
||||||
|
"headless password login credential authentication failed",
|
||||||
|
newHeadlessCredentialFailure(status, code, message),
|
||||||
|
clientID,
|
||||||
|
loginChallenge,
|
||||||
|
)
|
||||||
return errorJSONCode(c, status, code, message)
|
return errorJSONCode(c, status, code, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2064,11 +2289,15 @@ func (h *AuthHandler) HeadlessPasswordLogin(c *fiber.Ctx) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logOidcRedirectSummary("headless_password_login", acceptResp.RedirectTo)
|
logOidcRedirectSummary("headless_password_login", acceptResp.RedirectTo)
|
||||||
return c.JSON(fiber.Map{
|
if err := c.Status(fiber.StatusOK).JSON(fiber.Map{
|
||||||
"redirectTo": acceptResp.RedirectTo,
|
"redirectTo": acceptResp.RedirectTo,
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
"provider": h.IdpProvider.Name(),
|
"provider": h.IdpProvider.Name(),
|
||||||
})
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logHeadlessLoginSuccess(c, clientID, loginChallenge, acceptResp.RedirectTo)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *AuthHandler) startHeadlessPhoneLink(c *fiber.Ctx, loginID string) (fiber.Map, string, string, time.Duration, error) {
|
func (h *AuthHandler) startHeadlessPhoneLink(c *fiber.Ctx, loginID string) (fiber.Map, string, string, time.Duration, error) {
|
||||||
@@ -2194,8 +2423,9 @@ func (h *AuthHandler) HeadlessLinkInit(c *fiber.Ctx) error {
|
|||||||
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
|
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil {
|
if failure := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); failure != nil {
|
||||||
return err
|
logHeadlessLoginFailure(c, failure.logMessage, failure, clientID, loginChallenge)
|
||||||
|
return respondHeadlessLoginFailure(c, failure)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, pendingRef, resolvedLoginID, ttl, err := h.startHeadlessPhoneLink(c, loginID)
|
resp, pendingRef, resolvedLoginID, ttl, err := h.startHeadlessPhoneLink(c, loginID)
|
||||||
@@ -2255,8 +2485,9 @@ func (h *AuthHandler) HeadlessLinkPoll(c *fiber.Ctx) error {
|
|||||||
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
|
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil {
|
if failure := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); failure != nil {
|
||||||
return err
|
logHeadlessLoginFailure(c, failure.logMessage, failure, clientID, state.LoginChallenge)
|
||||||
|
return respondHeadlessLoginFailure(c, failure)
|
||||||
}
|
}
|
||||||
|
|
||||||
val, err := h.RedisService.Get(prefixSession + pendingRef)
|
val, err := h.RedisService.Get(prefixSession + pendingRef)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -202,6 +203,35 @@ func mustHeadlessClientAssertion(t *testing.T, privateKey *rsa.PrivateKey, clien
|
|||||||
return raw
|
return raw
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mustHeadlessClientAssertionWithCustomClaims(
|
||||||
|
t *testing.T,
|
||||||
|
privateKey *rsa.PrivateKey,
|
||||||
|
clientID string,
|
||||||
|
claims josejwt.Claims,
|
||||||
|
) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
signer, err := jose.NewSigner(jose.SigningKey{
|
||||||
|
Algorithm: jose.RS256,
|
||||||
|
Key: jose.JSONWebKey{
|
||||||
|
Key: privateKey,
|
||||||
|
KeyID: "test-kid",
|
||||||
|
Use: "sig",
|
||||||
|
Algorithm: string(jose.RS256),
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create signer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
raw, err := josejwt.Signed(signer).Claims(claims).Serialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to sign client assertion: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
|
||||||
func mustHeadlessJWKForAlgorithm(t *testing.T, alg jose.SignatureAlgorithm) (any, map[string]any) {
|
func mustHeadlessJWKForAlgorithm(t *testing.T, alg jose.SignatureAlgorithm) (any, map[string]any) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
@@ -379,6 +409,94 @@ func runHeadlessPasswordLoginWithAssertion(t *testing.T, jwks map[string]any, cl
|
|||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func runHeadlessPasswordLoginWithAssertionAndLogger(
|
||||||
|
t *testing.T,
|
||||||
|
jwks map[string]any,
|
||||||
|
clientAssertion string,
|
||||||
|
logger *slog.Logger,
|
||||||
|
) *http.Response {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
mockIdp := new(MockIdentityProvider)
|
||||||
|
mockIdp.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{
|
||||||
|
SessionToken: &domain.Token{JWT: "valid-jwt"},
|
||||||
|
Subject: "kratos-identity-id",
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
mockKratos := new(MockKratosAdminService)
|
||||||
|
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil)
|
||||||
|
|
||||||
|
jwksBody, err := json.Marshal(jwks)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal jwks body: %v", err)
|
||||||
|
}
|
||||||
|
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write(jwksBody)
|
||||||
|
}))
|
||||||
|
t.Cleanup(jwksServer.Close)
|
||||||
|
|
||||||
|
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
|
||||||
|
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||||
|
Challenge: "challenge-123",
|
||||||
|
Client: domain.HydraClient{
|
||||||
|
ClientID: "headless-login-client",
|
||||||
|
TokenEndpointAuthMethod: "none",
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"status": "active",
|
||||||
|
"headless_login_enabled": true,
|
||||||
|
"headless_token_endpoint_auth_method": "private_key_jwt",
|
||||||
|
"headless_jwks_uri": jwksServer.URL + "/.well-known/jwks.json",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.NotFound(w, r)
|
||||||
|
})
|
||||||
|
|
||||||
|
h := &AuthHandler{
|
||||||
|
IdpProvider: mockIdp,
|
||||||
|
KratosAdmin: mockKratos,
|
||||||
|
Hydra: &service.HydraAdminService{
|
||||||
|
AdminURL: "http://hydra.test",
|
||||||
|
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
app := newHeadlessPasswordLoginTestApp(h)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]string{
|
||||||
|
"client_id": "headless-login-client",
|
||||||
|
"client_assertion": clientAssertion,
|
||||||
|
"loginId": "employee001",
|
||||||
|
"password": "password",
|
||||||
|
"login_challenge": "challenge-123",
|
||||||
|
})
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("X-Request-Id", "req-headless-test-123")
|
||||||
|
|
||||||
|
if logger != nil {
|
||||||
|
previous := slog.Default()
|
||||||
|
slog.SetDefault(logger.With())
|
||||||
|
t.Cleanup(func() {
|
||||||
|
slog.SetDefault(previous)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
// mockHydraTransport simulates Hydra API responses
|
// mockHydraTransport simulates Hydra API responses
|
||||||
func mockHydraTransport(handler http.Handler) http.RoundTripper {
|
func mockHydraTransport(handler http.Handler) http.RoundTripper {
|
||||||
return roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
return roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
@@ -900,6 +1018,222 @@ func TestHeadlessPasswordLogin_InvalidClientAssertionRejected(t *testing.T) {
|
|||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
t.Fatalf("expected 401, got %d, body: %s", resp.StatusCode, string(bodyBytes))
|
t.Fatalf("expected 401, got %d, body: %s", resp.StatusCode, string(bodyBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var got map[string]any
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||||
|
t.Fatalf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
if got["code"] != "invalid_client_assertion_signature" {
|
||||||
|
t.Fatalf("expected code=invalid_client_assertion_signature, got=%v", got["code"])
|
||||||
|
}
|
||||||
|
if got["error"] != "Client assertion signature verification failed" {
|
||||||
|
t.Fatalf("expected detailed signature error, got=%v", got["error"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeadlessPasswordLogin_AudienceMismatchReturnsDetailedCode(t *testing.T) {
|
||||||
|
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||||
|
clientAssertion := mustHeadlessClientAssertion(
|
||||||
|
t,
|
||||||
|
privateKey,
|
||||||
|
"headless-login-client",
|
||||||
|
"https://rp.example.com/oidc/token",
|
||||||
|
)
|
||||||
|
|
||||||
|
resp := runHeadlessPasswordLoginWithAssertion(t, jwks, clientAssertion)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("expected 401, got %d, body: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
var got map[string]any
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||||
|
t.Fatalf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
if got["code"] != "invalid_client_assertion_audience" {
|
||||||
|
t.Fatalf("expected code=invalid_client_assertion_audience, got=%v", got["code"])
|
||||||
|
}
|
||||||
|
if got["error"] != "Client assertion audience mismatch" {
|
||||||
|
t.Fatalf("expected audience mismatch error, got=%v", got["error"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeadlessPasswordLogin_IssSubMismatchReturnsDetailedCode(t *testing.T) {
|
||||||
|
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||||
|
now := time.Now()
|
||||||
|
clientAssertion := mustHeadlessClientAssertionWithCustomClaims(
|
||||||
|
t,
|
||||||
|
privateKey,
|
||||||
|
"headless-login-client",
|
||||||
|
josejwt.Claims{
|
||||||
|
Issuer: "other-client",
|
||||||
|
Subject: "headless-login-client",
|
||||||
|
Audience: josejwt.Audience{"http://example.com/api/v1/auth/headless/password/login"},
|
||||||
|
Expiry: josejwt.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
IssuedAt: josejwt.NewNumericDate(now),
|
||||||
|
NotBefore: josejwt.NewNumericDate(
|
||||||
|
now.Add(-1 * time.Minute),
|
||||||
|
),
|
||||||
|
ID: "assertion-iss-mismatch",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
resp := runHeadlessPasswordLoginWithAssertion(t, jwks, clientAssertion)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("expected 401, got %d, body: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
var got map[string]any
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||||
|
t.Fatalf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
if got["code"] != "invalid_client_assertion_iss_sub" {
|
||||||
|
t.Fatalf("expected code=invalid_client_assertion_iss_sub, got=%v", got["code"])
|
||||||
|
}
|
||||||
|
if got["error"] != "Client assertion issuer or subject mismatch" {
|
||||||
|
t.Fatalf("expected iss/sub mismatch error, got=%v", got["error"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeadlessPasswordLogin_ExpiredAssertionReturnsDetailedCode(t *testing.T) {
|
||||||
|
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||||
|
now := time.Now()
|
||||||
|
clientAssertion := mustHeadlessClientAssertionWithCustomClaims(
|
||||||
|
t,
|
||||||
|
privateKey,
|
||||||
|
"headless-login-client",
|
||||||
|
josejwt.Claims{
|
||||||
|
Issuer: "headless-login-client",
|
||||||
|
Subject: "headless-login-client",
|
||||||
|
Audience: josejwt.Audience{"http://example.com/api/v1/auth/headless/password/login"},
|
||||||
|
Expiry: josejwt.NewNumericDate(now.Add(-1 * time.Minute)),
|
||||||
|
IssuedAt: josejwt.NewNumericDate(now.Add(-10 * time.Minute)),
|
||||||
|
NotBefore: josejwt.NewNumericDate(
|
||||||
|
now.Add(-11 * time.Minute),
|
||||||
|
),
|
||||||
|
ID: "assertion-expired",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
resp := runHeadlessPasswordLoginWithAssertion(t, jwks, clientAssertion)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("expected 401, got %d, body: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
var got map[string]any
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||||
|
t.Fatalf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
if got["code"] != "invalid_client_assertion_expired" {
|
||||||
|
t.Fatalf("expected code=invalid_client_assertion_expired, got=%v", got["code"])
|
||||||
|
}
|
||||||
|
if got["error"] != "Client assertion has expired" {
|
||||||
|
t.Fatalf("expected expired assertion error, got=%v", got["error"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeadlessPasswordLogin_DebugLogIncludesDiagnostics(t *testing.T) {
|
||||||
|
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||||
|
clientAssertion := mustHeadlessClientAssertion(
|
||||||
|
t,
|
||||||
|
privateKey,
|
||||||
|
"headless-login-client",
|
||||||
|
"https://rp.example.com/oidc/token",
|
||||||
|
)
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
logger := slog.New(slog.NewJSONHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||||
|
|
||||||
|
resp := runHeadlessPasswordLoginWithAssertionAndLogger(t, jwks, clientAssertion, logger)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("expected 401, got %d, body: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
logOutput := buf.String()
|
||||||
|
if !strings.Contains(logOutput, "expected_audiences") {
|
||||||
|
t.Fatalf("expected debug log to include expected_audiences, got=%s", logOutput)
|
||||||
|
}
|
||||||
|
if !strings.Contains(logOutput, "received_audiences") {
|
||||||
|
t.Fatalf("expected debug log to include received_audiences, got=%s", logOutput)
|
||||||
|
}
|
||||||
|
if !strings.Contains(logOutput, "invalid_client_assertion_audience") {
|
||||||
|
t.Fatalf("expected debug log to include reason code, got=%s", logOutput)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeadlessPasswordLogin_InfoLogOmitsDebugDiagnostics(t *testing.T) {
|
||||||
|
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||||
|
clientAssertion := mustHeadlessClientAssertion(
|
||||||
|
t,
|
||||||
|
privateKey,
|
||||||
|
"headless-login-client",
|
||||||
|
"https://rp.example.com/oidc/token",
|
||||||
|
)
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
logger := slog.New(slog.NewJSONHandler(buf, &slog.HandlerOptions{Level: slog.LevelInfo}))
|
||||||
|
|
||||||
|
resp := runHeadlessPasswordLoginWithAssertionAndLogger(t, jwks, clientAssertion, logger)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("expected 401, got %d, body: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
logOutput := buf.String()
|
||||||
|
if strings.Contains(logOutput, "expected_audiences") {
|
||||||
|
t.Fatalf("expected info log to omit expected_audiences, got=%s", logOutput)
|
||||||
|
}
|
||||||
|
if strings.Contains(logOutput, "received_audiences") {
|
||||||
|
t.Fatalf("expected info log to omit received_audiences, got=%s", logOutput)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeadlessPasswordLogin_SuccessLogIncludesCorrelationFields(t *testing.T) {
|
||||||
|
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||||
|
clientAssertion := mustHeadlessClientAssertion(
|
||||||
|
t,
|
||||||
|
privateKey,
|
||||||
|
"headless-login-client",
|
||||||
|
"http://example.com/api/v1/auth/headless/password/login",
|
||||||
|
)
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
logger := slog.New(slog.NewJSONHandler(buf, &slog.HandlerOptions{Level: slog.LevelInfo}))
|
||||||
|
|
||||||
|
resp := runHeadlessPasswordLoginWithAssertionAndLogger(t, jwks, clientAssertion, logger)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("expected 200, got %d, body: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
logOutput := buf.String()
|
||||||
|
for _, needle := range []string{
|
||||||
|
"headless password login succeeded",
|
||||||
|
`"req_id":"req-headless-test-123"`,
|
||||||
|
`"path":"/api/v1/auth/headless/password/login"`,
|
||||||
|
`"client_id":"headless-login-client"`,
|
||||||
|
`"login_challenge_prefix":"challenge-12"`,
|
||||||
|
`"response_status":200`,
|
||||||
|
} {
|
||||||
|
if !strings.Contains(logOutput, needle) {
|
||||||
|
t.Fatalf("expected success log to include %s, got=%s", needle, logOutput)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHeadlessPasswordLogin_AcceptsConfiguredClientAssertionAlgorithms(t *testing.T) {
|
func TestHeadlessPasswordLogin_AcceptsConfiguredClientAssertionAlgorithms(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user