1
0
forked from baron/baron-sso

feat(auth): add trusted rp headless login flows

This commit is contained in:
Lectom C Han
2026-03-30 21:46:15 +09:00
parent 26890dfabb
commit b4342b355f
4 changed files with 1244 additions and 98 deletions

View File

@@ -526,6 +526,9 @@ func main() {
auth.Post("/login/code/verify", authHandler.VerifyLoginCode) auth.Post("/login/code/verify", authHandler.VerifyLoginCode)
auth.Post("/login/code/verify-short", authHandler.VerifyLoginShortCode) auth.Post("/login/code/verify-short", authHandler.VerifyLoginShortCode)
auth.Post("/password/login", authHandler.PasswordLogin) auth.Post("/password/login", authHandler.PasswordLogin)
auth.Post("/headless/password/login", authHandler.HeadlessPasswordLogin)
auth.Post("/headless/link/init", authHandler.HeadlessLinkInit)
auth.Post("/headless/link/poll", authHandler.HeadlessLinkPoll)
auth.Get("/tenant-info", authHandler.GetTenantInfo) auth.Get("/tenant-info", authHandler.GetTenantInfo)
auth.Get("/consent", authHandler.GetConsentRequest) auth.Get("/consent", authHandler.GetConsentRequest)
auth.Post("/consent/accept", authHandler.AcceptConsentRequest) auth.Post("/consent/accept", authHandler.AcceptConsentRequest)

View File

@@ -25,6 +25,8 @@ import (
"strings" "strings"
"time" "time"
"github.com/go-jose/go-jose/v4"
josejwt "github.com/go-jose/go-jose/v4/jwt"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
) )
@@ -46,6 +48,7 @@ const (
prefixLoginCodeSmsOnly = "login_code_sms_only:" prefixLoginCodeSmsOnly = "login_code_sms_only:"
prefixLoginCodeQrPending = "login_code_qr_pending:" prefixLoginCodeQrPending = "login_code_qr_pending:"
prefixLoginCodeQr = "login_code_qr:" prefixLoginCodeQr = "login_code_qr:"
prefixHeadlessLinkState = "headless_link_state:"
prefixPollMeta = "poll_meta:" prefixPollMeta = "poll_meta:"
prefixQrRef = "qr_ref:" prefixQrRef = "qr_ref:"
prefixQrMeta = "qr_meta:" prefixQrMeta = "qr_meta:"
@@ -75,6 +78,7 @@ const (
loginCodeExpiration = 10 * time.Minute loginCodeExpiration = 10 * time.Minute
linkResendCooldown = 60 * time.Second linkResendCooldown = 60 * time.Second
prefixDrySend = "dry_send:" prefixDrySend = "dry_send:"
headlessJWKSFetchTTL = 5 * time.Second
) )
type AuthHandler struct { type AuthHandler struct {
@@ -100,6 +104,40 @@ type signupState struct {
ExpiresAt int64 `json:"expires_at"` // Unix timestamp ExpiresAt int64 `json:"expires_at"` // Unix timestamp
} }
type headlessLinkState struct {
ClientID string `json:"clientId"`
LoginChallenge string `json:"loginChallenge"`
LoginID string `json:"loginId"`
RedirectTo string `json:"redirectTo,omitempty"`
}
type headlessClientAssertionClaims struct {
Issuer string `json:"iss"`
Subject string `json:"sub"`
Audience headlessAssertionAud `json:"aud"`
ExpiresAt int64 `json:"exp"`
IssuedAt int64 `json:"iat,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
ID string `json:"jti,omitempty"`
}
type headlessAssertionAud []string
func (a *headlessAssertionAud) UnmarshalJSON(data []byte) error {
var single string
if err := json.Unmarshal(data, &single); err == nil {
*a = []string{single}
return nil
}
var list []string
if err := json.Unmarshal(data, &list); err != nil {
return err
}
*a = list
return nil
}
// GenerateSecureToken - Helper to generate secure random strings // GenerateSecureToken - Helper to generate secure random strings
func GenerateSecureToken(length int) string { func GenerateSecureToken(length int) string {
b := make([]byte, length) b := make([]byte, length)
@@ -1224,87 +1262,9 @@ func (h *AuthHandler) PollEnchantedLink(c *fiber.Ctx) error {
} }
if data["status"] == "approved" { if data["status"] == "approved" {
loginID := data["loginId"] _, authInfo, err := h.completeApprovedLinkLogin(c, req.PendingRef)
if loginID == "" { if err != nil {
loginID = data["login_id"] return err
}
if loginID == "" {
slog.Warn("[Poll] Approved but missing loginId", "pendingRef", req.PendingRef)
return errorJSON(c, fiber.StatusBadRequest, "Invalid session reference")
}
if h.IdpProvider == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
}
loginStrategy := h.loadLoginStrategy(req.PendingRef)
if loginStrategy == "" {
loginStrategy = loginFlowLink
}
var authInfo *domain.AuthInfo
var err error
if loginStrategy == loginFlowCode {
code, _ := h.RedisService.Get(prefixLoginCodeValue + req.PendingRef)
code = normalizeLoginCode(code)
if code == "" {
slog.Warn("[Poll] Missing login code for approved flow", "pendingRef", req.PendingRef)
return errorJSON(c, fiber.StatusBadRequest, "Login code expired")
}
flowID, _ := h.RedisService.Get(prefixLoginCode + loginID)
if flowID == "" {
return errorJSON(c, fiber.StatusNotFound, "Login flow expired")
}
authInfo, err = h.IdpProvider.VerifyLoginCode(loginID, flowID, code)
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
return errorJSON(c, fiber.StatusNotImplemented, "Login method not supported")
}
slog.Error("[Poll] IDP code verify failed", "error", err)
return errorJSON(c, fiber.StatusInternalServerError, "Failed to verify login code")
}
} else {
authInfo, err = h.IdpProvider.IssueSession(loginID)
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
return errorJSON(c, fiber.StatusNotImplemented, "Login method not supported")
}
slog.Error("[Poll] IDP session issue failed", "error", err)
return errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session")
}
}
if authInfo == nil || authInfo.SessionToken == nil || authInfo.SessionToken.JWT == "" {
return errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session")
}
c.Locals("login_id", loginID)
setSessionIDLocal(c, authInfo.SessionToken)
sessionID := extractSessionIDFromToken(authInfo.SessionToken)
if sessionID == "" && authInfo.SessionToken != nil && authInfo.SessionToken.JWT != "" {
if resolved, err := h.getKratosSessionID(authInfo.SessionToken.JWT); err == nil && resolved != "" {
sessionID = resolved
authInfo.SessionToken.SessionID = resolved
setSessionIDLocal(c, authInfo.SessionToken)
}
}
sessionData := map[string]string{
"status": statusSuccess,
"jwt": authInfo.SessionToken.JWT,
}
if sessionID != "" {
sessionData["session_id"] = sessionID
}
sessionDataJSON, _ := json.Marshal(sessionData)
h.RedisService.Set(prefixSession+req.PendingRef, string(sessionDataJSON), defaultExpiration)
h.writeLinkAuditLog(loginID, req.PendingRef, authInfo.SessionToken, c)
h.clearLoginMeta(req.PendingRef)
if loginStrategy == loginFlowCode {
h.RedisService.Delete(prefixLoginCode + loginID)
h.RedisService.Delete(prefixLoginCodePending + loginID)
h.RedisService.Delete(prefixLoginCodeSmsTarget + loginID)
h.RedisService.Delete(prefixLoginCodeSmsLookup + loginID)
h.RedisService.Delete(prefixLoginCodeValue + req.PendingRef)
} }
return c.JSON(fiber.Map{ return c.JSON(fiber.Map{
@@ -1671,6 +1631,617 @@ func logOidcRedirectSummary(source, redirectTo string) {
) )
} }
func (h *AuthHandler) authenticatePasswordLogin(ctx context.Context, loginID, password string) (*domain.AuthInfo, error) {
if h.IdpProvider == nil {
return nil, fmt.Errorf("authentication service not configured")
}
authInfo, err := h.IdpProvider.SignIn(loginID, password)
if err != nil {
return nil, err
}
subject, resolveErr := h.resolveKratosIdentityIDFromLoginID(ctx, loginID)
if resolveErr != nil || subject == "" {
slog.Error("Failed to resolve kratos identity after login", "loginID", loginID, "error", resolveErr)
return nil, fmt.Errorf("failed to resolve user identity")
}
authInfo.Subject = subject
return authInfo, nil
}
func passwordLoginErrorSpec(err error) (int, string, string) {
if err == nil {
return fiber.StatusOK, "", ""
}
if errors.Is(err, domain.ErrNotSupported) {
return fiber.StatusNotImplemented, "not_supported", "Login method not supported"
}
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "identity") {
return fiber.StatusNotFound, "not_found", "User not registered"
}
if strings.Contains(err.Error(), "failed to resolve user identity") {
return fiber.StatusInternalServerError, "internal_error", "Failed to resolve user identity"
}
return fiber.StatusUnauthorized, "password_or_email_mismatch", "Invalid credentials"
}
func headlessAssertionAudiences(c *fiber.Ctx) []string {
if c == nil {
return nil
}
path := strings.TrimSpace(c.Path())
if path == "" {
return nil
}
base := strings.TrimRight(strings.TrimSpace(c.BaseURL()), "/")
if base == "" {
return []string{path}
}
return []string{base + path, path}
}
func containsHeadlessAudience(expected []string, actual headlessAssertionAud) bool {
for _, audience := range actual {
for _, candidate := range expected {
if strings.TrimSpace(audience) == strings.TrimSpace(candidate) {
return true
}
}
}
return false
}
func (h *AuthHandler) loadHeadlessJWKS(ctx context.Context, client domain.HydraClient) (*jose.JSONWebKeySet, error) {
var raw []byte
switch {
case client.JWKS != nil:
data, err := json.Marshal(client.JWKS)
if err != nil {
return nil, fmt.Errorf("failed to encode jwks: %w", err)
}
raw = data
case strings.TrimSpace(client.JWKSUri) != "":
req, err := http.NewRequestWithContext(ctx, http.MethodGet, strings.TrimSpace(client.JWKSUri), nil)
if err != nil {
return nil, fmt.Errorf("failed to build jwks request: %w", err)
}
client := &http.Client{Timeout: headlessJWKSFetchTTL}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch jwks: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return nil, fmt.Errorf("failed to fetch jwks status=%d body=%s", resp.StatusCode, string(body))
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
if err != nil {
return nil, fmt.Errorf("failed to read jwks response: %w", err)
}
raw = body
default:
return nil, fmt.Errorf("trusted rp public key is not configured")
}
var keySet jose.JSONWebKeySet
if err := json.Unmarshal(raw, &keySet); err != nil {
return nil, fmt.Errorf("failed to decode jwks: %w", err)
}
if len(keySet.Keys) == 0 {
return nil, fmt.Errorf("trusted rp jwks has no keys")
}
return &keySet, nil
}
func validateHeadlessClientAssertionClaims(c *fiber.Ctx, claims headlessClientAssertionClaims, clientID string) error {
now := time.Now().Unix()
if claims.Issuer != clientID || claims.Subject != clientID {
return fmt.Errorf("client assertion iss/sub mismatch")
}
if claims.ExpiresAt == 0 || claims.ExpiresAt <= now {
return fmt.Errorf("client assertion expired")
}
if claims.NotBefore != 0 && claims.NotBefore > now {
return fmt.Errorf("client assertion not active yet")
}
if claims.IssuedAt != 0 && claims.IssuedAt > now+60 {
return fmt.Errorf("client assertion issued in the future")
}
if !containsHeadlessAudience(headlessAssertionAudiences(c), claims.Audience) {
return fmt.Errorf("client assertion audience mismatch")
}
return nil
}
func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.HydraClient, clientID, clientAssertion string) error {
assertion := strings.TrimSpace(clientAssertion)
if assertion == "" {
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_assertion is required")
}
keySet, err := h.loadHeadlessJWKS(c.Context(), client)
if err != nil {
slog.Error("failed to load jwks for headless client assertion", "clientID", clientID, "error", err)
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
}
token, err := josejwt.ParseSigned(assertion, []jose.SignatureAlgorithm{
jose.RS256, jose.RS384, jose.RS512,
jose.PS256, jose.PS384, jose.PS512,
jose.ES256, jose.ES384, jose.ES512,
jose.EdDSA,
})
if err != nil {
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
}
expectedKid := ""
if len(token.Headers) > 0 {
expectedKid = strings.TrimSpace(token.Headers[0].KeyID)
}
for _, key := range keySet.Keys {
if expectedKid != "" && key.KeyID != "" && key.KeyID != expectedKid {
continue
}
var claims headlessClientAssertionClaims
if err := token.Claims(key.Key, &claims); err != nil {
continue
}
if err := validateHeadlessClientAssertionClaims(c, claims, clientID); err != nil {
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
}
return nil
}
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
}
func (h *AuthHandler) storeHeadlessLinkState(pendingRef string, state headlessLinkState, ttl time.Duration) {
if h.RedisService == nil || pendingRef == "" {
return
}
raw, err := json.Marshal(state)
if err != nil {
return
}
_ = h.RedisService.Set(prefixHeadlessLinkState+pendingRef, string(raw), ttl)
}
func (h *AuthHandler) loadHeadlessLinkState(pendingRef string) (headlessLinkState, bool) {
if h.RedisService == nil || pendingRef == "" {
return headlessLinkState{}, false
}
raw, err := h.RedisService.Get(prefixHeadlessLinkState + pendingRef)
if err != nil || raw == "" {
return headlessLinkState{}, false
}
var state headlessLinkState
if err := json.Unmarshal([]byte(raw), &state); err != nil {
return headlessLinkState{}, false
}
return state, true
}
func (h *AuthHandler) completeApprovedLinkLogin(c *fiber.Ctx, pendingRef string) (string, *domain.AuthInfo, error) {
val, err := h.RedisService.Get(prefixSession + pendingRef)
if err != nil || val == "" {
return "", nil, errorJSON(c, fiber.StatusBadRequest, "Invalid session reference")
}
var data map[string]string
_ = json.Unmarshal([]byte(val), &data)
loginID := data["loginId"]
if loginID == "" {
loginID = data["login_id"]
}
if loginID == "" {
slog.Warn("[Poll] Approved but missing loginId", "pendingRef", pendingRef)
return "", nil, errorJSON(c, fiber.StatusBadRequest, "Invalid session reference")
}
if h.IdpProvider == nil {
return "", nil, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
}
loginStrategy := h.loadLoginStrategy(pendingRef)
if loginStrategy == "" {
loginStrategy = loginFlowLink
}
var authInfo *domain.AuthInfo
if loginStrategy == loginFlowCode {
code, _ := h.RedisService.Get(prefixLoginCodeValue + pendingRef)
code = normalizeLoginCode(code)
if code == "" {
slog.Warn("[Poll] Missing login code for approved flow", "pendingRef", pendingRef)
return "", nil, errorJSON(c, fiber.StatusBadRequest, "Login code expired")
}
flowID, _ := h.RedisService.Get(prefixLoginCode + loginID)
if flowID == "" {
return "", nil, errorJSON(c, fiber.StatusNotFound, "Login flow expired")
}
authInfo, err = h.IdpProvider.VerifyLoginCode(loginID, flowID, code)
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
return "", nil, errorJSON(c, fiber.StatusNotImplemented, "Login method not supported")
}
slog.Error("[Poll] IDP code verify failed", "error", err)
return "", nil, errorJSON(c, fiber.StatusInternalServerError, "Failed to verify login code")
}
} else {
authInfo, err = h.IdpProvider.IssueSession(loginID)
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
return "", nil, errorJSON(c, fiber.StatusNotImplemented, "Login method not supported")
}
slog.Error("[Poll] IDP session issue failed", "error", err)
return "", nil, errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session")
}
}
if authInfo == nil || authInfo.SessionToken == nil || authInfo.SessionToken.JWT == "" {
return "", nil, errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session")
}
c.Locals("login_id", loginID)
setSessionIDLocal(c, authInfo.SessionToken)
sessionID := extractSessionIDFromToken(authInfo.SessionToken)
if sessionID == "" && authInfo.SessionToken != nil && authInfo.SessionToken.JWT != "" {
if resolved, err := h.getKratosSessionID(authInfo.SessionToken.JWT); err == nil && resolved != "" {
sessionID = resolved
authInfo.SessionToken.SessionID = resolved
setSessionIDLocal(c, authInfo.SessionToken)
}
}
sessionData := map[string]string{
"status": statusSuccess,
"jwt": authInfo.SessionToken.JWT,
}
if sessionID != "" {
sessionData["session_id"] = sessionID
}
sessionDataJSON, _ := json.Marshal(sessionData)
_ = h.RedisService.Set(prefixSession+pendingRef, string(sessionDataJSON), defaultExpiration)
h.writeLinkAuditLog(loginID, pendingRef, authInfo.SessionToken, c)
h.clearLoginMeta(pendingRef)
if loginStrategy == loginFlowCode {
_ = h.RedisService.Delete(prefixLoginCode + loginID)
_ = h.RedisService.Delete(prefixLoginCodePending + loginID)
_ = h.RedisService.Delete(prefixLoginCodeSmsTarget + loginID)
_ = h.RedisService.Delete(prefixLoginCodeSmsLookup + loginID)
_ = h.RedisService.Delete(prefixLoginCodeValue + pendingRef)
}
return loginID, authInfo, nil
}
func (h *AuthHandler) validateHeadlessPasswordLoginClient(loginReq *domain.HydraLoginRequest, clientID string) error {
if loginReq == nil {
return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request")
}
if strings.TrimSpace(loginReq.Client.ClientID) != strings.TrimSpace(clientID) {
return fiber.NewError(fiber.StatusForbidden, "The client application is not allowed to use this login request.")
}
if metadata := loginReq.Client.Metadata; metadata != nil {
if status, ok := metadata["status"].(string); ok && strings.ToLower(status) == "inactive" {
return fiber.NewError(fiber.StatusForbidden, "The client application is disabled.")
}
}
if !loginReq.Client.IsHeadlessLoginEnabled() {
return fiber.NewError(fiber.StatusForbidden, "The client application is not allowed to use headless password login.")
}
return nil
}
func (h *AuthHandler) HeadlessPasswordLogin(c *fiber.Ctx) error {
var req struct {
ClientID string `json:"client_id"`
ClientAssertion string `json:"client_assertion"`
LoginID string `json:"loginId"`
Password string `json:"password"`
LoginChallenge string `json:"login_challenge"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "Invalid request body")
}
clientID := strings.TrimSpace(req.ClientID)
loginID := strings.TrimSpace(req.LoginID)
loginChallenge := strings.TrimSpace(req.LoginChallenge)
if clientID == "" || loginID == "" || strings.TrimSpace(req.Password) == "" || loginChallenge == "" {
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_id, loginId, password and login_challenge are required")
}
if h.IdpProvider == nil || h.Hydra == nil {
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
}
loginReq, err := h.Hydra.GetLoginRequest(c.Context(), loginChallenge)
if err != nil {
slog.Error("failed to get hydra login request for headless password login", "error", err)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request")
}
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
return err
}
if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil {
return err
}
authInfo, authErr := h.authenticatePasswordLogin(c.Context(), loginID, req.Password)
if authErr != nil {
status, code, message := passwordLoginErrorSpec(authErr)
return errorJSONCode(c, status, code, message)
}
setSessionIDLocal(c, authInfo.SessionToken)
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), loginChallenge, authInfo.Subject)
if err != nil {
slog.Error("failed to accept hydra login request in headless password login", "error", err)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request")
}
logOidcRedirectSummary("headless_password_login", acceptResp.RedirectTo)
return c.JSON(fiber.Map{
"redirectTo": acceptResp.RedirectTo,
"status": "ok",
"provider": h.IdpProvider.Name(),
})
}
func (h *AuthHandler) startHeadlessPhoneLink(c *fiber.Ctx, loginID string) (fiber.Map, string, string, time.Duration, error) {
rawLoginID := strings.ReplaceAll(loginID, "-", "")
rawLoginID = strings.ReplaceAll(rawLoginID, " ", "")
if rawLoginID == "" || strings.Contains(rawLoginID, "@") {
return nil, "", "", 0, errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "phone-based loginId is required")
}
lookupLoginID := normalizePhoneForLoginID(rawLoginID)
if h.IdpProvider == nil {
return nil, "", "", 0, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
}
exists, err := h.IdpProvider.UserExists(lookupLoginID)
if err != nil {
slog.Warn("[HeadlessLink] IDP user lookup failed", "loginID", rawLoginID, "error", err)
return nil, "", "", 0, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
}
if !exists {
slog.Warn("[HeadlessLink] User not found", "loginID", rawLoginID)
return nil, "", "", 0, errorJSON(c, fiber.StatusNotFound, "User not registered")
}
userfrontURL := h.resolveUserfrontURL(c)
if init, err := h.IdpProvider.InitiateLinkLogin(lookupLoginID, userfrontURL); err == nil && init != nil && init.Mode != "" {
keyLoginID := lookupLoginID
if init.LoginID != "" {
keyLoginID = init.LoginID
}
if init.FlowID != "" {
_ = h.RedisService.Set(prefixLoginCode+keyLoginID, init.FlowID, loginCodeExpiration)
}
pendingRef := GenerateSecureToken(3)
sessionData, _ := json.Marshal(map[string]string{
"status": statusPending,
"loginId": keyLoginID,
})
_ = h.RedisService.Set(prefixSession+pendingRef, string(sessionData), loginCodeExpiration)
h.storeLoginMeta(pendingRef, rawLoginID, "sms", loginFlowLink, loginFlowCode, loginCodeExpiration)
_ = h.RedisService.Set(prefixLoginCodePending+keyLoginID, pendingRef, loginCodeExpiration)
if keyLoginID != lookupLoginID {
_ = h.RedisService.Set(prefixLoginCodeSmsTarget+keyLoginID, lookupLoginID, loginCodeExpiration)
_ = h.RedisService.Set(prefixLoginCodeSmsLookup+lookupLoginID, keyLoginID, loginCodeExpiration)
}
expiresIn := int(loginCodeExpiration.Seconds())
if !init.ExpiresAt.IsZero() {
if seconds := int(time.Until(init.ExpiresAt).Seconds()); seconds > 0 {
expiresIn = seconds
}
}
return fiber.Map{
"pendingRef": pendingRef,
"status": "pending",
"mode": init.Mode,
"provider": h.IdpProvider.Name(),
"expiresIn": expiresIn,
"interval": int(minPollInterval.Seconds()),
"resendAfter": int(linkResendCooldown.Seconds()),
}, pendingRef, keyLoginID, loginCodeExpiration, nil
} else if err != nil && !errors.Is(err, domain.ErrNotSupported) {
slog.Error("[HeadlessLink] Link login init failed", "provider", h.IdpProvider.Name(), "error", err)
return nil, "", "", 0, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
}
if h.SmsService == nil {
return nil, "", "", 0, errorJSON(c, fiber.StatusInternalServerError, "SMS service not configured")
}
token := GenerateSecureToken(3)
pendingRef := GenerateSecureToken(3)
sessionData, _ := json.Marshal(map[string]string{
"status": statusPending,
"loginId": lookupLoginID,
})
_ = h.RedisService.Set(prefixSession+pendingRef, string(sessionData), defaultExpiration)
_ = h.RedisService.Set(prefixToken+token, fmt.Sprintf(`{"pendingRef":"%s","loginId":"%s"}`, pendingRef, lookupLoginID), defaultExpiration)
h.storeLoginMeta(pendingRef, rawLoginID, "sms", loginFlowLink, loginFlowLink, defaultExpiration)
link := fmt.Sprintf("%s/verify/%s", strings.TrimRight(userfrontURL, "/"), token)
content := fmt.Sprintf("[Baron 로그인] 로그인 링크: %s", link)
if err := h.SmsService.SendSms(rawLoginID, content); err != nil {
slog.Error("[HeadlessLink] SMS send failed", "error", err)
return nil, "", "", 0, errorJSON(c, fiber.StatusInternalServerError, "Failed to send SMS")
}
return fiber.Map{
"pendingRef": pendingRef,
"status": "pending",
"provider": h.IdpProvider.Name(),
"expiresIn": int(defaultExpiration.Seconds()),
"interval": int(minPollInterval.Seconds()),
"resendAfter": int(linkResendCooldown.Seconds()),
}, pendingRef, lookupLoginID, defaultExpiration, nil
}
func (h *AuthHandler) HeadlessLinkInit(c *fiber.Ctx) error {
var req struct {
ClientID string `json:"client_id"`
ClientAssertion string `json:"client_assertion"`
LoginID string `json:"loginId"`
LoginChallenge string `json:"login_challenge"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "Invalid request body")
}
clientID := strings.TrimSpace(req.ClientID)
loginChallenge := strings.TrimSpace(req.LoginChallenge)
loginID := strings.TrimSpace(req.LoginID)
if clientID == "" || loginChallenge == "" || loginID == "" {
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_id, client_assertion, loginId and login_challenge are required")
}
if h.Hydra == nil {
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
}
loginReq, err := h.Hydra.GetLoginRequest(c.Context(), loginChallenge)
if err != nil {
slog.Error("failed to get hydra login request for headless link init", "error", err)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request")
}
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
return err
}
if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil {
return err
}
resp, pendingRef, resolvedLoginID, ttl, err := h.startHeadlessPhoneLink(c, loginID)
if err != nil {
return err
}
h.storeHeadlessLinkState(pendingRef, headlessLinkState{
ClientID: clientID,
LoginChallenge: loginChallenge,
LoginID: resolvedLoginID,
}, ttl)
return c.JSON(resp)
}
func (h *AuthHandler) HeadlessLinkPoll(c *fiber.Ctx) error {
var req struct {
ClientID string `json:"client_id"`
ClientAssertion string `json:"client_assertion"`
PendingRef string `json:"pendingRef"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "Invalid request body")
}
clientID := strings.TrimSpace(req.ClientID)
pendingRef := strings.TrimSpace(req.PendingRef)
if clientID == "" || pendingRef == "" {
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_id, client_assertion and pendingRef are required")
}
state, ok := h.loadHeadlessLinkState(pendingRef)
if !ok {
return c.JSON(fiber.Map{
"error": "expired_token",
"code": "expired_token",
})
}
if state.ClientID != clientID {
return fiber.NewError(fiber.StatusForbidden, "The client application is not allowed to use this pending login.")
}
if state.RedirectTo != "" {
return c.JSON(fiber.Map{
"redirectTo": state.RedirectTo,
"status": "ok",
})
}
if h.Hydra == nil {
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
}
loginReq, err := h.Hydra.GetLoginRequest(c.Context(), state.LoginChallenge)
if err != nil {
slog.Error("failed to get hydra login request for headless link poll", "error", err)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request")
}
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
return err
}
if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil {
return err
}
val, err := h.RedisService.Get(prefixSession + pendingRef)
if err != nil || val == "" {
return c.JSON(fiber.Map{
"error": "expired_token",
"code": "expired_token",
})
}
var session map[string]string
_ = json.Unmarshal([]byte(val), &session)
if session["status"] == statusPending {
return c.JSON(fiber.Map{
"error": "authorization_pending",
"code": "authorization_pending",
"interval": int(minPollInterval.Seconds()),
})
}
loginID := state.LoginID
if session["status"] == "approved" {
completedLoginID, _, err := h.completeApprovedLinkLogin(c, pendingRef)
if err != nil {
return err
}
loginID = completedLoginID
}
if loginID == "" {
return errorJSON(c, fiber.StatusInternalServerError, "Failed to resolve approved user identity")
}
subject, err := h.resolveKratosIdentityIDFromLoginID(c.Context(), loginID)
if err != nil || subject == "" {
slog.Error("failed to resolve kratos identity for headless link poll", "loginID", loginID, "error", err)
return errorJSON(c, fiber.StatusInternalServerError, "Failed to resolve user identity")
}
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), state.LoginChallenge, subject)
if err != nil {
slog.Error("failed to accept hydra login request in headless link poll", "error", err)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request")
}
state.RedirectTo = acceptResp.RedirectTo
h.storeHeadlessLinkState(pendingRef, state, defaultExpiration)
logOidcRedirectSummary("headless_link_poll", acceptResp.RedirectTo)
return c.JSON(fiber.Map{
"redirectTo": acceptResp.RedirectTo,
"status": "ok",
})
}
func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error { func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
startTime := time.Now() startTime := time.Now()
ale := logger.NewAuditLogEntry(c, "login") ale := logger.NewAuditLogEntry(c, "login")
@@ -1704,28 +2275,16 @@ func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured") return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
} }
authInfo, err := h.IdpProvider.SignIn(loginID, req.Password) authInfo, err := h.authenticatePasswordLogin(c.Context(), loginID, req.Password)
if err != nil { if err != nil {
if errors.Is(err, domain.ErrNotSupported) { status, code, message := passwordLoginErrorSpec(err)
return errorJSONCode(c, fiber.StatusNotImplemented, "not_supported", "Login method not supported") ale.Status = status
}
ale.Status = fiber.StatusUnauthorized
ale.LatencyMs = time.Since(startTime) ale.LatencyMs = time.Since(startTime)
ale.ProviderError = err.Error() ale.ProviderError = err.Error()
ale.Log(slog.LevelWarn, "IDP sign-in failed", slog.String("provider", h.IdpProvider.Name())) ale.Log(slog.LevelWarn, "IDP sign-in failed", slog.String("provider", h.IdpProvider.Name()))
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "identity") { return errorJSONCode(c, status, code, message)
return errorJSONCode(c, fiber.StatusNotFound, "not_found", "User not registered")
}
return errorJSONCode(c, fiber.StatusUnauthorized, "password_or_email_mismatch", "Invalid credentials")
} }
subject, resolveErr := h.resolveKratosIdentityIDFromLoginID(c.Context(), loginID)
if resolveErr != nil || subject == "" {
slog.Error("Failed to resolve kratos identity after login", "loginID", loginID, "error", resolveErr)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to resolve user identity")
}
authInfo.Subject = subject
ale.Status = fiber.StatusOK ale.Status = fiber.StatusOK
ale.LatencyMs = time.Since(startTime) ale.LatencyMs = time.Since(startTime)
setSessionIDLocal(c, authInfo.SessionToken) setSessionIDLocal(c, authInfo.SessionToken)
@@ -1746,7 +2305,7 @@ func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
} }
} }
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), req.LoginChallenge, subject) acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), req.LoginChallenge, authInfo.Subject)
if err != nil { if err != nil {
slog.Error("failed to accept hydra login request", "error", err) slog.Error("failed to accept hydra login request", "error", err)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request") return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request")

View File

@@ -2,14 +2,17 @@ package handler
import ( import (
"baron-sso-backend/internal/domain" "baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes" "bytes"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
) )
// Mock services // Mock services
@@ -21,6 +24,14 @@ type mockSmsService struct{}
func (m *mockSmsService) SendSms(to, content string) error { return nil } func (m *mockSmsService) SendSms(to, content string) error { return nil }
func newHeadlessLinkTestApp(h *AuthHandler) *fiber.App {
app := fiber.New()
app.Post("/api/v1/auth/headless/link/init", h.HeadlessLinkInit)
app.Post("/api/v1/auth/headless/link/poll", h.HeadlessLinkPoll)
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
return app
}
func TestEnchantedLinkFlow_Email_Success(t *testing.T) { func TestEnchantedLinkFlow_Email_Success(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)} redis := &mockRedisRepo{data: make(map[string]string)}
// Force "Not Supported" for InitiateLinkLogin only to trigger custom Enchanted Link logic // Force "Not Supported" for InitiateLinkLogin only to trigger custom Enchanted Link logic
@@ -144,3 +155,162 @@ func TestPollEnchantedLink_ExpiredToken_ReturnsCode(t *testing.T) {
assert.Equal(t, "expired_token", got["error"]) assert.Equal(t, "expired_token", got["error"])
assert.Equal(t, "expired_token", got["code"]) assert.Equal(t, "expired_token", got["code"])
} }
func TestHeadlessLinkInit_TrustedClientSuccess(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
privateKey, jwks := mustHeadlessRSAJWK(t)
idp := &mockIdpProvider{
userExists: true,
initiateLinkErr: domain.ErrNotSupported,
}
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet {
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "trusted-rp",
TokenEndpointAuthMethod: "private_key_jwt",
JWKS: jwks,
Metadata: map[string]interface{}{
"status": "active",
"headless_login_enabled": true,
},
},
})
return
}
http.NotFound(w, r)
})
h := &AuthHandler{
RedisService: redis,
IdpProvider: idp,
SmsService: &mockSmsService{},
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessLinkTestApp(h)
t.Setenv("USERFRONT_URL", "http://userfront.test")
body, _ := json.Marshal(map[string]string{
"client_id": "trusted-rp",
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "trusted-rp", "http://example.com/api/v1/auth/headless/link/init"),
"loginId": "010-1234-5678",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var got map[string]interface{}
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.NotEmpty(t, got["pendingRef"])
_, hasUserCode := got["userCode"]
assert.False(t, hasUserCode)
}
func TestHeadlessLinkPoll_AfterApprovalReturnsRedirect(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
privateKey, jwks := mustHeadlessRSAJWK(t)
idp := &mockIdpProvider{
userExists: true,
initiateLinkErr: domain.ErrNotSupported,
}
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "trusted-rp",
TokenEndpointAuthMethod: "private_key_jwt",
JWKS: jwks,
Metadata: map[string]interface{}{
"status": "active",
"headless_login_enabled": true,
},
},
})
return
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
return
}
http.NotFound(w, r)
})
mockKratos := new(MockKratosAdminService)
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "+821012345678").Return("kratos-identity-id", nil)
h := &AuthHandler{
RedisService: redis,
IdpProvider: idp,
SmsService: &mockSmsService{},
KratosAdmin: mockKratos,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessLinkTestApp(h)
t.Setenv("USERFRONT_URL", "http://userfront.test")
initBody, _ := json.Marshal(map[string]string{
"client_id": "trusted-rp",
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "trusted-rp", "http://example.com/api/v1/auth/headless/link/init"),
"loginId": "010-1234-5678",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(initBody))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var initResp map[string]interface{}
_ = json.NewDecoder(resp.Body).Decode(&initResp)
pendingRef := initResp["pendingRef"].(string)
assert.NotEmpty(t, pendingRef)
var token string
for k := range redis.data {
if len(k) > 16 && k[:16] == "enchanted_token:" {
token = k[16:]
break
}
}
assert.NotEmpty(t, token)
verifyBody, _ := json.Marshal(map[string]interface{}{
"token": token,
"verifyOnly": true,
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody))
req.Header.Set("Content-Type", "application/json")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
pollBody, _ := json.Marshal(map[string]string{
"client_id": "trusted-rp",
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "trusted-rp", "http://example.com/api/v1/auth/headless/link/poll"),
"pendingRef": pendingRef,
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/poll", bytes.NewReader(pollBody))
req.Header.Set("Content-Type", "application/json")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var pollResp map[string]interface{}
_ = json.NewDecoder(resp.Body).Decode(&pollResp)
assert.Equal(t, "http://rp/cb", pollResp["redirectTo"])
assert.Equal(t, "ok", pollResp["status"])
}

View File

@@ -10,6 +10,8 @@ import (
"baron-sso-backend/internal/service" "baron-sso-backend/internal/service"
"bytes" "bytes"
"context" "context"
"crypto/rand"
"crypto/rsa"
"encoding/json" "encoding/json"
"errors" "errors"
"io" "io"
@@ -17,7 +19,10 @@ import (
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
"time"
"github.com/go-jose/go-jose/v4"
josejwt "github.com/go-jose/go-jose/v4/jwt"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
) )
@@ -121,6 +126,79 @@ func newAuthLoginTestApp(h *AuthHandler) *fiber.App {
return app return app
} }
func newHeadlessPasswordLoginTestApp(h *AuthHandler) *fiber.App {
app := fiber.New()
app.Post("/api/v1/auth/headless/password/login", h.HeadlessPasswordLogin)
return app
}
func mustHeadlessRSAJWK(t *testing.T) (*rsa.PrivateKey, map[string]any) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("failed to generate rsa key: %v", err)
}
keySet := jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
{
Key: &privateKey.PublicKey,
KeyID: "test-kid",
Use: "sig",
Algorithm: string(jose.RS256),
},
},
}
raw, err := json.Marshal(keySet)
if err != nil {
t.Fatalf("failed to marshal jwks: %v", err)
}
var jwks map[string]any
if err := json.Unmarshal(raw, &jwks); err != nil {
t.Fatalf("failed to decode jwks map: %v", err)
}
return privateKey, jwks
}
func mustHeadlessClientAssertion(t *testing.T, privateKey *rsa.PrivateKey, clientID, audience string) string {
t.Helper()
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.RS256,
Key: jose.JSONWebKey{
Key: privateKey,
KeyID: "test-kid",
Use: "sig",
Algorithm: string(jose.RS256),
},
}, nil)
if err != nil {
t.Fatalf("failed to create signer: %v", err)
}
now := time.Now()
raw, err := josejwt.Signed(signer).Claims(josejwt.Claims{
Issuer: clientID,
Subject: clientID,
Audience: josejwt.Audience{audience},
Expiry: josejwt.NewNumericDate(now.Add(5 * time.Minute)),
IssuedAt: josejwt.NewNumericDate(now),
NotBefore: josejwt.NewNumericDate(
now.Add(-1 * time.Minute),
),
ID: "assertion-1",
}).Serialize()
if err != nil {
t.Fatalf("failed to sign client assertion: %v", err)
}
return raw
}
// mockHydraTransport simulates Hydra API responses // mockHydraTransport simulates Hydra API responses
func mockHydraTransport(handler http.Handler) http.RoundTripper { func mockHydraTransport(handler http.Handler) http.RoundTripper {
return roundTripFunc(func(req *http.Request) (*http.Response, error) { return roundTripFunc(func(req *http.Request) (*http.Response, error) {
@@ -206,6 +284,342 @@ func TestPasswordLogin_OIDC_Success(t *testing.T) {
} }
} }
func TestHeadlessPasswordLogin_TrustedClientSuccess(t *testing.T) {
mockIdp := new(MockIdentityProvider)
mockIdp.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{
SessionToken: &domain.Token{JWT: "valid-jwt"},
Subject: "kratos-identity-id",
}, nil)
privateKey, jwks := mustHeadlessRSAJWK(t)
jwksBody, _ := json.Marshal(jwks)
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(jwksBody)
}))
defer jwksServer.Close()
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "trusted-rp",
TokenEndpointAuthMethod: "private_key_jwt",
JWKSUri: jwksServer.URL + "/.well-known/jwks.json",
Metadata: map[string]interface{}{
"status": "active",
"headless_login_enabled": true,
},
},
})
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
default:
http.NotFound(w, r)
}
})
mockKratos := new(MockKratosAdminService)
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil)
h := &AuthHandler{
IdpProvider: mockIdp,
KratosAdmin: mockKratos,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessPasswordLoginTestApp(h)
clientAssertion := mustHeadlessClientAssertion(
t,
privateKey,
"trusted-rp",
"http://example.com/api/v1/auth/headless/password/login",
)
body, _ := json.Marshal(map[string]string{
"client_id": "trusted-rp",
"client_assertion": clientAssertion,
"loginId": "employee001",
"password": "password",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 200, got %d, body: %s", resp.StatusCode, string(bodyBytes))
}
var got map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if got["redirectTo"] != "http://rp/cb" {
t.Fatalf("expected redirectTo http://rp/cb, got %v", got["redirectTo"])
}
if _, ok := got["sessionJwt"]; ok {
t.Fatalf("expected headless response to omit sessionJwt, got %v", got["sessionJwt"])
}
}
func TestHeadlessPasswordLogin_MissingClientAssertionRejected(t *testing.T) {
mockIdp := new(MockIdentityProvider)
mockIdp.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{
SessionToken: &domain.Token{JWT: "valid-jwt"},
Subject: "kratos-identity-id",
}, nil)
mockKratos := new(MockKratosAdminService)
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil)
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "trusted-rp",
TokenEndpointAuthMethod: "private_key_jwt",
JWKS: map[string]any{
"keys": []map[string]any{},
},
Metadata: map[string]interface{}{
"status": "active",
"headless_login_enabled": true,
},
},
})
return
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
return
}
http.NotFound(w, r)
})
h := &AuthHandler{
IdpProvider: mockIdp,
KratosAdmin: mockKratos,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessPasswordLoginTestApp(h)
body, _ := json.Marshal(map[string]string{
"client_id": "trusted-rp",
"loginId": "employee001",
"password": "password",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 400, got %d, body: %s", resp.StatusCode, string(bodyBytes))
}
}
func TestHeadlessPasswordLogin_InvalidClientAssertionRejected(t *testing.T) {
mockIdp := new(MockIdentityProvider)
mockIdp.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{
SessionToken: &domain.Token{JWT: "valid-jwt"},
Subject: "kratos-identity-id",
}, nil)
mockKratos := new(MockKratosAdminService)
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil)
validKey, jwks := mustHeadlessRSAJWK(t)
invalidKey, _ := mustHeadlessRSAJWK(t)
_ = validKey
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "trusted-rp",
TokenEndpointAuthMethod: "private_key_jwt",
JWKS: jwks,
Metadata: map[string]interface{}{
"status": "active",
"headless_login_enabled": true,
},
},
})
return
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
return
}
http.NotFound(w, r)
})
h := &AuthHandler{
IdpProvider: mockIdp,
KratosAdmin: mockKratos,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessPasswordLoginTestApp(h)
clientAssertion := mustHeadlessClientAssertion(
t,
invalidKey,
"trusted-rp",
"http://example.com/api/v1/auth/headless/password/login",
)
body, _ := json.Marshal(map[string]string{
"client_id": "trusted-rp",
"client_assertion": clientAssertion,
"loginId": "employee001",
"password": "password",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 401, got %d, body: %s", resp.StatusCode, string(bodyBytes))
}
}
func TestHeadlessPasswordLogin_HeadlessDisabledRejected(t *testing.T) {
mockIdp := new(MockIdentityProvider)
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet {
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "trusted-rp",
TokenEndpointAuthMethod: "private_key_jwt",
JWKSUri: "https://rp.example.com/.well-known/jwks.json",
Metadata: map[string]interface{}{
"status": "active",
},
},
})
return
}
http.NotFound(w, r)
})
h := &AuthHandler{
IdpProvider: mockIdp,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessPasswordLoginTestApp(h)
body, _ := json.Marshal(map[string]string{
"client_id": "trusted-rp",
"loginId": "employee001",
"password": "password",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 403, got %d, body: %s", resp.StatusCode, string(bodyBytes))
}
}
func TestHeadlessPasswordLogin_ClientIDMismatchRejected(t *testing.T) {
mockIdp := new(MockIdentityProvider)
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet {
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "other-rp",
TokenEndpointAuthMethod: "private_key_jwt",
JWKSUri: "https://rp.example.com/.well-known/jwks.json",
Metadata: map[string]interface{}{
"status": "active",
"headless_login_enabled": true,
},
},
})
return
}
http.NotFound(w, r)
})
h := &AuthHandler{
IdpProvider: mockIdp,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessPasswordLoginTestApp(h)
body, _ := json.Marshal(map[string]string{
"client_id": "trusted-rp",
"loginId": "employee001",
"password": "password",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 403, got %d, body: %s", resp.StatusCode, string(bodyBytes))
}
}
func TestPasswordLogin_OIDC_InactiveClient(t *testing.T) { func TestPasswordLogin_OIDC_InactiveClient(t *testing.T) {
mockIdp := new(MockIdentityProvider) mockIdp := new(MockIdentityProvider)
mockIdp.On("SignIn", "user@example.com", "password").Return(&domain.AuthInfo{ mockIdp.On("SignIn", "user@example.com", "password").Return(&domain.AuthInfo{