diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 83f976f0..61ff20d2 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -526,6 +526,9 @@ func main() { auth.Post("/login/code/verify", authHandler.VerifyLoginCode) auth.Post("/login/code/verify-short", authHandler.VerifyLoginShortCode) auth.Post("/password/login", authHandler.PasswordLogin) + auth.Post("/headless/password/login", authHandler.HeadlessPasswordLogin) + auth.Post("/headless/link/init", authHandler.HeadlessLinkInit) + auth.Post("/headless/link/poll", authHandler.HeadlessLinkPoll) auth.Get("/tenant-info", authHandler.GetTenantInfo) auth.Get("/consent", authHandler.GetConsentRequest) auth.Post("/consent/accept", authHandler.AcceptConsentRequest) diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 401fd4a0..bd961dd8 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -25,6 +25,8 @@ import ( "strings" "time" + "github.com/go-jose/go-jose/v4" + josejwt "github.com/go-jose/go-jose/v4/jwt" "github.com/gofiber/fiber/v2" ) @@ -46,6 +48,7 @@ const ( prefixLoginCodeSmsOnly = "login_code_sms_only:" prefixLoginCodeQrPending = "login_code_qr_pending:" prefixLoginCodeQr = "login_code_qr:" + prefixHeadlessLinkState = "headless_link_state:" prefixPollMeta = "poll_meta:" prefixQrRef = "qr_ref:" prefixQrMeta = "qr_meta:" @@ -75,6 +78,7 @@ const ( loginCodeExpiration = 10 * time.Minute linkResendCooldown = 60 * time.Second prefixDrySend = "dry_send:" + headlessJWKSFetchTTL = 5 * time.Second ) type AuthHandler struct { @@ -100,6 +104,40 @@ type signupState struct { ExpiresAt int64 `json:"expires_at"` // Unix timestamp } +type headlessLinkState struct { + ClientID string `json:"clientId"` + LoginChallenge string `json:"loginChallenge"` + LoginID string `json:"loginId"` + RedirectTo string `json:"redirectTo,omitempty"` +} + +type headlessClientAssertionClaims struct { + Issuer string `json:"iss"` + Subject string `json:"sub"` + Audience headlessAssertionAud `json:"aud"` + ExpiresAt int64 `json:"exp"` + IssuedAt int64 `json:"iat,omitempty"` + NotBefore int64 `json:"nbf,omitempty"` + ID string `json:"jti,omitempty"` +} + +type headlessAssertionAud []string + +func (a *headlessAssertionAud) UnmarshalJSON(data []byte) error { + var single string + if err := json.Unmarshal(data, &single); err == nil { + *a = []string{single} + return nil + } + + var list []string + if err := json.Unmarshal(data, &list); err != nil { + return err + } + *a = list + return nil +} + // GenerateSecureToken - Helper to generate secure random strings func GenerateSecureToken(length int) string { b := make([]byte, length) @@ -1224,87 +1262,9 @@ func (h *AuthHandler) PollEnchantedLink(c *fiber.Ctx) error { } if data["status"] == "approved" { - loginID := data["loginId"] - if loginID == "" { - loginID = data["login_id"] - } - if loginID == "" { - slog.Warn("[Poll] Approved but missing loginId", "pendingRef", req.PendingRef) - return errorJSON(c, fiber.StatusBadRequest, "Invalid session reference") - } - if h.IdpProvider == nil { - return errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable") - } - - loginStrategy := h.loadLoginStrategy(req.PendingRef) - if loginStrategy == "" { - loginStrategy = loginFlowLink - } - - var authInfo *domain.AuthInfo - var err error - if loginStrategy == loginFlowCode { - code, _ := h.RedisService.Get(prefixLoginCodeValue + req.PendingRef) - code = normalizeLoginCode(code) - if code == "" { - slog.Warn("[Poll] Missing login code for approved flow", "pendingRef", req.PendingRef) - return errorJSON(c, fiber.StatusBadRequest, "Login code expired") - } - flowID, _ := h.RedisService.Get(prefixLoginCode + loginID) - if flowID == "" { - return errorJSON(c, fiber.StatusNotFound, "Login flow expired") - } - authInfo, err = h.IdpProvider.VerifyLoginCode(loginID, flowID, code) - if err != nil { - if errors.Is(err, domain.ErrNotSupported) { - return errorJSON(c, fiber.StatusNotImplemented, "Login method not supported") - } - slog.Error("[Poll] IDP code verify failed", "error", err) - return errorJSON(c, fiber.StatusInternalServerError, "Failed to verify login code") - } - } else { - authInfo, err = h.IdpProvider.IssueSession(loginID) - if err != nil { - if errors.Is(err, domain.ErrNotSupported) { - return errorJSON(c, fiber.StatusNotImplemented, "Login method not supported") - } - slog.Error("[Poll] IDP session issue failed", "error", err) - return errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session") - } - } - if authInfo == nil || authInfo.SessionToken == nil || authInfo.SessionToken.JWT == "" { - return errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session") - } - - c.Locals("login_id", loginID) - setSessionIDLocal(c, authInfo.SessionToken) - sessionID := extractSessionIDFromToken(authInfo.SessionToken) - if sessionID == "" && authInfo.SessionToken != nil && authInfo.SessionToken.JWT != "" { - if resolved, err := h.getKratosSessionID(authInfo.SessionToken.JWT); err == nil && resolved != "" { - sessionID = resolved - authInfo.SessionToken.SessionID = resolved - setSessionIDLocal(c, authInfo.SessionToken) - } - } - - sessionData := map[string]string{ - "status": statusSuccess, - "jwt": authInfo.SessionToken.JWT, - } - if sessionID != "" { - sessionData["session_id"] = sessionID - } - sessionDataJSON, _ := json.Marshal(sessionData) - h.RedisService.Set(prefixSession+req.PendingRef, string(sessionDataJSON), defaultExpiration) - - h.writeLinkAuditLog(loginID, req.PendingRef, authInfo.SessionToken, c) - h.clearLoginMeta(req.PendingRef) - if loginStrategy == loginFlowCode { - h.RedisService.Delete(prefixLoginCode + loginID) - h.RedisService.Delete(prefixLoginCodePending + loginID) - h.RedisService.Delete(prefixLoginCodeSmsTarget + loginID) - h.RedisService.Delete(prefixLoginCodeSmsLookup + loginID) - h.RedisService.Delete(prefixLoginCodeValue + req.PendingRef) + _, authInfo, err := h.completeApprovedLinkLogin(c, req.PendingRef) + if err != nil { + return err } return c.JSON(fiber.Map{ @@ -1671,6 +1631,617 @@ func logOidcRedirectSummary(source, redirectTo string) { ) } +func (h *AuthHandler) authenticatePasswordLogin(ctx context.Context, loginID, password string) (*domain.AuthInfo, error) { + if h.IdpProvider == nil { + return nil, fmt.Errorf("authentication service not configured") + } + + authInfo, err := h.IdpProvider.SignIn(loginID, password) + if err != nil { + return nil, err + } + + subject, resolveErr := h.resolveKratosIdentityIDFromLoginID(ctx, loginID) + if resolveErr != nil || subject == "" { + slog.Error("Failed to resolve kratos identity after login", "loginID", loginID, "error", resolveErr) + return nil, fmt.Errorf("failed to resolve user identity") + } + + authInfo.Subject = subject + return authInfo, nil +} + +func passwordLoginErrorSpec(err error) (int, string, string) { + if err == nil { + return fiber.StatusOK, "", "" + } + if errors.Is(err, domain.ErrNotSupported) { + return fiber.StatusNotImplemented, "not_supported", "Login method not supported" + } + if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "identity") { + return fiber.StatusNotFound, "not_found", "User not registered" + } + if strings.Contains(err.Error(), "failed to resolve user identity") { + return fiber.StatusInternalServerError, "internal_error", "Failed to resolve user identity" + } + return fiber.StatusUnauthorized, "password_or_email_mismatch", "Invalid credentials" +} + +func headlessAssertionAudiences(c *fiber.Ctx) []string { + if c == nil { + return nil + } + + path := strings.TrimSpace(c.Path()) + if path == "" { + return nil + } + + base := strings.TrimRight(strings.TrimSpace(c.BaseURL()), "/") + if base == "" { + return []string{path} + } + + return []string{base + path, path} +} + +func containsHeadlessAudience(expected []string, actual headlessAssertionAud) bool { + for _, audience := range actual { + for _, candidate := range expected { + if strings.TrimSpace(audience) == strings.TrimSpace(candidate) { + return true + } + } + } + return false +} + +func (h *AuthHandler) loadHeadlessJWKS(ctx context.Context, client domain.HydraClient) (*jose.JSONWebKeySet, error) { + var raw []byte + switch { + case client.JWKS != nil: + data, err := json.Marshal(client.JWKS) + if err != nil { + return nil, fmt.Errorf("failed to encode jwks: %w", err) + } + raw = data + case strings.TrimSpace(client.JWKSUri) != "": + req, err := http.NewRequestWithContext(ctx, http.MethodGet, strings.TrimSpace(client.JWKSUri), nil) + if err != nil { + return nil, fmt.Errorf("failed to build jwks request: %w", err) + } + client := &http.Client{Timeout: headlessJWKSFetchTTL} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to fetch jwks: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode >= 300 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048)) + return nil, fmt.Errorf("failed to fetch jwks status=%d body=%s", resp.StatusCode, string(body)) + } + body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024)) + if err != nil { + return nil, fmt.Errorf("failed to read jwks response: %w", err) + } + raw = body + default: + return nil, fmt.Errorf("trusted rp public key is not configured") + } + + var keySet jose.JSONWebKeySet + if err := json.Unmarshal(raw, &keySet); err != nil { + return nil, fmt.Errorf("failed to decode jwks: %w", err) + } + if len(keySet.Keys) == 0 { + return nil, fmt.Errorf("trusted rp jwks has no keys") + } + return &keySet, nil +} + +func validateHeadlessClientAssertionClaims(c *fiber.Ctx, claims headlessClientAssertionClaims, clientID string) error { + now := time.Now().Unix() + if claims.Issuer != clientID || claims.Subject != clientID { + return fmt.Errorf("client assertion iss/sub mismatch") + } + if claims.ExpiresAt == 0 || claims.ExpiresAt <= now { + return fmt.Errorf("client assertion expired") + } + if claims.NotBefore != 0 && claims.NotBefore > now { + return fmt.Errorf("client assertion not active yet") + } + if claims.IssuedAt != 0 && claims.IssuedAt > now+60 { + return fmt.Errorf("client assertion issued in the future") + } + if !containsHeadlessAudience(headlessAssertionAudiences(c), claims.Audience) { + return fmt.Errorf("client assertion audience mismatch") + } + return nil +} + +func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.HydraClient, clientID, clientAssertion string) error { + assertion := strings.TrimSpace(clientAssertion) + if assertion == "" { + return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_assertion is required") + } + + keySet, err := h.loadHeadlessJWKS(c.Context(), client) + if err != nil { + slog.Error("failed to load jwks for headless client assertion", "clientID", clientID, "error", err) + return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion") + } + + token, err := josejwt.ParseSigned(assertion, []jose.SignatureAlgorithm{ + jose.RS256, jose.RS384, jose.RS512, + jose.PS256, jose.PS384, jose.PS512, + jose.ES256, jose.ES384, jose.ES512, + jose.EdDSA, + }) + if err != nil { + return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion") + } + + expectedKid := "" + if len(token.Headers) > 0 { + expectedKid = strings.TrimSpace(token.Headers[0].KeyID) + } + + for _, key := range keySet.Keys { + if expectedKid != "" && key.KeyID != "" && key.KeyID != expectedKid { + continue + } + + var claims headlessClientAssertionClaims + if err := token.Claims(key.Key, &claims); err != nil { + continue + } + if err := validateHeadlessClientAssertionClaims(c, claims, clientID); err != nil { + return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion") + } + return nil + } + + return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion") +} + +func (h *AuthHandler) storeHeadlessLinkState(pendingRef string, state headlessLinkState, ttl time.Duration) { + if h.RedisService == nil || pendingRef == "" { + return + } + raw, err := json.Marshal(state) + if err != nil { + return + } + _ = h.RedisService.Set(prefixHeadlessLinkState+pendingRef, string(raw), ttl) +} + +func (h *AuthHandler) loadHeadlessLinkState(pendingRef string) (headlessLinkState, bool) { + if h.RedisService == nil || pendingRef == "" { + return headlessLinkState{}, false + } + raw, err := h.RedisService.Get(prefixHeadlessLinkState + pendingRef) + if err != nil || raw == "" { + return headlessLinkState{}, false + } + var state headlessLinkState + if err := json.Unmarshal([]byte(raw), &state); err != nil { + return headlessLinkState{}, false + } + return state, true +} + +func (h *AuthHandler) completeApprovedLinkLogin(c *fiber.Ctx, pendingRef string) (string, *domain.AuthInfo, error) { + val, err := h.RedisService.Get(prefixSession + pendingRef) + if err != nil || val == "" { + return "", nil, errorJSON(c, fiber.StatusBadRequest, "Invalid session reference") + } + + var data map[string]string + _ = json.Unmarshal([]byte(val), &data) + loginID := data["loginId"] + if loginID == "" { + loginID = data["login_id"] + } + if loginID == "" { + slog.Warn("[Poll] Approved but missing loginId", "pendingRef", pendingRef) + return "", nil, errorJSON(c, fiber.StatusBadRequest, "Invalid session reference") + } + if h.IdpProvider == nil { + return "", nil, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable") + } + + loginStrategy := h.loadLoginStrategy(pendingRef) + if loginStrategy == "" { + loginStrategy = loginFlowLink + } + + var authInfo *domain.AuthInfo + if loginStrategy == loginFlowCode { + code, _ := h.RedisService.Get(prefixLoginCodeValue + pendingRef) + code = normalizeLoginCode(code) + if code == "" { + slog.Warn("[Poll] Missing login code for approved flow", "pendingRef", pendingRef) + return "", nil, errorJSON(c, fiber.StatusBadRequest, "Login code expired") + } + flowID, _ := h.RedisService.Get(prefixLoginCode + loginID) + if flowID == "" { + return "", nil, errorJSON(c, fiber.StatusNotFound, "Login flow expired") + } + authInfo, err = h.IdpProvider.VerifyLoginCode(loginID, flowID, code) + if err != nil { + if errors.Is(err, domain.ErrNotSupported) { + return "", nil, errorJSON(c, fiber.StatusNotImplemented, "Login method not supported") + } + slog.Error("[Poll] IDP code verify failed", "error", err) + return "", nil, errorJSON(c, fiber.StatusInternalServerError, "Failed to verify login code") + } + } else { + authInfo, err = h.IdpProvider.IssueSession(loginID) + if err != nil { + if errors.Is(err, domain.ErrNotSupported) { + return "", nil, errorJSON(c, fiber.StatusNotImplemented, "Login method not supported") + } + slog.Error("[Poll] IDP session issue failed", "error", err) + return "", nil, errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session") + } + } + if authInfo == nil || authInfo.SessionToken == nil || authInfo.SessionToken.JWT == "" { + return "", nil, errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session") + } + + c.Locals("login_id", loginID) + setSessionIDLocal(c, authInfo.SessionToken) + sessionID := extractSessionIDFromToken(authInfo.SessionToken) + if sessionID == "" && authInfo.SessionToken != nil && authInfo.SessionToken.JWT != "" { + if resolved, err := h.getKratosSessionID(authInfo.SessionToken.JWT); err == nil && resolved != "" { + sessionID = resolved + authInfo.SessionToken.SessionID = resolved + setSessionIDLocal(c, authInfo.SessionToken) + } + } + + sessionData := map[string]string{ + "status": statusSuccess, + "jwt": authInfo.SessionToken.JWT, + } + if sessionID != "" { + sessionData["session_id"] = sessionID + } + sessionDataJSON, _ := json.Marshal(sessionData) + _ = h.RedisService.Set(prefixSession+pendingRef, string(sessionDataJSON), defaultExpiration) + + h.writeLinkAuditLog(loginID, pendingRef, authInfo.SessionToken, c) + h.clearLoginMeta(pendingRef) + if loginStrategy == loginFlowCode { + _ = h.RedisService.Delete(prefixLoginCode + loginID) + _ = h.RedisService.Delete(prefixLoginCodePending + loginID) + _ = h.RedisService.Delete(prefixLoginCodeSmsTarget + loginID) + _ = h.RedisService.Delete(prefixLoginCodeSmsLookup + loginID) + _ = h.RedisService.Delete(prefixLoginCodeValue + pendingRef) + } + + return loginID, authInfo, nil +} + +func (h *AuthHandler) validateHeadlessPasswordLoginClient(loginReq *domain.HydraLoginRequest, clientID string) error { + if loginReq == nil { + return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request") + } + + if strings.TrimSpace(loginReq.Client.ClientID) != strings.TrimSpace(clientID) { + return fiber.NewError(fiber.StatusForbidden, "The client application is not allowed to use this login request.") + } + + if metadata := loginReq.Client.Metadata; metadata != nil { + if status, ok := metadata["status"].(string); ok && strings.ToLower(status) == "inactive" { + return fiber.NewError(fiber.StatusForbidden, "The client application is disabled.") + } + } + + if !loginReq.Client.IsHeadlessLoginEnabled() { + return fiber.NewError(fiber.StatusForbidden, "The client application is not allowed to use headless password login.") + } + + return nil +} + +func (h *AuthHandler) HeadlessPasswordLogin(c *fiber.Ctx) error { + var req struct { + ClientID string `json:"client_id"` + ClientAssertion string `json:"client_assertion"` + LoginID string `json:"loginId"` + Password string `json:"password"` + LoginChallenge string `json:"login_challenge"` + } + + if err := c.BodyParser(&req); err != nil { + return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "Invalid request body") + } + + clientID := strings.TrimSpace(req.ClientID) + loginID := strings.TrimSpace(req.LoginID) + loginChallenge := strings.TrimSpace(req.LoginChallenge) + if clientID == "" || loginID == "" || strings.TrimSpace(req.Password) == "" || loginChallenge == "" { + return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_id, loginId, password and login_challenge are required") + } + + if h.IdpProvider == nil || h.Hydra == nil { + return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured") + } + + loginReq, err := h.Hydra.GetLoginRequest(c.Context(), loginChallenge) + if err != nil { + slog.Error("failed to get hydra login request for headless password login", "error", err) + return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request") + } + if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil { + return err + } + if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil { + return err + } + + authInfo, authErr := h.authenticatePasswordLogin(c.Context(), loginID, req.Password) + if authErr != nil { + status, code, message := passwordLoginErrorSpec(authErr) + return errorJSONCode(c, status, code, message) + } + + setSessionIDLocal(c, authInfo.SessionToken) + + acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), loginChallenge, authInfo.Subject) + if err != nil { + slog.Error("failed to accept hydra login request in headless password login", "error", err) + return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request") + } + + logOidcRedirectSummary("headless_password_login", acceptResp.RedirectTo) + return c.JSON(fiber.Map{ + "redirectTo": acceptResp.RedirectTo, + "status": "ok", + "provider": h.IdpProvider.Name(), + }) +} + +func (h *AuthHandler) startHeadlessPhoneLink(c *fiber.Ctx, loginID string) (fiber.Map, string, string, time.Duration, error) { + rawLoginID := strings.ReplaceAll(loginID, "-", "") + rawLoginID = strings.ReplaceAll(rawLoginID, " ", "") + if rawLoginID == "" || strings.Contains(rawLoginID, "@") { + return nil, "", "", 0, errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "phone-based loginId is required") + } + + lookupLoginID := normalizePhoneForLoginID(rawLoginID) + if h.IdpProvider == nil { + return nil, "", "", 0, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable") + } + exists, err := h.IdpProvider.UserExists(lookupLoginID) + if err != nil { + slog.Warn("[HeadlessLink] IDP user lookup failed", "loginID", rawLoginID, "error", err) + return nil, "", "", 0, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable") + } + if !exists { + slog.Warn("[HeadlessLink] User not found", "loginID", rawLoginID) + return nil, "", "", 0, errorJSON(c, fiber.StatusNotFound, "User not registered") + } + + userfrontURL := h.resolveUserfrontURL(c) + if init, err := h.IdpProvider.InitiateLinkLogin(lookupLoginID, userfrontURL); err == nil && init != nil && init.Mode != "" { + keyLoginID := lookupLoginID + if init.LoginID != "" { + keyLoginID = init.LoginID + } + if init.FlowID != "" { + _ = h.RedisService.Set(prefixLoginCode+keyLoginID, init.FlowID, loginCodeExpiration) + } + pendingRef := GenerateSecureToken(3) + sessionData, _ := json.Marshal(map[string]string{ + "status": statusPending, + "loginId": keyLoginID, + }) + _ = h.RedisService.Set(prefixSession+pendingRef, string(sessionData), loginCodeExpiration) + h.storeLoginMeta(pendingRef, rawLoginID, "sms", loginFlowLink, loginFlowCode, loginCodeExpiration) + _ = h.RedisService.Set(prefixLoginCodePending+keyLoginID, pendingRef, loginCodeExpiration) + if keyLoginID != lookupLoginID { + _ = h.RedisService.Set(prefixLoginCodeSmsTarget+keyLoginID, lookupLoginID, loginCodeExpiration) + _ = h.RedisService.Set(prefixLoginCodeSmsLookup+lookupLoginID, keyLoginID, loginCodeExpiration) + } + expiresIn := int(loginCodeExpiration.Seconds()) + if !init.ExpiresAt.IsZero() { + if seconds := int(time.Until(init.ExpiresAt).Seconds()); seconds > 0 { + expiresIn = seconds + } + } + return fiber.Map{ + "pendingRef": pendingRef, + "status": "pending", + "mode": init.Mode, + "provider": h.IdpProvider.Name(), + "expiresIn": expiresIn, + "interval": int(minPollInterval.Seconds()), + "resendAfter": int(linkResendCooldown.Seconds()), + }, pendingRef, keyLoginID, loginCodeExpiration, nil + } else if err != nil && !errors.Is(err, domain.ErrNotSupported) { + slog.Error("[HeadlessLink] Link login init failed", "provider", h.IdpProvider.Name(), "error", err) + return nil, "", "", 0, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable") + } + + if h.SmsService == nil { + return nil, "", "", 0, errorJSON(c, fiber.StatusInternalServerError, "SMS service not configured") + } + + token := GenerateSecureToken(3) + pendingRef := GenerateSecureToken(3) + sessionData, _ := json.Marshal(map[string]string{ + "status": statusPending, + "loginId": lookupLoginID, + }) + _ = h.RedisService.Set(prefixSession+pendingRef, string(sessionData), defaultExpiration) + _ = h.RedisService.Set(prefixToken+token, fmt.Sprintf(`{"pendingRef":"%s","loginId":"%s"}`, pendingRef, lookupLoginID), defaultExpiration) + h.storeLoginMeta(pendingRef, rawLoginID, "sms", loginFlowLink, loginFlowLink, defaultExpiration) + + link := fmt.Sprintf("%s/verify/%s", strings.TrimRight(userfrontURL, "/"), token) + content := fmt.Sprintf("[Baron 로그인] 로그인 링크: %s", link) + if err := h.SmsService.SendSms(rawLoginID, content); err != nil { + slog.Error("[HeadlessLink] SMS send failed", "error", err) + return nil, "", "", 0, errorJSON(c, fiber.StatusInternalServerError, "Failed to send SMS") + } + + return fiber.Map{ + "pendingRef": pendingRef, + "status": "pending", + "provider": h.IdpProvider.Name(), + "expiresIn": int(defaultExpiration.Seconds()), + "interval": int(minPollInterval.Seconds()), + "resendAfter": int(linkResendCooldown.Seconds()), + }, pendingRef, lookupLoginID, defaultExpiration, nil +} + +func (h *AuthHandler) HeadlessLinkInit(c *fiber.Ctx) error { + var req struct { + ClientID string `json:"client_id"` + ClientAssertion string `json:"client_assertion"` + LoginID string `json:"loginId"` + LoginChallenge string `json:"login_challenge"` + } + + if err := c.BodyParser(&req); err != nil { + return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "Invalid request body") + } + + clientID := strings.TrimSpace(req.ClientID) + loginChallenge := strings.TrimSpace(req.LoginChallenge) + loginID := strings.TrimSpace(req.LoginID) + if clientID == "" || loginChallenge == "" || loginID == "" { + return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_id, client_assertion, loginId and login_challenge are required") + } + if h.Hydra == nil { + return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured") + } + + loginReq, err := h.Hydra.GetLoginRequest(c.Context(), loginChallenge) + if err != nil { + slog.Error("failed to get hydra login request for headless link init", "error", err) + return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request") + } + if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil { + return err + } + if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil { + return err + } + + resp, pendingRef, resolvedLoginID, ttl, err := h.startHeadlessPhoneLink(c, loginID) + if err != nil { + return err + } + h.storeHeadlessLinkState(pendingRef, headlessLinkState{ + ClientID: clientID, + LoginChallenge: loginChallenge, + LoginID: resolvedLoginID, + }, ttl) + return c.JSON(resp) +} + +func (h *AuthHandler) HeadlessLinkPoll(c *fiber.Ctx) error { + var req struct { + ClientID string `json:"client_id"` + ClientAssertion string `json:"client_assertion"` + PendingRef string `json:"pendingRef"` + } + + if err := c.BodyParser(&req); err != nil { + return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "Invalid request body") + } + + clientID := strings.TrimSpace(req.ClientID) + pendingRef := strings.TrimSpace(req.PendingRef) + if clientID == "" || pendingRef == "" { + return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_id, client_assertion and pendingRef are required") + } + + state, ok := h.loadHeadlessLinkState(pendingRef) + if !ok { + return c.JSON(fiber.Map{ + "error": "expired_token", + "code": "expired_token", + }) + } + if state.ClientID != clientID { + return fiber.NewError(fiber.StatusForbidden, "The client application is not allowed to use this pending login.") + } + if state.RedirectTo != "" { + return c.JSON(fiber.Map{ + "redirectTo": state.RedirectTo, + "status": "ok", + }) + } + if h.Hydra == nil { + return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured") + } + + loginReq, err := h.Hydra.GetLoginRequest(c.Context(), state.LoginChallenge) + if err != nil { + slog.Error("failed to get hydra login request for headless link poll", "error", err) + return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request") + } + if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil { + return err + } + if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil { + return err + } + + val, err := h.RedisService.Get(prefixSession + pendingRef) + if err != nil || val == "" { + return c.JSON(fiber.Map{ + "error": "expired_token", + "code": "expired_token", + }) + } + + var session map[string]string + _ = json.Unmarshal([]byte(val), &session) + if session["status"] == statusPending { + return c.JSON(fiber.Map{ + "error": "authorization_pending", + "code": "authorization_pending", + "interval": int(minPollInterval.Seconds()), + }) + } + + loginID := state.LoginID + if session["status"] == "approved" { + completedLoginID, _, err := h.completeApprovedLinkLogin(c, pendingRef) + if err != nil { + return err + } + loginID = completedLoginID + } + + if loginID == "" { + return errorJSON(c, fiber.StatusInternalServerError, "Failed to resolve approved user identity") + } + subject, err := h.resolveKratosIdentityIDFromLoginID(c.Context(), loginID) + if err != nil || subject == "" { + slog.Error("failed to resolve kratos identity for headless link poll", "loginID", loginID, "error", err) + return errorJSON(c, fiber.StatusInternalServerError, "Failed to resolve user identity") + } + + acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), state.LoginChallenge, subject) + if err != nil { + slog.Error("failed to accept hydra login request in headless link poll", "error", err) + return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request") + } + + state.RedirectTo = acceptResp.RedirectTo + h.storeHeadlessLinkState(pendingRef, state, defaultExpiration) + logOidcRedirectSummary("headless_link_poll", acceptResp.RedirectTo) + return c.JSON(fiber.Map{ + "redirectTo": acceptResp.RedirectTo, + "status": "ok", + }) +} + func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error { startTime := time.Now() ale := logger.NewAuditLogEntry(c, "login") @@ -1704,28 +2275,16 @@ func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error { return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured") } - authInfo, err := h.IdpProvider.SignIn(loginID, req.Password) + authInfo, err := h.authenticatePasswordLogin(c.Context(), loginID, req.Password) if err != nil { - if errors.Is(err, domain.ErrNotSupported) { - return errorJSONCode(c, fiber.StatusNotImplemented, "not_supported", "Login method not supported") - } - ale.Status = fiber.StatusUnauthorized + status, code, message := passwordLoginErrorSpec(err) + ale.Status = status ale.LatencyMs = time.Since(startTime) ale.ProviderError = err.Error() ale.Log(slog.LevelWarn, "IDP sign-in failed", slog.String("provider", h.IdpProvider.Name())) - if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "identity") { - return errorJSONCode(c, fiber.StatusNotFound, "not_found", "User not registered") - } - return errorJSONCode(c, fiber.StatusUnauthorized, "password_or_email_mismatch", "Invalid credentials") + return errorJSONCode(c, status, code, message) } - subject, resolveErr := h.resolveKratosIdentityIDFromLoginID(c.Context(), loginID) - if resolveErr != nil || subject == "" { - slog.Error("Failed to resolve kratos identity after login", "loginID", loginID, "error", resolveErr) - return fiber.NewError(fiber.StatusInternalServerError, "Failed to resolve user identity") - } - authInfo.Subject = subject - ale.Status = fiber.StatusOK ale.LatencyMs = time.Since(startTime) setSessionIDLocal(c, authInfo.SessionToken) @@ -1746,7 +2305,7 @@ func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error { } } - acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), req.LoginChallenge, subject) + acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), req.LoginChallenge, authInfo.Subject) if err != nil { slog.Error("failed to accept hydra login request", "error", err) return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request") diff --git a/backend/internal/handler/auth_handler_link_test.go b/backend/internal/handler/auth_handler_link_test.go index 4a51cc6e..08b663d3 100644 --- a/backend/internal/handler/auth_handler_link_test.go +++ b/backend/internal/handler/auth_handler_link_test.go @@ -2,14 +2,17 @@ package handler import ( "baron-sso-backend/internal/domain" + "baron-sso-backend/internal/service" "bytes" "encoding/json" "net/http" "net/http/httptest" + "strings" "testing" "github.com/gofiber/fiber/v2" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) // Mock services @@ -21,6 +24,14 @@ type mockSmsService struct{} func (m *mockSmsService) SendSms(to, content string) error { return nil } +func newHeadlessLinkTestApp(h *AuthHandler) *fiber.App { + app := fiber.New() + app.Post("/api/v1/auth/headless/link/init", h.HeadlessLinkInit) + app.Post("/api/v1/auth/headless/link/poll", h.HeadlessLinkPoll) + app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink) + return app +} + func TestEnchantedLinkFlow_Email_Success(t *testing.T) { redis := &mockRedisRepo{data: make(map[string]string)} // Force "Not Supported" for InitiateLinkLogin only to trigger custom Enchanted Link logic @@ -144,3 +155,162 @@ func TestPollEnchantedLink_ExpiredToken_ReturnsCode(t *testing.T) { assert.Equal(t, "expired_token", got["error"]) assert.Equal(t, "expired_token", got["code"]) } + +func TestHeadlessLinkInit_TrustedClientSuccess(t *testing.T) { + redis := &mockRedisRepo{data: make(map[string]string)} + privateKey, jwks := mustHeadlessRSAJWK(t) + + idp := &mockIdpProvider{ + userExists: true, + initiateLinkErr: domain.ErrNotSupported, + } + + hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet { + _ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{ + Challenge: "challenge-123", + Client: domain.HydraClient{ + ClientID: "trusted-rp", + TokenEndpointAuthMethod: "private_key_jwt", + JWKS: jwks, + Metadata: map[string]interface{}{ + "status": "active", + "headless_login_enabled": true, + }, + }, + }) + return + } + http.NotFound(w, r) + }) + + h := &AuthHandler{ + RedisService: redis, + IdpProvider: idp, + SmsService: &mockSmsService{}, + Hydra: &service.HydraAdminService{ + AdminURL: "http://hydra.test", + HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)}, + }, + } + + app := newHeadlessLinkTestApp(h) + t.Setenv("USERFRONT_URL", "http://userfront.test") + + body, _ := json.Marshal(map[string]string{ + "client_id": "trusted-rp", + "client_assertion": mustHeadlessClientAssertion(t, privateKey, "trusted-rp", "http://example.com/api/v1/auth/headless/link/init"), + "loginId": "010-1234-5678", + "login_challenge": "challenge-123", + }) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + resp, _ := app.Test(req, -1) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var got map[string]interface{} + _ = json.NewDecoder(resp.Body).Decode(&got) + assert.NotEmpty(t, got["pendingRef"]) + _, hasUserCode := got["userCode"] + assert.False(t, hasUserCode) +} + +func TestHeadlessLinkPoll_AfterApprovalReturnsRedirect(t *testing.T) { + redis := &mockRedisRepo{data: make(map[string]string)} + privateKey, jwks := mustHeadlessRSAJWK(t) + + idp := &mockIdpProvider{ + userExists: true, + initiateLinkErr: domain.ErrNotSupported, + } + + hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet: + _ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{ + Challenge: "challenge-123", + Client: domain.HydraClient{ + ClientID: "trusted-rp", + TokenEndpointAuthMethod: "private_key_jwt", + JWKS: jwks, + Metadata: map[string]interface{}{ + "status": "active", + "headless_login_enabled": true, + }, + }, + }) + return + case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut: + _ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"}) + return + } + http.NotFound(w, r) + }) + + mockKratos := new(MockKratosAdminService) + mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "+821012345678").Return("kratos-identity-id", nil) + + h := &AuthHandler{ + RedisService: redis, + IdpProvider: idp, + SmsService: &mockSmsService{}, + KratosAdmin: mockKratos, + Hydra: &service.HydraAdminService{ + AdminURL: "http://hydra.test", + HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)}, + }, + } + + app := newHeadlessLinkTestApp(h) + t.Setenv("USERFRONT_URL", "http://userfront.test") + + initBody, _ := json.Marshal(map[string]string{ + "client_id": "trusted-rp", + "client_assertion": mustHeadlessClientAssertion(t, privateKey, "trusted-rp", "http://example.com/api/v1/auth/headless/link/init"), + "loginId": "010-1234-5678", + "login_challenge": "challenge-123", + }) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(initBody)) + req.Header.Set("Content-Type", "application/json") + resp, _ := app.Test(req, -1) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var initResp map[string]interface{} + _ = json.NewDecoder(resp.Body).Decode(&initResp) + pendingRef := initResp["pendingRef"].(string) + assert.NotEmpty(t, pendingRef) + + var token string + for k := range redis.data { + if len(k) > 16 && k[:16] == "enchanted_token:" { + token = k[16:] + break + } + } + assert.NotEmpty(t, token) + + verifyBody, _ := json.Marshal(map[string]interface{}{ + "token": token, + "verifyOnly": true, + }) + req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody)) + req.Header.Set("Content-Type", "application/json") + resp, _ = app.Test(req, -1) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + pollBody, _ := json.Marshal(map[string]string{ + "client_id": "trusted-rp", + "client_assertion": mustHeadlessClientAssertion(t, privateKey, "trusted-rp", "http://example.com/api/v1/auth/headless/link/poll"), + "pendingRef": pendingRef, + }) + req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/poll", bytes.NewReader(pollBody)) + req.Header.Set("Content-Type", "application/json") + resp, _ = app.Test(req, -1) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + var pollResp map[string]interface{} + _ = json.NewDecoder(resp.Body).Decode(&pollResp) + assert.Equal(t, "http://rp/cb", pollResp["redirectTo"]) + assert.Equal(t, "ok", pollResp["status"]) +} diff --git a/backend/internal/handler/auth_handler_login_test.go b/backend/internal/handler/auth_handler_login_test.go index 6741696c..4cdf17c5 100644 --- a/backend/internal/handler/auth_handler_login_test.go +++ b/backend/internal/handler/auth_handler_login_test.go @@ -10,6 +10,8 @@ import ( "baron-sso-backend/internal/service" "bytes" "context" + "crypto/rand" + "crypto/rsa" "encoding/json" "errors" "io" @@ -17,7 +19,10 @@ import ( "net/http/httptest" "strings" "testing" + "time" + "github.com/go-jose/go-jose/v4" + josejwt "github.com/go-jose/go-jose/v4/jwt" "github.com/gofiber/fiber/v2" "github.com/stretchr/testify/mock" ) @@ -121,6 +126,79 @@ func newAuthLoginTestApp(h *AuthHandler) *fiber.App { return app } +func newHeadlessPasswordLoginTestApp(h *AuthHandler) *fiber.App { + app := fiber.New() + app.Post("/api/v1/auth/headless/password/login", h.HeadlessPasswordLogin) + return app +} + +func mustHeadlessRSAJWK(t *testing.T) (*rsa.PrivateKey, map[string]any) { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate rsa key: %v", err) + } + + keySet := jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + Key: &privateKey.PublicKey, + KeyID: "test-kid", + Use: "sig", + Algorithm: string(jose.RS256), + }, + }, + } + + raw, err := json.Marshal(keySet) + if err != nil { + t.Fatalf("failed to marshal jwks: %v", err) + } + + var jwks map[string]any + if err := json.Unmarshal(raw, &jwks); err != nil { + t.Fatalf("failed to decode jwks map: %v", err) + } + + return privateKey, jwks +} + +func mustHeadlessClientAssertion(t *testing.T, privateKey *rsa.PrivateKey, clientID, audience string) string { + t.Helper() + + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.RS256, + Key: jose.JSONWebKey{ + Key: privateKey, + KeyID: "test-kid", + Use: "sig", + Algorithm: string(jose.RS256), + }, + }, nil) + if err != nil { + t.Fatalf("failed to create signer: %v", err) + } + + now := time.Now() + raw, err := josejwt.Signed(signer).Claims(josejwt.Claims{ + Issuer: clientID, + Subject: clientID, + Audience: josejwt.Audience{audience}, + Expiry: josejwt.NewNumericDate(now.Add(5 * time.Minute)), + IssuedAt: josejwt.NewNumericDate(now), + NotBefore: josejwt.NewNumericDate( + now.Add(-1 * time.Minute), + ), + ID: "assertion-1", + }).Serialize() + if err != nil { + t.Fatalf("failed to sign client assertion: %v", err) + } + + return raw +} + // mockHydraTransport simulates Hydra API responses func mockHydraTransport(handler http.Handler) http.RoundTripper { return roundTripFunc(func(req *http.Request) (*http.Response, error) { @@ -206,6 +284,342 @@ func TestPasswordLogin_OIDC_Success(t *testing.T) { } } +func TestHeadlessPasswordLogin_TrustedClientSuccess(t *testing.T) { + mockIdp := new(MockIdentityProvider) + mockIdp.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{ + SessionToken: &domain.Token{JWT: "valid-jwt"}, + Subject: "kratos-identity-id", + }, nil) + + privateKey, jwks := mustHeadlessRSAJWK(t) + jwksBody, _ := json.Marshal(jwks) + jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(jwksBody) + })) + defer jwksServer.Close() + + hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet: + json.NewEncoder(w).Encode(domain.HydraLoginRequest{ + Challenge: "challenge-123", + Client: domain.HydraClient{ + ClientID: "trusted-rp", + TokenEndpointAuthMethod: "private_key_jwt", + JWKSUri: jwksServer.URL + "/.well-known/jwks.json", + Metadata: map[string]interface{}{ + "status": "active", + "headless_login_enabled": true, + }, + }, + }) + case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut: + json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"}) + default: + http.NotFound(w, r) + } + }) + + mockKratos := new(MockKratosAdminService) + mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil) + + h := &AuthHandler{ + IdpProvider: mockIdp, + KratosAdmin: mockKratos, + Hydra: &service.HydraAdminService{ + AdminURL: "http://hydra.test", + HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)}, + }, + } + + app := newHeadlessPasswordLoginTestApp(h) + + clientAssertion := mustHeadlessClientAssertion( + t, + privateKey, + "trusted-rp", + "http://example.com/api/v1/auth/headless/password/login", + ) + body, _ := json.Marshal(map[string]string{ + "client_id": "trusted-rp", + "client_assertion": clientAssertion, + "loginId": "employee001", + "password": "password", + "login_challenge": "challenge-123", + }) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp, err := app.Test(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 200, got %d, body: %s", resp.StatusCode, string(bodyBytes)) + } + + var got map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&got); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if got["redirectTo"] != "http://rp/cb" { + t.Fatalf("expected redirectTo http://rp/cb, got %v", got["redirectTo"]) + } + if _, ok := got["sessionJwt"]; ok { + t.Fatalf("expected headless response to omit sessionJwt, got %v", got["sessionJwt"]) + } +} + +func TestHeadlessPasswordLogin_MissingClientAssertionRejected(t *testing.T) { + mockIdp := new(MockIdentityProvider) + mockIdp.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{ + SessionToken: &domain.Token{JWT: "valid-jwt"}, + Subject: "kratos-identity-id", + }, nil) + mockKratos := new(MockKratosAdminService) + mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil) + + hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet: + json.NewEncoder(w).Encode(domain.HydraLoginRequest{ + Challenge: "challenge-123", + Client: domain.HydraClient{ + ClientID: "trusted-rp", + TokenEndpointAuthMethod: "private_key_jwt", + JWKS: map[string]any{ + "keys": []map[string]any{}, + }, + Metadata: map[string]interface{}{ + "status": "active", + "headless_login_enabled": true, + }, + }, + }) + return + case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut: + json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"}) + return + } + http.NotFound(w, r) + }) + + h := &AuthHandler{ + IdpProvider: mockIdp, + KratosAdmin: mockKratos, + Hydra: &service.HydraAdminService{ + AdminURL: "http://hydra.test", + HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)}, + }, + } + + app := newHeadlessPasswordLoginTestApp(h) + + body, _ := json.Marshal(map[string]string{ + "client_id": "trusted-rp", + "loginId": "employee001", + "password": "password", + "login_challenge": "challenge-123", + }) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp, err := app.Test(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 400, got %d, body: %s", resp.StatusCode, string(bodyBytes)) + } +} + +func TestHeadlessPasswordLogin_InvalidClientAssertionRejected(t *testing.T) { + mockIdp := new(MockIdentityProvider) + mockIdp.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{ + SessionToken: &domain.Token{JWT: "valid-jwt"}, + Subject: "kratos-identity-id", + }, nil) + mockKratos := new(MockKratosAdminService) + mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil) + + validKey, jwks := mustHeadlessRSAJWK(t) + invalidKey, _ := mustHeadlessRSAJWK(t) + _ = validKey + + hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet: + json.NewEncoder(w).Encode(domain.HydraLoginRequest{ + Challenge: "challenge-123", + Client: domain.HydraClient{ + ClientID: "trusted-rp", + TokenEndpointAuthMethod: "private_key_jwt", + JWKS: jwks, + Metadata: map[string]interface{}{ + "status": "active", + "headless_login_enabled": true, + }, + }, + }) + return + case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut: + json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"}) + return + } + http.NotFound(w, r) + }) + + h := &AuthHandler{ + IdpProvider: mockIdp, + KratosAdmin: mockKratos, + Hydra: &service.HydraAdminService{ + AdminURL: "http://hydra.test", + HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)}, + }, + } + + app := newHeadlessPasswordLoginTestApp(h) + + clientAssertion := mustHeadlessClientAssertion( + t, + invalidKey, + "trusted-rp", + "http://example.com/api/v1/auth/headless/password/login", + ) + body, _ := json.Marshal(map[string]string{ + "client_id": "trusted-rp", + "client_assertion": clientAssertion, + "loginId": "employee001", + "password": "password", + "login_challenge": "challenge-123", + }) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp, err := app.Test(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusUnauthorized { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 401, got %d, body: %s", resp.StatusCode, string(bodyBytes)) + } +} + +func TestHeadlessPasswordLogin_HeadlessDisabledRejected(t *testing.T) { + mockIdp := new(MockIdentityProvider) + + hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet { + json.NewEncoder(w).Encode(domain.HydraLoginRequest{ + Challenge: "challenge-123", + Client: domain.HydraClient{ + ClientID: "trusted-rp", + TokenEndpointAuthMethod: "private_key_jwt", + JWKSUri: "https://rp.example.com/.well-known/jwks.json", + Metadata: map[string]interface{}{ + "status": "active", + }, + }, + }) + return + } + http.NotFound(w, r) + }) + + h := &AuthHandler{ + IdpProvider: mockIdp, + Hydra: &service.HydraAdminService{ + AdminURL: "http://hydra.test", + HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)}, + }, + } + + app := newHeadlessPasswordLoginTestApp(h) + + body, _ := json.Marshal(map[string]string{ + "client_id": "trusted-rp", + "loginId": "employee001", + "password": "password", + "login_challenge": "challenge-123", + }) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp, err := app.Test(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusForbidden { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 403, got %d, body: %s", resp.StatusCode, string(bodyBytes)) + } +} + +func TestHeadlessPasswordLogin_ClientIDMismatchRejected(t *testing.T) { + mockIdp := new(MockIdentityProvider) + + hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet { + json.NewEncoder(w).Encode(domain.HydraLoginRequest{ + Challenge: "challenge-123", + Client: domain.HydraClient{ + ClientID: "other-rp", + TokenEndpointAuthMethod: "private_key_jwt", + JWKSUri: "https://rp.example.com/.well-known/jwks.json", + Metadata: map[string]interface{}{ + "status": "active", + "headless_login_enabled": true, + }, + }, + }) + return + } + http.NotFound(w, r) + }) + + h := &AuthHandler{ + IdpProvider: mockIdp, + Hydra: &service.HydraAdminService{ + AdminURL: "http://hydra.test", + HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)}, + }, + } + + app := newHeadlessPasswordLoginTestApp(h) + + body, _ := json.Marshal(map[string]string{ + "client_id": "trusted-rp", + "loginId": "employee001", + "password": "password", + "login_challenge": "challenge-123", + }) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp, err := app.Test(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusForbidden { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 403, got %d, body: %s", resp.StatusCode, string(bodyBytes)) + } +} + func TestPasswordLogin_OIDC_InactiveClient(t *testing.T) { mockIdp := new(MockIdentityProvider) mockIdp.On("SignIn", "user@example.com", "password").Return(&domain.AuthInfo{