1
0
forked from baron/baron-sso

feat(auth): add trusted rp headless login flows

This commit is contained in:
Lectom C Han
2026-03-30 21:46:15 +09:00
parent 26890dfabb
commit b4342b355f
4 changed files with 1244 additions and 98 deletions

View File

@@ -526,6 +526,9 @@ func main() {
auth.Post("/login/code/verify", authHandler.VerifyLoginCode)
auth.Post("/login/code/verify-short", authHandler.VerifyLoginShortCode)
auth.Post("/password/login", authHandler.PasswordLogin)
auth.Post("/headless/password/login", authHandler.HeadlessPasswordLogin)
auth.Post("/headless/link/init", authHandler.HeadlessLinkInit)
auth.Post("/headless/link/poll", authHandler.HeadlessLinkPoll)
auth.Get("/tenant-info", authHandler.GetTenantInfo)
auth.Get("/consent", authHandler.GetConsentRequest)
auth.Post("/consent/accept", authHandler.AcceptConsentRequest)

View File

@@ -25,6 +25,8 @@ import (
"strings"
"time"
"github.com/go-jose/go-jose/v4"
josejwt "github.com/go-jose/go-jose/v4/jwt"
"github.com/gofiber/fiber/v2"
)
@@ -46,6 +48,7 @@ const (
prefixLoginCodeSmsOnly = "login_code_sms_only:"
prefixLoginCodeQrPending = "login_code_qr_pending:"
prefixLoginCodeQr = "login_code_qr:"
prefixHeadlessLinkState = "headless_link_state:"
prefixPollMeta = "poll_meta:"
prefixQrRef = "qr_ref:"
prefixQrMeta = "qr_meta:"
@@ -75,6 +78,7 @@ const (
loginCodeExpiration = 10 * time.Minute
linkResendCooldown = 60 * time.Second
prefixDrySend = "dry_send:"
headlessJWKSFetchTTL = 5 * time.Second
)
type AuthHandler struct {
@@ -100,6 +104,40 @@ type signupState struct {
ExpiresAt int64 `json:"expires_at"` // Unix timestamp
}
type headlessLinkState struct {
ClientID string `json:"clientId"`
LoginChallenge string `json:"loginChallenge"`
LoginID string `json:"loginId"`
RedirectTo string `json:"redirectTo,omitempty"`
}
type headlessClientAssertionClaims struct {
Issuer string `json:"iss"`
Subject string `json:"sub"`
Audience headlessAssertionAud `json:"aud"`
ExpiresAt int64 `json:"exp"`
IssuedAt int64 `json:"iat,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
ID string `json:"jti,omitempty"`
}
type headlessAssertionAud []string
func (a *headlessAssertionAud) UnmarshalJSON(data []byte) error {
var single string
if err := json.Unmarshal(data, &single); err == nil {
*a = []string{single}
return nil
}
var list []string
if err := json.Unmarshal(data, &list); err != nil {
return err
}
*a = list
return nil
}
// GenerateSecureToken - Helper to generate secure random strings
func GenerateSecureToken(length int) string {
b := make([]byte, length)
@@ -1224,87 +1262,9 @@ func (h *AuthHandler) PollEnchantedLink(c *fiber.Ctx) error {
}
if data["status"] == "approved" {
loginID := data["loginId"]
if loginID == "" {
loginID = data["login_id"]
}
if loginID == "" {
slog.Warn("[Poll] Approved but missing loginId", "pendingRef", req.PendingRef)
return errorJSON(c, fiber.StatusBadRequest, "Invalid session reference")
}
if h.IdpProvider == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
}
loginStrategy := h.loadLoginStrategy(req.PendingRef)
if loginStrategy == "" {
loginStrategy = loginFlowLink
}
var authInfo *domain.AuthInfo
var err error
if loginStrategy == loginFlowCode {
code, _ := h.RedisService.Get(prefixLoginCodeValue + req.PendingRef)
code = normalizeLoginCode(code)
if code == "" {
slog.Warn("[Poll] Missing login code for approved flow", "pendingRef", req.PendingRef)
return errorJSON(c, fiber.StatusBadRequest, "Login code expired")
}
flowID, _ := h.RedisService.Get(prefixLoginCode + loginID)
if flowID == "" {
return errorJSON(c, fiber.StatusNotFound, "Login flow expired")
}
authInfo, err = h.IdpProvider.VerifyLoginCode(loginID, flowID, code)
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
return errorJSON(c, fiber.StatusNotImplemented, "Login method not supported")
}
slog.Error("[Poll] IDP code verify failed", "error", err)
return errorJSON(c, fiber.StatusInternalServerError, "Failed to verify login code")
}
} else {
authInfo, err = h.IdpProvider.IssueSession(loginID)
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
return errorJSON(c, fiber.StatusNotImplemented, "Login method not supported")
}
slog.Error("[Poll] IDP session issue failed", "error", err)
return errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session")
}
}
if authInfo == nil || authInfo.SessionToken == nil || authInfo.SessionToken.JWT == "" {
return errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session")
}
c.Locals("login_id", loginID)
setSessionIDLocal(c, authInfo.SessionToken)
sessionID := extractSessionIDFromToken(authInfo.SessionToken)
if sessionID == "" && authInfo.SessionToken != nil && authInfo.SessionToken.JWT != "" {
if resolved, err := h.getKratosSessionID(authInfo.SessionToken.JWT); err == nil && resolved != "" {
sessionID = resolved
authInfo.SessionToken.SessionID = resolved
setSessionIDLocal(c, authInfo.SessionToken)
}
}
sessionData := map[string]string{
"status": statusSuccess,
"jwt": authInfo.SessionToken.JWT,
}
if sessionID != "" {
sessionData["session_id"] = sessionID
}
sessionDataJSON, _ := json.Marshal(sessionData)
h.RedisService.Set(prefixSession+req.PendingRef, string(sessionDataJSON), defaultExpiration)
h.writeLinkAuditLog(loginID, req.PendingRef, authInfo.SessionToken, c)
h.clearLoginMeta(req.PendingRef)
if loginStrategy == loginFlowCode {
h.RedisService.Delete(prefixLoginCode + loginID)
h.RedisService.Delete(prefixLoginCodePending + loginID)
h.RedisService.Delete(prefixLoginCodeSmsTarget + loginID)
h.RedisService.Delete(prefixLoginCodeSmsLookup + loginID)
h.RedisService.Delete(prefixLoginCodeValue + req.PendingRef)
_, authInfo, err := h.completeApprovedLinkLogin(c, req.PendingRef)
if err != nil {
return err
}
return c.JSON(fiber.Map{
@@ -1671,6 +1631,617 @@ func logOidcRedirectSummary(source, redirectTo string) {
)
}
func (h *AuthHandler) authenticatePasswordLogin(ctx context.Context, loginID, password string) (*domain.AuthInfo, error) {
if h.IdpProvider == nil {
return nil, fmt.Errorf("authentication service not configured")
}
authInfo, err := h.IdpProvider.SignIn(loginID, password)
if err != nil {
return nil, err
}
subject, resolveErr := h.resolveKratosIdentityIDFromLoginID(ctx, loginID)
if resolveErr != nil || subject == "" {
slog.Error("Failed to resolve kratos identity after login", "loginID", loginID, "error", resolveErr)
return nil, fmt.Errorf("failed to resolve user identity")
}
authInfo.Subject = subject
return authInfo, nil
}
func passwordLoginErrorSpec(err error) (int, string, string) {
if err == nil {
return fiber.StatusOK, "", ""
}
if errors.Is(err, domain.ErrNotSupported) {
return fiber.StatusNotImplemented, "not_supported", "Login method not supported"
}
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "identity") {
return fiber.StatusNotFound, "not_found", "User not registered"
}
if strings.Contains(err.Error(), "failed to resolve user identity") {
return fiber.StatusInternalServerError, "internal_error", "Failed to resolve user identity"
}
return fiber.StatusUnauthorized, "password_or_email_mismatch", "Invalid credentials"
}
func headlessAssertionAudiences(c *fiber.Ctx) []string {
if c == nil {
return nil
}
path := strings.TrimSpace(c.Path())
if path == "" {
return nil
}
base := strings.TrimRight(strings.TrimSpace(c.BaseURL()), "/")
if base == "" {
return []string{path}
}
return []string{base + path, path}
}
func containsHeadlessAudience(expected []string, actual headlessAssertionAud) bool {
for _, audience := range actual {
for _, candidate := range expected {
if strings.TrimSpace(audience) == strings.TrimSpace(candidate) {
return true
}
}
}
return false
}
func (h *AuthHandler) loadHeadlessJWKS(ctx context.Context, client domain.HydraClient) (*jose.JSONWebKeySet, error) {
var raw []byte
switch {
case client.JWKS != nil:
data, err := json.Marshal(client.JWKS)
if err != nil {
return nil, fmt.Errorf("failed to encode jwks: %w", err)
}
raw = data
case strings.TrimSpace(client.JWKSUri) != "":
req, err := http.NewRequestWithContext(ctx, http.MethodGet, strings.TrimSpace(client.JWKSUri), nil)
if err != nil {
return nil, fmt.Errorf("failed to build jwks request: %w", err)
}
client := &http.Client{Timeout: headlessJWKSFetchTTL}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch jwks: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return nil, fmt.Errorf("failed to fetch jwks status=%d body=%s", resp.StatusCode, string(body))
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
if err != nil {
return nil, fmt.Errorf("failed to read jwks response: %w", err)
}
raw = body
default:
return nil, fmt.Errorf("trusted rp public key is not configured")
}
var keySet jose.JSONWebKeySet
if err := json.Unmarshal(raw, &keySet); err != nil {
return nil, fmt.Errorf("failed to decode jwks: %w", err)
}
if len(keySet.Keys) == 0 {
return nil, fmt.Errorf("trusted rp jwks has no keys")
}
return &keySet, nil
}
func validateHeadlessClientAssertionClaims(c *fiber.Ctx, claims headlessClientAssertionClaims, clientID string) error {
now := time.Now().Unix()
if claims.Issuer != clientID || claims.Subject != clientID {
return fmt.Errorf("client assertion iss/sub mismatch")
}
if claims.ExpiresAt == 0 || claims.ExpiresAt <= now {
return fmt.Errorf("client assertion expired")
}
if claims.NotBefore != 0 && claims.NotBefore > now {
return fmt.Errorf("client assertion not active yet")
}
if claims.IssuedAt != 0 && claims.IssuedAt > now+60 {
return fmt.Errorf("client assertion issued in the future")
}
if !containsHeadlessAudience(headlessAssertionAudiences(c), claims.Audience) {
return fmt.Errorf("client assertion audience mismatch")
}
return nil
}
func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.HydraClient, clientID, clientAssertion string) error {
assertion := strings.TrimSpace(clientAssertion)
if assertion == "" {
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_assertion is required")
}
keySet, err := h.loadHeadlessJWKS(c.Context(), client)
if err != nil {
slog.Error("failed to load jwks for headless client assertion", "clientID", clientID, "error", err)
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
}
token, err := josejwt.ParseSigned(assertion, []jose.SignatureAlgorithm{
jose.RS256, jose.RS384, jose.RS512,
jose.PS256, jose.PS384, jose.PS512,
jose.ES256, jose.ES384, jose.ES512,
jose.EdDSA,
})
if err != nil {
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
}
expectedKid := ""
if len(token.Headers) > 0 {
expectedKid = strings.TrimSpace(token.Headers[0].KeyID)
}
for _, key := range keySet.Keys {
if expectedKid != "" && key.KeyID != "" && key.KeyID != expectedKid {
continue
}
var claims headlessClientAssertionClaims
if err := token.Claims(key.Key, &claims); err != nil {
continue
}
if err := validateHeadlessClientAssertionClaims(c, claims, clientID); err != nil {
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
}
return nil
}
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
}
func (h *AuthHandler) storeHeadlessLinkState(pendingRef string, state headlessLinkState, ttl time.Duration) {
if h.RedisService == nil || pendingRef == "" {
return
}
raw, err := json.Marshal(state)
if err != nil {
return
}
_ = h.RedisService.Set(prefixHeadlessLinkState+pendingRef, string(raw), ttl)
}
func (h *AuthHandler) loadHeadlessLinkState(pendingRef string) (headlessLinkState, bool) {
if h.RedisService == nil || pendingRef == "" {
return headlessLinkState{}, false
}
raw, err := h.RedisService.Get(prefixHeadlessLinkState + pendingRef)
if err != nil || raw == "" {
return headlessLinkState{}, false
}
var state headlessLinkState
if err := json.Unmarshal([]byte(raw), &state); err != nil {
return headlessLinkState{}, false
}
return state, true
}
func (h *AuthHandler) completeApprovedLinkLogin(c *fiber.Ctx, pendingRef string) (string, *domain.AuthInfo, error) {
val, err := h.RedisService.Get(prefixSession + pendingRef)
if err != nil || val == "" {
return "", nil, errorJSON(c, fiber.StatusBadRequest, "Invalid session reference")
}
var data map[string]string
_ = json.Unmarshal([]byte(val), &data)
loginID := data["loginId"]
if loginID == "" {
loginID = data["login_id"]
}
if loginID == "" {
slog.Warn("[Poll] Approved but missing loginId", "pendingRef", pendingRef)
return "", nil, errorJSON(c, fiber.StatusBadRequest, "Invalid session reference")
}
if h.IdpProvider == nil {
return "", nil, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
}
loginStrategy := h.loadLoginStrategy(pendingRef)
if loginStrategy == "" {
loginStrategy = loginFlowLink
}
var authInfo *domain.AuthInfo
if loginStrategy == loginFlowCode {
code, _ := h.RedisService.Get(prefixLoginCodeValue + pendingRef)
code = normalizeLoginCode(code)
if code == "" {
slog.Warn("[Poll] Missing login code for approved flow", "pendingRef", pendingRef)
return "", nil, errorJSON(c, fiber.StatusBadRequest, "Login code expired")
}
flowID, _ := h.RedisService.Get(prefixLoginCode + loginID)
if flowID == "" {
return "", nil, errorJSON(c, fiber.StatusNotFound, "Login flow expired")
}
authInfo, err = h.IdpProvider.VerifyLoginCode(loginID, flowID, code)
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
return "", nil, errorJSON(c, fiber.StatusNotImplemented, "Login method not supported")
}
slog.Error("[Poll] IDP code verify failed", "error", err)
return "", nil, errorJSON(c, fiber.StatusInternalServerError, "Failed to verify login code")
}
} else {
authInfo, err = h.IdpProvider.IssueSession(loginID)
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
return "", nil, errorJSON(c, fiber.StatusNotImplemented, "Login method not supported")
}
slog.Error("[Poll] IDP session issue failed", "error", err)
return "", nil, errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session")
}
}
if authInfo == nil || authInfo.SessionToken == nil || authInfo.SessionToken.JWT == "" {
return "", nil, errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session")
}
c.Locals("login_id", loginID)
setSessionIDLocal(c, authInfo.SessionToken)
sessionID := extractSessionIDFromToken(authInfo.SessionToken)
if sessionID == "" && authInfo.SessionToken != nil && authInfo.SessionToken.JWT != "" {
if resolved, err := h.getKratosSessionID(authInfo.SessionToken.JWT); err == nil && resolved != "" {
sessionID = resolved
authInfo.SessionToken.SessionID = resolved
setSessionIDLocal(c, authInfo.SessionToken)
}
}
sessionData := map[string]string{
"status": statusSuccess,
"jwt": authInfo.SessionToken.JWT,
}
if sessionID != "" {
sessionData["session_id"] = sessionID
}
sessionDataJSON, _ := json.Marshal(sessionData)
_ = h.RedisService.Set(prefixSession+pendingRef, string(sessionDataJSON), defaultExpiration)
h.writeLinkAuditLog(loginID, pendingRef, authInfo.SessionToken, c)
h.clearLoginMeta(pendingRef)
if loginStrategy == loginFlowCode {
_ = h.RedisService.Delete(prefixLoginCode + loginID)
_ = h.RedisService.Delete(prefixLoginCodePending + loginID)
_ = h.RedisService.Delete(prefixLoginCodeSmsTarget + loginID)
_ = h.RedisService.Delete(prefixLoginCodeSmsLookup + loginID)
_ = h.RedisService.Delete(prefixLoginCodeValue + pendingRef)
}
return loginID, authInfo, nil
}
func (h *AuthHandler) validateHeadlessPasswordLoginClient(loginReq *domain.HydraLoginRequest, clientID string) error {
if loginReq == nil {
return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request")
}
if strings.TrimSpace(loginReq.Client.ClientID) != strings.TrimSpace(clientID) {
return fiber.NewError(fiber.StatusForbidden, "The client application is not allowed to use this login request.")
}
if metadata := loginReq.Client.Metadata; metadata != nil {
if status, ok := metadata["status"].(string); ok && strings.ToLower(status) == "inactive" {
return fiber.NewError(fiber.StatusForbidden, "The client application is disabled.")
}
}
if !loginReq.Client.IsHeadlessLoginEnabled() {
return fiber.NewError(fiber.StatusForbidden, "The client application is not allowed to use headless password login.")
}
return nil
}
func (h *AuthHandler) HeadlessPasswordLogin(c *fiber.Ctx) error {
var req struct {
ClientID string `json:"client_id"`
ClientAssertion string `json:"client_assertion"`
LoginID string `json:"loginId"`
Password string `json:"password"`
LoginChallenge string `json:"login_challenge"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "Invalid request body")
}
clientID := strings.TrimSpace(req.ClientID)
loginID := strings.TrimSpace(req.LoginID)
loginChallenge := strings.TrimSpace(req.LoginChallenge)
if clientID == "" || loginID == "" || strings.TrimSpace(req.Password) == "" || loginChallenge == "" {
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_id, loginId, password and login_challenge are required")
}
if h.IdpProvider == nil || h.Hydra == nil {
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
}
loginReq, err := h.Hydra.GetLoginRequest(c.Context(), loginChallenge)
if err != nil {
slog.Error("failed to get hydra login request for headless password login", "error", err)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request")
}
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
return err
}
if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil {
return err
}
authInfo, authErr := h.authenticatePasswordLogin(c.Context(), loginID, req.Password)
if authErr != nil {
status, code, message := passwordLoginErrorSpec(authErr)
return errorJSONCode(c, status, code, message)
}
setSessionIDLocal(c, authInfo.SessionToken)
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), loginChallenge, authInfo.Subject)
if err != nil {
slog.Error("failed to accept hydra login request in headless password login", "error", err)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request")
}
logOidcRedirectSummary("headless_password_login", acceptResp.RedirectTo)
return c.JSON(fiber.Map{
"redirectTo": acceptResp.RedirectTo,
"status": "ok",
"provider": h.IdpProvider.Name(),
})
}
func (h *AuthHandler) startHeadlessPhoneLink(c *fiber.Ctx, loginID string) (fiber.Map, string, string, time.Duration, error) {
rawLoginID := strings.ReplaceAll(loginID, "-", "")
rawLoginID = strings.ReplaceAll(rawLoginID, " ", "")
if rawLoginID == "" || strings.Contains(rawLoginID, "@") {
return nil, "", "", 0, errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "phone-based loginId is required")
}
lookupLoginID := normalizePhoneForLoginID(rawLoginID)
if h.IdpProvider == nil {
return nil, "", "", 0, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
}
exists, err := h.IdpProvider.UserExists(lookupLoginID)
if err != nil {
slog.Warn("[HeadlessLink] IDP user lookup failed", "loginID", rawLoginID, "error", err)
return nil, "", "", 0, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
}
if !exists {
slog.Warn("[HeadlessLink] User not found", "loginID", rawLoginID)
return nil, "", "", 0, errorJSON(c, fiber.StatusNotFound, "User not registered")
}
userfrontURL := h.resolveUserfrontURL(c)
if init, err := h.IdpProvider.InitiateLinkLogin(lookupLoginID, userfrontURL); err == nil && init != nil && init.Mode != "" {
keyLoginID := lookupLoginID
if init.LoginID != "" {
keyLoginID = init.LoginID
}
if init.FlowID != "" {
_ = h.RedisService.Set(prefixLoginCode+keyLoginID, init.FlowID, loginCodeExpiration)
}
pendingRef := GenerateSecureToken(3)
sessionData, _ := json.Marshal(map[string]string{
"status": statusPending,
"loginId": keyLoginID,
})
_ = h.RedisService.Set(prefixSession+pendingRef, string(sessionData), loginCodeExpiration)
h.storeLoginMeta(pendingRef, rawLoginID, "sms", loginFlowLink, loginFlowCode, loginCodeExpiration)
_ = h.RedisService.Set(prefixLoginCodePending+keyLoginID, pendingRef, loginCodeExpiration)
if keyLoginID != lookupLoginID {
_ = h.RedisService.Set(prefixLoginCodeSmsTarget+keyLoginID, lookupLoginID, loginCodeExpiration)
_ = h.RedisService.Set(prefixLoginCodeSmsLookup+lookupLoginID, keyLoginID, loginCodeExpiration)
}
expiresIn := int(loginCodeExpiration.Seconds())
if !init.ExpiresAt.IsZero() {
if seconds := int(time.Until(init.ExpiresAt).Seconds()); seconds > 0 {
expiresIn = seconds
}
}
return fiber.Map{
"pendingRef": pendingRef,
"status": "pending",
"mode": init.Mode,
"provider": h.IdpProvider.Name(),
"expiresIn": expiresIn,
"interval": int(minPollInterval.Seconds()),
"resendAfter": int(linkResendCooldown.Seconds()),
}, pendingRef, keyLoginID, loginCodeExpiration, nil
} else if err != nil && !errors.Is(err, domain.ErrNotSupported) {
slog.Error("[HeadlessLink] Link login init failed", "provider", h.IdpProvider.Name(), "error", err)
return nil, "", "", 0, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
}
if h.SmsService == nil {
return nil, "", "", 0, errorJSON(c, fiber.StatusInternalServerError, "SMS service not configured")
}
token := GenerateSecureToken(3)
pendingRef := GenerateSecureToken(3)
sessionData, _ := json.Marshal(map[string]string{
"status": statusPending,
"loginId": lookupLoginID,
})
_ = h.RedisService.Set(prefixSession+pendingRef, string(sessionData), defaultExpiration)
_ = h.RedisService.Set(prefixToken+token, fmt.Sprintf(`{"pendingRef":"%s","loginId":"%s"}`, pendingRef, lookupLoginID), defaultExpiration)
h.storeLoginMeta(pendingRef, rawLoginID, "sms", loginFlowLink, loginFlowLink, defaultExpiration)
link := fmt.Sprintf("%s/verify/%s", strings.TrimRight(userfrontURL, "/"), token)
content := fmt.Sprintf("[Baron 로그인] 로그인 링크: %s", link)
if err := h.SmsService.SendSms(rawLoginID, content); err != nil {
slog.Error("[HeadlessLink] SMS send failed", "error", err)
return nil, "", "", 0, errorJSON(c, fiber.StatusInternalServerError, "Failed to send SMS")
}
return fiber.Map{
"pendingRef": pendingRef,
"status": "pending",
"provider": h.IdpProvider.Name(),
"expiresIn": int(defaultExpiration.Seconds()),
"interval": int(minPollInterval.Seconds()),
"resendAfter": int(linkResendCooldown.Seconds()),
}, pendingRef, lookupLoginID, defaultExpiration, nil
}
func (h *AuthHandler) HeadlessLinkInit(c *fiber.Ctx) error {
var req struct {
ClientID string `json:"client_id"`
ClientAssertion string `json:"client_assertion"`
LoginID string `json:"loginId"`
LoginChallenge string `json:"login_challenge"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "Invalid request body")
}
clientID := strings.TrimSpace(req.ClientID)
loginChallenge := strings.TrimSpace(req.LoginChallenge)
loginID := strings.TrimSpace(req.LoginID)
if clientID == "" || loginChallenge == "" || loginID == "" {
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_id, client_assertion, loginId and login_challenge are required")
}
if h.Hydra == nil {
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
}
loginReq, err := h.Hydra.GetLoginRequest(c.Context(), loginChallenge)
if err != nil {
slog.Error("failed to get hydra login request for headless link init", "error", err)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request")
}
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
return err
}
if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil {
return err
}
resp, pendingRef, resolvedLoginID, ttl, err := h.startHeadlessPhoneLink(c, loginID)
if err != nil {
return err
}
h.storeHeadlessLinkState(pendingRef, headlessLinkState{
ClientID: clientID,
LoginChallenge: loginChallenge,
LoginID: resolvedLoginID,
}, ttl)
return c.JSON(resp)
}
func (h *AuthHandler) HeadlessLinkPoll(c *fiber.Ctx) error {
var req struct {
ClientID string `json:"client_id"`
ClientAssertion string `json:"client_assertion"`
PendingRef string `json:"pendingRef"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "Invalid request body")
}
clientID := strings.TrimSpace(req.ClientID)
pendingRef := strings.TrimSpace(req.PendingRef)
if clientID == "" || pendingRef == "" {
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_id, client_assertion and pendingRef are required")
}
state, ok := h.loadHeadlessLinkState(pendingRef)
if !ok {
return c.JSON(fiber.Map{
"error": "expired_token",
"code": "expired_token",
})
}
if state.ClientID != clientID {
return fiber.NewError(fiber.StatusForbidden, "The client application is not allowed to use this pending login.")
}
if state.RedirectTo != "" {
return c.JSON(fiber.Map{
"redirectTo": state.RedirectTo,
"status": "ok",
})
}
if h.Hydra == nil {
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
}
loginReq, err := h.Hydra.GetLoginRequest(c.Context(), state.LoginChallenge)
if err != nil {
slog.Error("failed to get hydra login request for headless link poll", "error", err)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request")
}
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
return err
}
if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil {
return err
}
val, err := h.RedisService.Get(prefixSession + pendingRef)
if err != nil || val == "" {
return c.JSON(fiber.Map{
"error": "expired_token",
"code": "expired_token",
})
}
var session map[string]string
_ = json.Unmarshal([]byte(val), &session)
if session["status"] == statusPending {
return c.JSON(fiber.Map{
"error": "authorization_pending",
"code": "authorization_pending",
"interval": int(minPollInterval.Seconds()),
})
}
loginID := state.LoginID
if session["status"] == "approved" {
completedLoginID, _, err := h.completeApprovedLinkLogin(c, pendingRef)
if err != nil {
return err
}
loginID = completedLoginID
}
if loginID == "" {
return errorJSON(c, fiber.StatusInternalServerError, "Failed to resolve approved user identity")
}
subject, err := h.resolveKratosIdentityIDFromLoginID(c.Context(), loginID)
if err != nil || subject == "" {
slog.Error("failed to resolve kratos identity for headless link poll", "loginID", loginID, "error", err)
return errorJSON(c, fiber.StatusInternalServerError, "Failed to resolve user identity")
}
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), state.LoginChallenge, subject)
if err != nil {
slog.Error("failed to accept hydra login request in headless link poll", "error", err)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request")
}
state.RedirectTo = acceptResp.RedirectTo
h.storeHeadlessLinkState(pendingRef, state, defaultExpiration)
logOidcRedirectSummary("headless_link_poll", acceptResp.RedirectTo)
return c.JSON(fiber.Map{
"redirectTo": acceptResp.RedirectTo,
"status": "ok",
})
}
func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
startTime := time.Now()
ale := logger.NewAuditLogEntry(c, "login")
@@ -1704,28 +2275,16 @@ func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
}
authInfo, err := h.IdpProvider.SignIn(loginID, req.Password)
authInfo, err := h.authenticatePasswordLogin(c.Context(), loginID, req.Password)
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
return errorJSONCode(c, fiber.StatusNotImplemented, "not_supported", "Login method not supported")
}
ale.Status = fiber.StatusUnauthorized
status, code, message := passwordLoginErrorSpec(err)
ale.Status = status
ale.LatencyMs = time.Since(startTime)
ale.ProviderError = err.Error()
ale.Log(slog.LevelWarn, "IDP sign-in failed", slog.String("provider", h.IdpProvider.Name()))
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "identity") {
return errorJSONCode(c, fiber.StatusNotFound, "not_found", "User not registered")
}
return errorJSONCode(c, fiber.StatusUnauthorized, "password_or_email_mismatch", "Invalid credentials")
return errorJSONCode(c, status, code, message)
}
subject, resolveErr := h.resolveKratosIdentityIDFromLoginID(c.Context(), loginID)
if resolveErr != nil || subject == "" {
slog.Error("Failed to resolve kratos identity after login", "loginID", loginID, "error", resolveErr)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to resolve user identity")
}
authInfo.Subject = subject
ale.Status = fiber.StatusOK
ale.LatencyMs = time.Since(startTime)
setSessionIDLocal(c, authInfo.SessionToken)
@@ -1746,7 +2305,7 @@ func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
}
}
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), req.LoginChallenge, subject)
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), req.LoginChallenge, authInfo.Subject)
if err != nil {
slog.Error("failed to accept hydra login request", "error", err)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request")

View File

@@ -2,14 +2,17 @@ package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// Mock services
@@ -21,6 +24,14 @@ type mockSmsService struct{}
func (m *mockSmsService) SendSms(to, content string) error { return nil }
func newHeadlessLinkTestApp(h *AuthHandler) *fiber.App {
app := fiber.New()
app.Post("/api/v1/auth/headless/link/init", h.HeadlessLinkInit)
app.Post("/api/v1/auth/headless/link/poll", h.HeadlessLinkPoll)
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
return app
}
func TestEnchantedLinkFlow_Email_Success(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
// Force "Not Supported" for InitiateLinkLogin only to trigger custom Enchanted Link logic
@@ -144,3 +155,162 @@ func TestPollEnchantedLink_ExpiredToken_ReturnsCode(t *testing.T) {
assert.Equal(t, "expired_token", got["error"])
assert.Equal(t, "expired_token", got["code"])
}
func TestHeadlessLinkInit_TrustedClientSuccess(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
privateKey, jwks := mustHeadlessRSAJWK(t)
idp := &mockIdpProvider{
userExists: true,
initiateLinkErr: domain.ErrNotSupported,
}
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet {
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "trusted-rp",
TokenEndpointAuthMethod: "private_key_jwt",
JWKS: jwks,
Metadata: map[string]interface{}{
"status": "active",
"headless_login_enabled": true,
},
},
})
return
}
http.NotFound(w, r)
})
h := &AuthHandler{
RedisService: redis,
IdpProvider: idp,
SmsService: &mockSmsService{},
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessLinkTestApp(h)
t.Setenv("USERFRONT_URL", "http://userfront.test")
body, _ := json.Marshal(map[string]string{
"client_id": "trusted-rp",
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "trusted-rp", "http://example.com/api/v1/auth/headless/link/init"),
"loginId": "010-1234-5678",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var got map[string]interface{}
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.NotEmpty(t, got["pendingRef"])
_, hasUserCode := got["userCode"]
assert.False(t, hasUserCode)
}
func TestHeadlessLinkPoll_AfterApprovalReturnsRedirect(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
privateKey, jwks := mustHeadlessRSAJWK(t)
idp := &mockIdpProvider{
userExists: true,
initiateLinkErr: domain.ErrNotSupported,
}
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "trusted-rp",
TokenEndpointAuthMethod: "private_key_jwt",
JWKS: jwks,
Metadata: map[string]interface{}{
"status": "active",
"headless_login_enabled": true,
},
},
})
return
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
return
}
http.NotFound(w, r)
})
mockKratos := new(MockKratosAdminService)
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "+821012345678").Return("kratos-identity-id", nil)
h := &AuthHandler{
RedisService: redis,
IdpProvider: idp,
SmsService: &mockSmsService{},
KratosAdmin: mockKratos,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessLinkTestApp(h)
t.Setenv("USERFRONT_URL", "http://userfront.test")
initBody, _ := json.Marshal(map[string]string{
"client_id": "trusted-rp",
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "trusted-rp", "http://example.com/api/v1/auth/headless/link/init"),
"loginId": "010-1234-5678",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(initBody))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var initResp map[string]interface{}
_ = json.NewDecoder(resp.Body).Decode(&initResp)
pendingRef := initResp["pendingRef"].(string)
assert.NotEmpty(t, pendingRef)
var token string
for k := range redis.data {
if len(k) > 16 && k[:16] == "enchanted_token:" {
token = k[16:]
break
}
}
assert.NotEmpty(t, token)
verifyBody, _ := json.Marshal(map[string]interface{}{
"token": token,
"verifyOnly": true,
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody))
req.Header.Set("Content-Type", "application/json")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
pollBody, _ := json.Marshal(map[string]string{
"client_id": "trusted-rp",
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "trusted-rp", "http://example.com/api/v1/auth/headless/link/poll"),
"pendingRef": pendingRef,
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/poll", bytes.NewReader(pollBody))
req.Header.Set("Content-Type", "application/json")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var pollResp map[string]interface{}
_ = json.NewDecoder(resp.Body).Decode(&pollResp)
assert.Equal(t, "http://rp/cb", pollResp["redirectTo"])
assert.Equal(t, "ok", pollResp["status"])
}

View File

@@ -10,6 +10,8 @@ import (
"baron-sso-backend/internal/service"
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"errors"
"io"
@@ -17,7 +19,10 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/go-jose/go-jose/v4"
josejwt "github.com/go-jose/go-jose/v4/jwt"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/mock"
)
@@ -121,6 +126,79 @@ func newAuthLoginTestApp(h *AuthHandler) *fiber.App {
return app
}
func newHeadlessPasswordLoginTestApp(h *AuthHandler) *fiber.App {
app := fiber.New()
app.Post("/api/v1/auth/headless/password/login", h.HeadlessPasswordLogin)
return app
}
func mustHeadlessRSAJWK(t *testing.T) (*rsa.PrivateKey, map[string]any) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("failed to generate rsa key: %v", err)
}
keySet := jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
{
Key: &privateKey.PublicKey,
KeyID: "test-kid",
Use: "sig",
Algorithm: string(jose.RS256),
},
},
}
raw, err := json.Marshal(keySet)
if err != nil {
t.Fatalf("failed to marshal jwks: %v", err)
}
var jwks map[string]any
if err := json.Unmarshal(raw, &jwks); err != nil {
t.Fatalf("failed to decode jwks map: %v", err)
}
return privateKey, jwks
}
func mustHeadlessClientAssertion(t *testing.T, privateKey *rsa.PrivateKey, clientID, audience string) string {
t.Helper()
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.RS256,
Key: jose.JSONWebKey{
Key: privateKey,
KeyID: "test-kid",
Use: "sig",
Algorithm: string(jose.RS256),
},
}, nil)
if err != nil {
t.Fatalf("failed to create signer: %v", err)
}
now := time.Now()
raw, err := josejwt.Signed(signer).Claims(josejwt.Claims{
Issuer: clientID,
Subject: clientID,
Audience: josejwt.Audience{audience},
Expiry: josejwt.NewNumericDate(now.Add(5 * time.Minute)),
IssuedAt: josejwt.NewNumericDate(now),
NotBefore: josejwt.NewNumericDate(
now.Add(-1 * time.Minute),
),
ID: "assertion-1",
}).Serialize()
if err != nil {
t.Fatalf("failed to sign client assertion: %v", err)
}
return raw
}
// mockHydraTransport simulates Hydra API responses
func mockHydraTransport(handler http.Handler) http.RoundTripper {
return roundTripFunc(func(req *http.Request) (*http.Response, error) {
@@ -206,6 +284,342 @@ func TestPasswordLogin_OIDC_Success(t *testing.T) {
}
}
func TestHeadlessPasswordLogin_TrustedClientSuccess(t *testing.T) {
mockIdp := new(MockIdentityProvider)
mockIdp.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{
SessionToken: &domain.Token{JWT: "valid-jwt"},
Subject: "kratos-identity-id",
}, nil)
privateKey, jwks := mustHeadlessRSAJWK(t)
jwksBody, _ := json.Marshal(jwks)
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(jwksBody)
}))
defer jwksServer.Close()
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "trusted-rp",
TokenEndpointAuthMethod: "private_key_jwt",
JWKSUri: jwksServer.URL + "/.well-known/jwks.json",
Metadata: map[string]interface{}{
"status": "active",
"headless_login_enabled": true,
},
},
})
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
default:
http.NotFound(w, r)
}
})
mockKratos := new(MockKratosAdminService)
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil)
h := &AuthHandler{
IdpProvider: mockIdp,
KratosAdmin: mockKratos,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessPasswordLoginTestApp(h)
clientAssertion := mustHeadlessClientAssertion(
t,
privateKey,
"trusted-rp",
"http://example.com/api/v1/auth/headless/password/login",
)
body, _ := json.Marshal(map[string]string{
"client_id": "trusted-rp",
"client_assertion": clientAssertion,
"loginId": "employee001",
"password": "password",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 200, got %d, body: %s", resp.StatusCode, string(bodyBytes))
}
var got map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if got["redirectTo"] != "http://rp/cb" {
t.Fatalf("expected redirectTo http://rp/cb, got %v", got["redirectTo"])
}
if _, ok := got["sessionJwt"]; ok {
t.Fatalf("expected headless response to omit sessionJwt, got %v", got["sessionJwt"])
}
}
func TestHeadlessPasswordLogin_MissingClientAssertionRejected(t *testing.T) {
mockIdp := new(MockIdentityProvider)
mockIdp.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{
SessionToken: &domain.Token{JWT: "valid-jwt"},
Subject: "kratos-identity-id",
}, nil)
mockKratos := new(MockKratosAdminService)
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil)
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "trusted-rp",
TokenEndpointAuthMethod: "private_key_jwt",
JWKS: map[string]any{
"keys": []map[string]any{},
},
Metadata: map[string]interface{}{
"status": "active",
"headless_login_enabled": true,
},
},
})
return
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
return
}
http.NotFound(w, r)
})
h := &AuthHandler{
IdpProvider: mockIdp,
KratosAdmin: mockKratos,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessPasswordLoginTestApp(h)
body, _ := json.Marshal(map[string]string{
"client_id": "trusted-rp",
"loginId": "employee001",
"password": "password",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 400, got %d, body: %s", resp.StatusCode, string(bodyBytes))
}
}
func TestHeadlessPasswordLogin_InvalidClientAssertionRejected(t *testing.T) {
mockIdp := new(MockIdentityProvider)
mockIdp.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{
SessionToken: &domain.Token{JWT: "valid-jwt"},
Subject: "kratos-identity-id",
}, nil)
mockKratos := new(MockKratosAdminService)
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil)
validKey, jwks := mustHeadlessRSAJWK(t)
invalidKey, _ := mustHeadlessRSAJWK(t)
_ = validKey
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "trusted-rp",
TokenEndpointAuthMethod: "private_key_jwt",
JWKS: jwks,
Metadata: map[string]interface{}{
"status": "active",
"headless_login_enabled": true,
},
},
})
return
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
return
}
http.NotFound(w, r)
})
h := &AuthHandler{
IdpProvider: mockIdp,
KratosAdmin: mockKratos,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessPasswordLoginTestApp(h)
clientAssertion := mustHeadlessClientAssertion(
t,
invalidKey,
"trusted-rp",
"http://example.com/api/v1/auth/headless/password/login",
)
body, _ := json.Marshal(map[string]string{
"client_id": "trusted-rp",
"client_assertion": clientAssertion,
"loginId": "employee001",
"password": "password",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 401, got %d, body: %s", resp.StatusCode, string(bodyBytes))
}
}
func TestHeadlessPasswordLogin_HeadlessDisabledRejected(t *testing.T) {
mockIdp := new(MockIdentityProvider)
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet {
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "trusted-rp",
TokenEndpointAuthMethod: "private_key_jwt",
JWKSUri: "https://rp.example.com/.well-known/jwks.json",
Metadata: map[string]interface{}{
"status": "active",
},
},
})
return
}
http.NotFound(w, r)
})
h := &AuthHandler{
IdpProvider: mockIdp,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessPasswordLoginTestApp(h)
body, _ := json.Marshal(map[string]string{
"client_id": "trusted-rp",
"loginId": "employee001",
"password": "password",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 403, got %d, body: %s", resp.StatusCode, string(bodyBytes))
}
}
func TestHeadlessPasswordLogin_ClientIDMismatchRejected(t *testing.T) {
mockIdp := new(MockIdentityProvider)
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet {
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "other-rp",
TokenEndpointAuthMethod: "private_key_jwt",
JWKSUri: "https://rp.example.com/.well-known/jwks.json",
Metadata: map[string]interface{}{
"status": "active",
"headless_login_enabled": true,
},
},
})
return
}
http.NotFound(w, r)
})
h := &AuthHandler{
IdpProvider: mockIdp,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessPasswordLoginTestApp(h)
body, _ := json.Marshal(map[string]string{
"client_id": "trusted-rp",
"loginId": "employee001",
"password": "password",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 403, got %d, body: %s", resp.StatusCode, string(bodyBytes))
}
}
func TestPasswordLogin_OIDC_InactiveClient(t *testing.T) {
mockIdp := new(MockIdentityProvider)
mockIdp.On("SignIn", "user@example.com", "password").Return(&domain.AuthInfo{