forked from baron/baron-sso
feat(auth): add trusted rp headless login flows
This commit is contained in:
@@ -526,6 +526,9 @@ func main() {
|
||||
auth.Post("/login/code/verify", authHandler.VerifyLoginCode)
|
||||
auth.Post("/login/code/verify-short", authHandler.VerifyLoginShortCode)
|
||||
auth.Post("/password/login", authHandler.PasswordLogin)
|
||||
auth.Post("/headless/password/login", authHandler.HeadlessPasswordLogin)
|
||||
auth.Post("/headless/link/init", authHandler.HeadlessLinkInit)
|
||||
auth.Post("/headless/link/poll", authHandler.HeadlessLinkPoll)
|
||||
auth.Get("/tenant-info", authHandler.GetTenantInfo)
|
||||
auth.Get("/consent", authHandler.GetConsentRequest)
|
||||
auth.Post("/consent/accept", authHandler.AcceptConsentRequest)
|
||||
|
||||
@@ -25,6 +25,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
josejwt "github.com/go-jose/go-jose/v4/jwt"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
@@ -46,6 +48,7 @@ const (
|
||||
prefixLoginCodeSmsOnly = "login_code_sms_only:"
|
||||
prefixLoginCodeQrPending = "login_code_qr_pending:"
|
||||
prefixLoginCodeQr = "login_code_qr:"
|
||||
prefixHeadlessLinkState = "headless_link_state:"
|
||||
prefixPollMeta = "poll_meta:"
|
||||
prefixQrRef = "qr_ref:"
|
||||
prefixQrMeta = "qr_meta:"
|
||||
@@ -75,6 +78,7 @@ const (
|
||||
loginCodeExpiration = 10 * time.Minute
|
||||
linkResendCooldown = 60 * time.Second
|
||||
prefixDrySend = "dry_send:"
|
||||
headlessJWKSFetchTTL = 5 * time.Second
|
||||
)
|
||||
|
||||
type AuthHandler struct {
|
||||
@@ -100,6 +104,40 @@ type signupState struct {
|
||||
ExpiresAt int64 `json:"expires_at"` // Unix timestamp
|
||||
}
|
||||
|
||||
type headlessLinkState struct {
|
||||
ClientID string `json:"clientId"`
|
||||
LoginChallenge string `json:"loginChallenge"`
|
||||
LoginID string `json:"loginId"`
|
||||
RedirectTo string `json:"redirectTo,omitempty"`
|
||||
}
|
||||
|
||||
type headlessClientAssertionClaims struct {
|
||||
Issuer string `json:"iss"`
|
||||
Subject string `json:"sub"`
|
||||
Audience headlessAssertionAud `json:"aud"`
|
||||
ExpiresAt int64 `json:"exp"`
|
||||
IssuedAt int64 `json:"iat,omitempty"`
|
||||
NotBefore int64 `json:"nbf,omitempty"`
|
||||
ID string `json:"jti,omitempty"`
|
||||
}
|
||||
|
||||
type headlessAssertionAud []string
|
||||
|
||||
func (a *headlessAssertionAud) UnmarshalJSON(data []byte) error {
|
||||
var single string
|
||||
if err := json.Unmarshal(data, &single); err == nil {
|
||||
*a = []string{single}
|
||||
return nil
|
||||
}
|
||||
|
||||
var list []string
|
||||
if err := json.Unmarshal(data, &list); err != nil {
|
||||
return err
|
||||
}
|
||||
*a = list
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateSecureToken - Helper to generate secure random strings
|
||||
func GenerateSecureToken(length int) string {
|
||||
b := make([]byte, length)
|
||||
@@ -1224,87 +1262,9 @@ func (h *AuthHandler) PollEnchantedLink(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
if data["status"] == "approved" {
|
||||
loginID := data["loginId"]
|
||||
if loginID == "" {
|
||||
loginID = data["login_id"]
|
||||
}
|
||||
if loginID == "" {
|
||||
slog.Warn("[Poll] Approved but missing loginId", "pendingRef", req.PendingRef)
|
||||
return errorJSON(c, fiber.StatusBadRequest, "Invalid session reference")
|
||||
}
|
||||
if h.IdpProvider == nil {
|
||||
return errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
|
||||
}
|
||||
|
||||
loginStrategy := h.loadLoginStrategy(req.PendingRef)
|
||||
if loginStrategy == "" {
|
||||
loginStrategy = loginFlowLink
|
||||
}
|
||||
|
||||
var authInfo *domain.AuthInfo
|
||||
var err error
|
||||
if loginStrategy == loginFlowCode {
|
||||
code, _ := h.RedisService.Get(prefixLoginCodeValue + req.PendingRef)
|
||||
code = normalizeLoginCode(code)
|
||||
if code == "" {
|
||||
slog.Warn("[Poll] Missing login code for approved flow", "pendingRef", req.PendingRef)
|
||||
return errorJSON(c, fiber.StatusBadRequest, "Login code expired")
|
||||
}
|
||||
flowID, _ := h.RedisService.Get(prefixLoginCode + loginID)
|
||||
if flowID == "" {
|
||||
return errorJSON(c, fiber.StatusNotFound, "Login flow expired")
|
||||
}
|
||||
authInfo, err = h.IdpProvider.VerifyLoginCode(loginID, flowID, code)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
return errorJSON(c, fiber.StatusNotImplemented, "Login method not supported")
|
||||
}
|
||||
slog.Error("[Poll] IDP code verify failed", "error", err)
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "Failed to verify login code")
|
||||
}
|
||||
} else {
|
||||
authInfo, err = h.IdpProvider.IssueSession(loginID)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
return errorJSON(c, fiber.StatusNotImplemented, "Login method not supported")
|
||||
}
|
||||
slog.Error("[Poll] IDP session issue failed", "error", err)
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session")
|
||||
}
|
||||
}
|
||||
if authInfo == nil || authInfo.SessionToken == nil || authInfo.SessionToken.JWT == "" {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session")
|
||||
}
|
||||
|
||||
c.Locals("login_id", loginID)
|
||||
setSessionIDLocal(c, authInfo.SessionToken)
|
||||
sessionID := extractSessionIDFromToken(authInfo.SessionToken)
|
||||
if sessionID == "" && authInfo.SessionToken != nil && authInfo.SessionToken.JWT != "" {
|
||||
if resolved, err := h.getKratosSessionID(authInfo.SessionToken.JWT); err == nil && resolved != "" {
|
||||
sessionID = resolved
|
||||
authInfo.SessionToken.SessionID = resolved
|
||||
setSessionIDLocal(c, authInfo.SessionToken)
|
||||
}
|
||||
}
|
||||
|
||||
sessionData := map[string]string{
|
||||
"status": statusSuccess,
|
||||
"jwt": authInfo.SessionToken.JWT,
|
||||
}
|
||||
if sessionID != "" {
|
||||
sessionData["session_id"] = sessionID
|
||||
}
|
||||
sessionDataJSON, _ := json.Marshal(sessionData)
|
||||
h.RedisService.Set(prefixSession+req.PendingRef, string(sessionDataJSON), defaultExpiration)
|
||||
|
||||
h.writeLinkAuditLog(loginID, req.PendingRef, authInfo.SessionToken, c)
|
||||
h.clearLoginMeta(req.PendingRef)
|
||||
if loginStrategy == loginFlowCode {
|
||||
h.RedisService.Delete(prefixLoginCode + loginID)
|
||||
h.RedisService.Delete(prefixLoginCodePending + loginID)
|
||||
h.RedisService.Delete(prefixLoginCodeSmsTarget + loginID)
|
||||
h.RedisService.Delete(prefixLoginCodeSmsLookup + loginID)
|
||||
h.RedisService.Delete(prefixLoginCodeValue + req.PendingRef)
|
||||
_, authInfo, err := h.completeApprovedLinkLogin(c, req.PendingRef)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
@@ -1671,6 +1631,617 @@ func logOidcRedirectSummary(source, redirectTo string) {
|
||||
)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) authenticatePasswordLogin(ctx context.Context, loginID, password string) (*domain.AuthInfo, error) {
|
||||
if h.IdpProvider == nil {
|
||||
return nil, fmt.Errorf("authentication service not configured")
|
||||
}
|
||||
|
||||
authInfo, err := h.IdpProvider.SignIn(loginID, password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
subject, resolveErr := h.resolveKratosIdentityIDFromLoginID(ctx, loginID)
|
||||
if resolveErr != nil || subject == "" {
|
||||
slog.Error("Failed to resolve kratos identity after login", "loginID", loginID, "error", resolveErr)
|
||||
return nil, fmt.Errorf("failed to resolve user identity")
|
||||
}
|
||||
|
||||
authInfo.Subject = subject
|
||||
return authInfo, nil
|
||||
}
|
||||
|
||||
func passwordLoginErrorSpec(err error) (int, string, string) {
|
||||
if err == nil {
|
||||
return fiber.StatusOK, "", ""
|
||||
}
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
return fiber.StatusNotImplemented, "not_supported", "Login method not supported"
|
||||
}
|
||||
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "identity") {
|
||||
return fiber.StatusNotFound, "not_found", "User not registered"
|
||||
}
|
||||
if strings.Contains(err.Error(), "failed to resolve user identity") {
|
||||
return fiber.StatusInternalServerError, "internal_error", "Failed to resolve user identity"
|
||||
}
|
||||
return fiber.StatusUnauthorized, "password_or_email_mismatch", "Invalid credentials"
|
||||
}
|
||||
|
||||
func headlessAssertionAudiences(c *fiber.Ctx) []string {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
path := strings.TrimSpace(c.Path())
|
||||
if path == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
base := strings.TrimRight(strings.TrimSpace(c.BaseURL()), "/")
|
||||
if base == "" {
|
||||
return []string{path}
|
||||
}
|
||||
|
||||
return []string{base + path, path}
|
||||
}
|
||||
|
||||
func containsHeadlessAudience(expected []string, actual headlessAssertionAud) bool {
|
||||
for _, audience := range actual {
|
||||
for _, candidate := range expected {
|
||||
if strings.TrimSpace(audience) == strings.TrimSpace(candidate) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *AuthHandler) loadHeadlessJWKS(ctx context.Context, client domain.HydraClient) (*jose.JSONWebKeySet, error) {
|
||||
var raw []byte
|
||||
switch {
|
||||
case client.JWKS != nil:
|
||||
data, err := json.Marshal(client.JWKS)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode jwks: %w", err)
|
||||
}
|
||||
raw = data
|
||||
case strings.TrimSpace(client.JWKSUri) != "":
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, strings.TrimSpace(client.JWKSUri), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build jwks request: %w", err)
|
||||
}
|
||||
client := &http.Client{Timeout: headlessJWKSFetchTTL}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch jwks: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return nil, fmt.Errorf("failed to fetch jwks status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read jwks response: %w", err)
|
||||
}
|
||||
raw = body
|
||||
default:
|
||||
return nil, fmt.Errorf("trusted rp public key is not configured")
|
||||
}
|
||||
|
||||
var keySet jose.JSONWebKeySet
|
||||
if err := json.Unmarshal(raw, &keySet); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode jwks: %w", err)
|
||||
}
|
||||
if len(keySet.Keys) == 0 {
|
||||
return nil, fmt.Errorf("trusted rp jwks has no keys")
|
||||
}
|
||||
return &keySet, nil
|
||||
}
|
||||
|
||||
func validateHeadlessClientAssertionClaims(c *fiber.Ctx, claims headlessClientAssertionClaims, clientID string) error {
|
||||
now := time.Now().Unix()
|
||||
if claims.Issuer != clientID || claims.Subject != clientID {
|
||||
return fmt.Errorf("client assertion iss/sub mismatch")
|
||||
}
|
||||
if claims.ExpiresAt == 0 || claims.ExpiresAt <= now {
|
||||
return fmt.Errorf("client assertion expired")
|
||||
}
|
||||
if claims.NotBefore != 0 && claims.NotBefore > now {
|
||||
return fmt.Errorf("client assertion not active yet")
|
||||
}
|
||||
if claims.IssuedAt != 0 && claims.IssuedAt > now+60 {
|
||||
return fmt.Errorf("client assertion issued in the future")
|
||||
}
|
||||
if !containsHeadlessAudience(headlessAssertionAudiences(c), claims.Audience) {
|
||||
return fmt.Errorf("client assertion audience mismatch")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.HydraClient, clientID, clientAssertion string) error {
|
||||
assertion := strings.TrimSpace(clientAssertion)
|
||||
if assertion == "" {
|
||||
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_assertion is required")
|
||||
}
|
||||
|
||||
keySet, err := h.loadHeadlessJWKS(c.Context(), client)
|
||||
if err != nil {
|
||||
slog.Error("failed to load jwks for headless client assertion", "clientID", clientID, "error", err)
|
||||
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
|
||||
}
|
||||
|
||||
token, err := josejwt.ParseSigned(assertion, []jose.SignatureAlgorithm{
|
||||
jose.RS256, jose.RS384, jose.RS512,
|
||||
jose.PS256, jose.PS384, jose.PS512,
|
||||
jose.ES256, jose.ES384, jose.ES512,
|
||||
jose.EdDSA,
|
||||
})
|
||||
if err != nil {
|
||||
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
|
||||
}
|
||||
|
||||
expectedKid := ""
|
||||
if len(token.Headers) > 0 {
|
||||
expectedKid = strings.TrimSpace(token.Headers[0].KeyID)
|
||||
}
|
||||
|
||||
for _, key := range keySet.Keys {
|
||||
if expectedKid != "" && key.KeyID != "" && key.KeyID != expectedKid {
|
||||
continue
|
||||
}
|
||||
|
||||
var claims headlessClientAssertionClaims
|
||||
if err := token.Claims(key.Key, &claims); err != nil {
|
||||
continue
|
||||
}
|
||||
if err := validateHeadlessClientAssertionClaims(c, claims, clientID); err != nil {
|
||||
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
|
||||
}
|
||||
|
||||
func (h *AuthHandler) storeHeadlessLinkState(pendingRef string, state headlessLinkState, ttl time.Duration) {
|
||||
if h.RedisService == nil || pendingRef == "" {
|
||||
return
|
||||
}
|
||||
raw, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = h.RedisService.Set(prefixHeadlessLinkState+pendingRef, string(raw), ttl)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) loadHeadlessLinkState(pendingRef string) (headlessLinkState, bool) {
|
||||
if h.RedisService == nil || pendingRef == "" {
|
||||
return headlessLinkState{}, false
|
||||
}
|
||||
raw, err := h.RedisService.Get(prefixHeadlessLinkState + pendingRef)
|
||||
if err != nil || raw == "" {
|
||||
return headlessLinkState{}, false
|
||||
}
|
||||
var state headlessLinkState
|
||||
if err := json.Unmarshal([]byte(raw), &state); err != nil {
|
||||
return headlessLinkState{}, false
|
||||
}
|
||||
return state, true
|
||||
}
|
||||
|
||||
func (h *AuthHandler) completeApprovedLinkLogin(c *fiber.Ctx, pendingRef string) (string, *domain.AuthInfo, error) {
|
||||
val, err := h.RedisService.Get(prefixSession + pendingRef)
|
||||
if err != nil || val == "" {
|
||||
return "", nil, errorJSON(c, fiber.StatusBadRequest, "Invalid session reference")
|
||||
}
|
||||
|
||||
var data map[string]string
|
||||
_ = json.Unmarshal([]byte(val), &data)
|
||||
loginID := data["loginId"]
|
||||
if loginID == "" {
|
||||
loginID = data["login_id"]
|
||||
}
|
||||
if loginID == "" {
|
||||
slog.Warn("[Poll] Approved but missing loginId", "pendingRef", pendingRef)
|
||||
return "", nil, errorJSON(c, fiber.StatusBadRequest, "Invalid session reference")
|
||||
}
|
||||
if h.IdpProvider == nil {
|
||||
return "", nil, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
|
||||
}
|
||||
|
||||
loginStrategy := h.loadLoginStrategy(pendingRef)
|
||||
if loginStrategy == "" {
|
||||
loginStrategy = loginFlowLink
|
||||
}
|
||||
|
||||
var authInfo *domain.AuthInfo
|
||||
if loginStrategy == loginFlowCode {
|
||||
code, _ := h.RedisService.Get(prefixLoginCodeValue + pendingRef)
|
||||
code = normalizeLoginCode(code)
|
||||
if code == "" {
|
||||
slog.Warn("[Poll] Missing login code for approved flow", "pendingRef", pendingRef)
|
||||
return "", nil, errorJSON(c, fiber.StatusBadRequest, "Login code expired")
|
||||
}
|
||||
flowID, _ := h.RedisService.Get(prefixLoginCode + loginID)
|
||||
if flowID == "" {
|
||||
return "", nil, errorJSON(c, fiber.StatusNotFound, "Login flow expired")
|
||||
}
|
||||
authInfo, err = h.IdpProvider.VerifyLoginCode(loginID, flowID, code)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
return "", nil, errorJSON(c, fiber.StatusNotImplemented, "Login method not supported")
|
||||
}
|
||||
slog.Error("[Poll] IDP code verify failed", "error", err)
|
||||
return "", nil, errorJSON(c, fiber.StatusInternalServerError, "Failed to verify login code")
|
||||
}
|
||||
} else {
|
||||
authInfo, err = h.IdpProvider.IssueSession(loginID)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
return "", nil, errorJSON(c, fiber.StatusNotImplemented, "Login method not supported")
|
||||
}
|
||||
slog.Error("[Poll] IDP session issue failed", "error", err)
|
||||
return "", nil, errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session")
|
||||
}
|
||||
}
|
||||
if authInfo == nil || authInfo.SessionToken == nil || authInfo.SessionToken.JWT == "" {
|
||||
return "", nil, errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session")
|
||||
}
|
||||
|
||||
c.Locals("login_id", loginID)
|
||||
setSessionIDLocal(c, authInfo.SessionToken)
|
||||
sessionID := extractSessionIDFromToken(authInfo.SessionToken)
|
||||
if sessionID == "" && authInfo.SessionToken != nil && authInfo.SessionToken.JWT != "" {
|
||||
if resolved, err := h.getKratosSessionID(authInfo.SessionToken.JWT); err == nil && resolved != "" {
|
||||
sessionID = resolved
|
||||
authInfo.SessionToken.SessionID = resolved
|
||||
setSessionIDLocal(c, authInfo.SessionToken)
|
||||
}
|
||||
}
|
||||
|
||||
sessionData := map[string]string{
|
||||
"status": statusSuccess,
|
||||
"jwt": authInfo.SessionToken.JWT,
|
||||
}
|
||||
if sessionID != "" {
|
||||
sessionData["session_id"] = sessionID
|
||||
}
|
||||
sessionDataJSON, _ := json.Marshal(sessionData)
|
||||
_ = h.RedisService.Set(prefixSession+pendingRef, string(sessionDataJSON), defaultExpiration)
|
||||
|
||||
h.writeLinkAuditLog(loginID, pendingRef, authInfo.SessionToken, c)
|
||||
h.clearLoginMeta(pendingRef)
|
||||
if loginStrategy == loginFlowCode {
|
||||
_ = h.RedisService.Delete(prefixLoginCode + loginID)
|
||||
_ = h.RedisService.Delete(prefixLoginCodePending + loginID)
|
||||
_ = h.RedisService.Delete(prefixLoginCodeSmsTarget + loginID)
|
||||
_ = h.RedisService.Delete(prefixLoginCodeSmsLookup + loginID)
|
||||
_ = h.RedisService.Delete(prefixLoginCodeValue + pendingRef)
|
||||
}
|
||||
|
||||
return loginID, authInfo, nil
|
||||
}
|
||||
|
||||
func (h *AuthHandler) validateHeadlessPasswordLoginClient(loginReq *domain.HydraLoginRequest, clientID string) error {
|
||||
if loginReq == nil {
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request")
|
||||
}
|
||||
|
||||
if strings.TrimSpace(loginReq.Client.ClientID) != strings.TrimSpace(clientID) {
|
||||
return fiber.NewError(fiber.StatusForbidden, "The client application is not allowed to use this login request.")
|
||||
}
|
||||
|
||||
if metadata := loginReq.Client.Metadata; metadata != nil {
|
||||
if status, ok := metadata["status"].(string); ok && strings.ToLower(status) == "inactive" {
|
||||
return fiber.NewError(fiber.StatusForbidden, "The client application is disabled.")
|
||||
}
|
||||
}
|
||||
|
||||
if !loginReq.Client.IsHeadlessLoginEnabled() {
|
||||
return fiber.NewError(fiber.StatusForbidden, "The client application is not allowed to use headless password login.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *AuthHandler) HeadlessPasswordLogin(c *fiber.Ctx) error {
|
||||
var req struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientAssertion string `json:"client_assertion"`
|
||||
LoginID string `json:"loginId"`
|
||||
Password string `json:"password"`
|
||||
LoginChallenge string `json:"login_challenge"`
|
||||
}
|
||||
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "Invalid request body")
|
||||
}
|
||||
|
||||
clientID := strings.TrimSpace(req.ClientID)
|
||||
loginID := strings.TrimSpace(req.LoginID)
|
||||
loginChallenge := strings.TrimSpace(req.LoginChallenge)
|
||||
if clientID == "" || loginID == "" || strings.TrimSpace(req.Password) == "" || loginChallenge == "" {
|
||||
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_id, loginId, password and login_challenge are required")
|
||||
}
|
||||
|
||||
if h.IdpProvider == nil || h.Hydra == nil {
|
||||
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
|
||||
}
|
||||
|
||||
loginReq, err := h.Hydra.GetLoginRequest(c.Context(), loginChallenge)
|
||||
if err != nil {
|
||||
slog.Error("failed to get hydra login request for headless password login", "error", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request")
|
||||
}
|
||||
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
authInfo, authErr := h.authenticatePasswordLogin(c.Context(), loginID, req.Password)
|
||||
if authErr != nil {
|
||||
status, code, message := passwordLoginErrorSpec(authErr)
|
||||
return errorJSONCode(c, status, code, message)
|
||||
}
|
||||
|
||||
setSessionIDLocal(c, authInfo.SessionToken)
|
||||
|
||||
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), loginChallenge, authInfo.Subject)
|
||||
if err != nil {
|
||||
slog.Error("failed to accept hydra login request in headless password login", "error", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request")
|
||||
}
|
||||
|
||||
logOidcRedirectSummary("headless_password_login", acceptResp.RedirectTo)
|
||||
return c.JSON(fiber.Map{
|
||||
"redirectTo": acceptResp.RedirectTo,
|
||||
"status": "ok",
|
||||
"provider": h.IdpProvider.Name(),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) startHeadlessPhoneLink(c *fiber.Ctx, loginID string) (fiber.Map, string, string, time.Duration, error) {
|
||||
rawLoginID := strings.ReplaceAll(loginID, "-", "")
|
||||
rawLoginID = strings.ReplaceAll(rawLoginID, " ", "")
|
||||
if rawLoginID == "" || strings.Contains(rawLoginID, "@") {
|
||||
return nil, "", "", 0, errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "phone-based loginId is required")
|
||||
}
|
||||
|
||||
lookupLoginID := normalizePhoneForLoginID(rawLoginID)
|
||||
if h.IdpProvider == nil {
|
||||
return nil, "", "", 0, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
|
||||
}
|
||||
exists, err := h.IdpProvider.UserExists(lookupLoginID)
|
||||
if err != nil {
|
||||
slog.Warn("[HeadlessLink] IDP user lookup failed", "loginID", rawLoginID, "error", err)
|
||||
return nil, "", "", 0, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
|
||||
}
|
||||
if !exists {
|
||||
slog.Warn("[HeadlessLink] User not found", "loginID", rawLoginID)
|
||||
return nil, "", "", 0, errorJSON(c, fiber.StatusNotFound, "User not registered")
|
||||
}
|
||||
|
||||
userfrontURL := h.resolveUserfrontURL(c)
|
||||
if init, err := h.IdpProvider.InitiateLinkLogin(lookupLoginID, userfrontURL); err == nil && init != nil && init.Mode != "" {
|
||||
keyLoginID := lookupLoginID
|
||||
if init.LoginID != "" {
|
||||
keyLoginID = init.LoginID
|
||||
}
|
||||
if init.FlowID != "" {
|
||||
_ = h.RedisService.Set(prefixLoginCode+keyLoginID, init.FlowID, loginCodeExpiration)
|
||||
}
|
||||
pendingRef := GenerateSecureToken(3)
|
||||
sessionData, _ := json.Marshal(map[string]string{
|
||||
"status": statusPending,
|
||||
"loginId": keyLoginID,
|
||||
})
|
||||
_ = h.RedisService.Set(prefixSession+pendingRef, string(sessionData), loginCodeExpiration)
|
||||
h.storeLoginMeta(pendingRef, rawLoginID, "sms", loginFlowLink, loginFlowCode, loginCodeExpiration)
|
||||
_ = h.RedisService.Set(prefixLoginCodePending+keyLoginID, pendingRef, loginCodeExpiration)
|
||||
if keyLoginID != lookupLoginID {
|
||||
_ = h.RedisService.Set(prefixLoginCodeSmsTarget+keyLoginID, lookupLoginID, loginCodeExpiration)
|
||||
_ = h.RedisService.Set(prefixLoginCodeSmsLookup+lookupLoginID, keyLoginID, loginCodeExpiration)
|
||||
}
|
||||
expiresIn := int(loginCodeExpiration.Seconds())
|
||||
if !init.ExpiresAt.IsZero() {
|
||||
if seconds := int(time.Until(init.ExpiresAt).Seconds()); seconds > 0 {
|
||||
expiresIn = seconds
|
||||
}
|
||||
}
|
||||
return fiber.Map{
|
||||
"pendingRef": pendingRef,
|
||||
"status": "pending",
|
||||
"mode": init.Mode,
|
||||
"provider": h.IdpProvider.Name(),
|
||||
"expiresIn": expiresIn,
|
||||
"interval": int(minPollInterval.Seconds()),
|
||||
"resendAfter": int(linkResendCooldown.Seconds()),
|
||||
}, pendingRef, keyLoginID, loginCodeExpiration, nil
|
||||
} else if err != nil && !errors.Is(err, domain.ErrNotSupported) {
|
||||
slog.Error("[HeadlessLink] Link login init failed", "provider", h.IdpProvider.Name(), "error", err)
|
||||
return nil, "", "", 0, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
|
||||
}
|
||||
|
||||
if h.SmsService == nil {
|
||||
return nil, "", "", 0, errorJSON(c, fiber.StatusInternalServerError, "SMS service not configured")
|
||||
}
|
||||
|
||||
token := GenerateSecureToken(3)
|
||||
pendingRef := GenerateSecureToken(3)
|
||||
sessionData, _ := json.Marshal(map[string]string{
|
||||
"status": statusPending,
|
||||
"loginId": lookupLoginID,
|
||||
})
|
||||
_ = h.RedisService.Set(prefixSession+pendingRef, string(sessionData), defaultExpiration)
|
||||
_ = h.RedisService.Set(prefixToken+token, fmt.Sprintf(`{"pendingRef":"%s","loginId":"%s"}`, pendingRef, lookupLoginID), defaultExpiration)
|
||||
h.storeLoginMeta(pendingRef, rawLoginID, "sms", loginFlowLink, loginFlowLink, defaultExpiration)
|
||||
|
||||
link := fmt.Sprintf("%s/verify/%s", strings.TrimRight(userfrontURL, "/"), token)
|
||||
content := fmt.Sprintf("[Baron 로그인] 로그인 링크: %s", link)
|
||||
if err := h.SmsService.SendSms(rawLoginID, content); err != nil {
|
||||
slog.Error("[HeadlessLink] SMS send failed", "error", err)
|
||||
return nil, "", "", 0, errorJSON(c, fiber.StatusInternalServerError, "Failed to send SMS")
|
||||
}
|
||||
|
||||
return fiber.Map{
|
||||
"pendingRef": pendingRef,
|
||||
"status": "pending",
|
||||
"provider": h.IdpProvider.Name(),
|
||||
"expiresIn": int(defaultExpiration.Seconds()),
|
||||
"interval": int(minPollInterval.Seconds()),
|
||||
"resendAfter": int(linkResendCooldown.Seconds()),
|
||||
}, pendingRef, lookupLoginID, defaultExpiration, nil
|
||||
}
|
||||
|
||||
func (h *AuthHandler) HeadlessLinkInit(c *fiber.Ctx) error {
|
||||
var req struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientAssertion string `json:"client_assertion"`
|
||||
LoginID string `json:"loginId"`
|
||||
LoginChallenge string `json:"login_challenge"`
|
||||
}
|
||||
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "Invalid request body")
|
||||
}
|
||||
|
||||
clientID := strings.TrimSpace(req.ClientID)
|
||||
loginChallenge := strings.TrimSpace(req.LoginChallenge)
|
||||
loginID := strings.TrimSpace(req.LoginID)
|
||||
if clientID == "" || loginChallenge == "" || loginID == "" {
|
||||
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_id, client_assertion, loginId and login_challenge are required")
|
||||
}
|
||||
if h.Hydra == nil {
|
||||
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
|
||||
}
|
||||
|
||||
loginReq, err := h.Hydra.GetLoginRequest(c.Context(), loginChallenge)
|
||||
if err != nil {
|
||||
slog.Error("failed to get hydra login request for headless link init", "error", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request")
|
||||
}
|
||||
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, pendingRef, resolvedLoginID, ttl, err := h.startHeadlessPhoneLink(c, loginID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.storeHeadlessLinkState(pendingRef, headlessLinkState{
|
||||
ClientID: clientID,
|
||||
LoginChallenge: loginChallenge,
|
||||
LoginID: resolvedLoginID,
|
||||
}, ttl)
|
||||
return c.JSON(resp)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) HeadlessLinkPoll(c *fiber.Ctx) error {
|
||||
var req struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientAssertion string `json:"client_assertion"`
|
||||
PendingRef string `json:"pendingRef"`
|
||||
}
|
||||
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "Invalid request body")
|
||||
}
|
||||
|
||||
clientID := strings.TrimSpace(req.ClientID)
|
||||
pendingRef := strings.TrimSpace(req.PendingRef)
|
||||
if clientID == "" || pendingRef == "" {
|
||||
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_id, client_assertion and pendingRef are required")
|
||||
}
|
||||
|
||||
state, ok := h.loadHeadlessLinkState(pendingRef)
|
||||
if !ok {
|
||||
return c.JSON(fiber.Map{
|
||||
"error": "expired_token",
|
||||
"code": "expired_token",
|
||||
})
|
||||
}
|
||||
if state.ClientID != clientID {
|
||||
return fiber.NewError(fiber.StatusForbidden, "The client application is not allowed to use this pending login.")
|
||||
}
|
||||
if state.RedirectTo != "" {
|
||||
return c.JSON(fiber.Map{
|
||||
"redirectTo": state.RedirectTo,
|
||||
"status": "ok",
|
||||
})
|
||||
}
|
||||
if h.Hydra == nil {
|
||||
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
|
||||
}
|
||||
|
||||
loginReq, err := h.Hydra.GetLoginRequest(c.Context(), state.LoginChallenge)
|
||||
if err != nil {
|
||||
slog.Error("failed to get hydra login request for headless link poll", "error", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request")
|
||||
}
|
||||
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val, err := h.RedisService.Get(prefixSession + pendingRef)
|
||||
if err != nil || val == "" {
|
||||
return c.JSON(fiber.Map{
|
||||
"error": "expired_token",
|
||||
"code": "expired_token",
|
||||
})
|
||||
}
|
||||
|
||||
var session map[string]string
|
||||
_ = json.Unmarshal([]byte(val), &session)
|
||||
if session["status"] == statusPending {
|
||||
return c.JSON(fiber.Map{
|
||||
"error": "authorization_pending",
|
||||
"code": "authorization_pending",
|
||||
"interval": int(minPollInterval.Seconds()),
|
||||
})
|
||||
}
|
||||
|
||||
loginID := state.LoginID
|
||||
if session["status"] == "approved" {
|
||||
completedLoginID, _, err := h.completeApprovedLinkLogin(c, pendingRef)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
loginID = completedLoginID
|
||||
}
|
||||
|
||||
if loginID == "" {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "Failed to resolve approved user identity")
|
||||
}
|
||||
subject, err := h.resolveKratosIdentityIDFromLoginID(c.Context(), loginID)
|
||||
if err != nil || subject == "" {
|
||||
slog.Error("failed to resolve kratos identity for headless link poll", "loginID", loginID, "error", err)
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "Failed to resolve user identity")
|
||||
}
|
||||
|
||||
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), state.LoginChallenge, subject)
|
||||
if err != nil {
|
||||
slog.Error("failed to accept hydra login request in headless link poll", "error", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request")
|
||||
}
|
||||
|
||||
state.RedirectTo = acceptResp.RedirectTo
|
||||
h.storeHeadlessLinkState(pendingRef, state, defaultExpiration)
|
||||
logOidcRedirectSummary("headless_link_poll", acceptResp.RedirectTo)
|
||||
return c.JSON(fiber.Map{
|
||||
"redirectTo": acceptResp.RedirectTo,
|
||||
"status": "ok",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
|
||||
startTime := time.Now()
|
||||
ale := logger.NewAuditLogEntry(c, "login")
|
||||
@@ -1704,28 +2275,16 @@ func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
|
||||
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
|
||||
}
|
||||
|
||||
authInfo, err := h.IdpProvider.SignIn(loginID, req.Password)
|
||||
authInfo, err := h.authenticatePasswordLogin(c.Context(), loginID, req.Password)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
return errorJSONCode(c, fiber.StatusNotImplemented, "not_supported", "Login method not supported")
|
||||
}
|
||||
ale.Status = fiber.StatusUnauthorized
|
||||
status, code, message := passwordLoginErrorSpec(err)
|
||||
ale.Status = status
|
||||
ale.LatencyMs = time.Since(startTime)
|
||||
ale.ProviderError = err.Error()
|
||||
ale.Log(slog.LevelWarn, "IDP sign-in failed", slog.String("provider", h.IdpProvider.Name()))
|
||||
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "identity") {
|
||||
return errorJSONCode(c, fiber.StatusNotFound, "not_found", "User not registered")
|
||||
}
|
||||
return errorJSONCode(c, fiber.StatusUnauthorized, "password_or_email_mismatch", "Invalid credentials")
|
||||
return errorJSONCode(c, status, code, message)
|
||||
}
|
||||
|
||||
subject, resolveErr := h.resolveKratosIdentityIDFromLoginID(c.Context(), loginID)
|
||||
if resolveErr != nil || subject == "" {
|
||||
slog.Error("Failed to resolve kratos identity after login", "loginID", loginID, "error", resolveErr)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "Failed to resolve user identity")
|
||||
}
|
||||
authInfo.Subject = subject
|
||||
|
||||
ale.Status = fiber.StatusOK
|
||||
ale.LatencyMs = time.Since(startTime)
|
||||
setSessionIDLocal(c, authInfo.SessionToken)
|
||||
@@ -1746,7 +2305,7 @@ func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
|
||||
}
|
||||
}
|
||||
|
||||
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), req.LoginChallenge, subject)
|
||||
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), req.LoginChallenge, authInfo.Subject)
|
||||
if err != nil {
|
||||
slog.Error("failed to accept hydra login request", "error", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request")
|
||||
|
||||
@@ -2,14 +2,17 @@ package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// Mock services
|
||||
@@ -21,6 +24,14 @@ type mockSmsService struct{}
|
||||
|
||||
func (m *mockSmsService) SendSms(to, content string) error { return nil }
|
||||
|
||||
func newHeadlessLinkTestApp(h *AuthHandler) *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/headless/link/init", h.HeadlessLinkInit)
|
||||
app.Post("/api/v1/auth/headless/link/poll", h.HeadlessLinkPoll)
|
||||
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
|
||||
return app
|
||||
}
|
||||
|
||||
func TestEnchantedLinkFlow_Email_Success(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
// Force "Not Supported" for InitiateLinkLogin only to trigger custom Enchanted Link logic
|
||||
@@ -144,3 +155,162 @@ func TestPollEnchantedLink_ExpiredToken_ReturnsCode(t *testing.T) {
|
||||
assert.Equal(t, "expired_token", got["error"])
|
||||
assert.Equal(t, "expired_token", got["code"])
|
||||
}
|
||||
|
||||
func TestHeadlessLinkInit_TrustedClientSuccess(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||
|
||||
idp := &mockIdpProvider{
|
||||
userExists: true,
|
||||
initiateLinkErr: domain.ErrNotSupported,
|
||||
}
|
||||
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet {
|
||||
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||
Challenge: "challenge-123",
|
||||
Client: domain.HydraClient{
|
||||
ClientID: "trusted-rp",
|
||||
TokenEndpointAuthMethod: "private_key_jwt",
|
||||
JWKS: jwks,
|
||||
Metadata: map[string]interface{}{
|
||||
"status": "active",
|
||||
"headless_login_enabled": true,
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: idp,
|
||||
SmsService: &mockSmsService{},
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||
},
|
||||
}
|
||||
|
||||
app := newHeadlessLinkTestApp(h)
|
||||
t.Setenv("USERFRONT_URL", "http://userfront.test")
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"client_id": "trusted-rp",
|
||||
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "trusted-rp", "http://example.com/api/v1/auth/headless/link/init"),
|
||||
"loginId": "010-1234-5678",
|
||||
"login_challenge": "challenge-123",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var got map[string]interface{}
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.NotEmpty(t, got["pendingRef"])
|
||||
_, hasUserCode := got["userCode"]
|
||||
assert.False(t, hasUserCode)
|
||||
}
|
||||
|
||||
func TestHeadlessLinkPoll_AfterApprovalReturnsRedirect(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||
|
||||
idp := &mockIdpProvider{
|
||||
userExists: true,
|
||||
initiateLinkErr: domain.ErrNotSupported,
|
||||
}
|
||||
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
|
||||
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||
Challenge: "challenge-123",
|
||||
Client: domain.HydraClient{
|
||||
ClientID: "trusted-rp",
|
||||
TokenEndpointAuthMethod: "private_key_jwt",
|
||||
JWKS: jwks,
|
||||
Metadata: map[string]interface{}{
|
||||
"status": "active",
|
||||
"headless_login_enabled": true,
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "+821012345678").Return("kratos-identity-id", nil)
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: idp,
|
||||
SmsService: &mockSmsService{},
|
||||
KratosAdmin: mockKratos,
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||
},
|
||||
}
|
||||
|
||||
app := newHeadlessLinkTestApp(h)
|
||||
t.Setenv("USERFRONT_URL", "http://userfront.test")
|
||||
|
||||
initBody, _ := json.Marshal(map[string]string{
|
||||
"client_id": "trusted-rp",
|
||||
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "trusted-rp", "http://example.com/api/v1/auth/headless/link/init"),
|
||||
"loginId": "010-1234-5678",
|
||||
"login_challenge": "challenge-123",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(initBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var initResp map[string]interface{}
|
||||
_ = json.NewDecoder(resp.Body).Decode(&initResp)
|
||||
pendingRef := initResp["pendingRef"].(string)
|
||||
assert.NotEmpty(t, pendingRef)
|
||||
|
||||
var token string
|
||||
for k := range redis.data {
|
||||
if len(k) > 16 && k[:16] == "enchanted_token:" {
|
||||
token = k[16:]
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
verifyBody, _ := json.Marshal(map[string]interface{}{
|
||||
"token": token,
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
pollBody, _ := json.Marshal(map[string]string{
|
||||
"client_id": "trusted-rp",
|
||||
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "trusted-rp", "http://example.com/api/v1/auth/headless/link/poll"),
|
||||
"pendingRef": pendingRef,
|
||||
})
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/poll", bytes.NewReader(pollBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
var pollResp map[string]interface{}
|
||||
_ = json.NewDecoder(resp.Body).Decode(&pollResp)
|
||||
assert.Equal(t, "http://rp/cb", pollResp["redirectTo"])
|
||||
assert.Equal(t, "ok", pollResp["status"])
|
||||
}
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
@@ -17,7 +19,10 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
josejwt "github.com/go-jose/go-jose/v4/jwt"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
@@ -121,6 +126,79 @@ func newAuthLoginTestApp(h *AuthHandler) *fiber.App {
|
||||
return app
|
||||
}
|
||||
|
||||
func newHeadlessPasswordLoginTestApp(h *AuthHandler) *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/headless/password/login", h.HeadlessPasswordLogin)
|
||||
return app
|
||||
}
|
||||
|
||||
func mustHeadlessRSAJWK(t *testing.T) (*rsa.PrivateKey, map[string]any) {
|
||||
t.Helper()
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate rsa key: %v", err)
|
||||
}
|
||||
|
||||
keySet := jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{
|
||||
{
|
||||
Key: &privateKey.PublicKey,
|
||||
KeyID: "test-kid",
|
||||
Use: "sig",
|
||||
Algorithm: string(jose.RS256),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(keySet)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal jwks: %v", err)
|
||||
}
|
||||
|
||||
var jwks map[string]any
|
||||
if err := json.Unmarshal(raw, &jwks); err != nil {
|
||||
t.Fatalf("failed to decode jwks map: %v", err)
|
||||
}
|
||||
|
||||
return privateKey, jwks
|
||||
}
|
||||
|
||||
func mustHeadlessClientAssertion(t *testing.T, privateKey *rsa.PrivateKey, clientID, audience string) string {
|
||||
t.Helper()
|
||||
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.RS256,
|
||||
Key: jose.JSONWebKey{
|
||||
Key: privateKey,
|
||||
KeyID: "test-kid",
|
||||
Use: "sig",
|
||||
Algorithm: string(jose.RS256),
|
||||
},
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create signer: %v", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
raw, err := josejwt.Signed(signer).Claims(josejwt.Claims{
|
||||
Issuer: clientID,
|
||||
Subject: clientID,
|
||||
Audience: josejwt.Audience{audience},
|
||||
Expiry: josejwt.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
IssuedAt: josejwt.NewNumericDate(now),
|
||||
NotBefore: josejwt.NewNumericDate(
|
||||
now.Add(-1 * time.Minute),
|
||||
),
|
||||
ID: "assertion-1",
|
||||
}).Serialize()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to sign client assertion: %v", err)
|
||||
}
|
||||
|
||||
return raw
|
||||
}
|
||||
|
||||
// mockHydraTransport simulates Hydra API responses
|
||||
func mockHydraTransport(handler http.Handler) http.RoundTripper {
|
||||
return roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
@@ -206,6 +284,342 @@ func TestPasswordLogin_OIDC_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadlessPasswordLogin_TrustedClientSuccess(t *testing.T) {
|
||||
mockIdp := new(MockIdentityProvider)
|
||||
mockIdp.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{
|
||||
SessionToken: &domain.Token{JWT: "valid-jwt"},
|
||||
Subject: "kratos-identity-id",
|
||||
}, nil)
|
||||
|
||||
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||
jwksBody, _ := json.Marshal(jwks)
|
||||
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(jwksBody)
|
||||
}))
|
||||
defer jwksServer.Close()
|
||||
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
|
||||
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||
Challenge: "challenge-123",
|
||||
Client: domain.HydraClient{
|
||||
ClientID: "trusted-rp",
|
||||
TokenEndpointAuthMethod: "private_key_jwt",
|
||||
JWKSUri: jwksServer.URL + "/.well-known/jwks.json",
|
||||
Metadata: map[string]interface{}{
|
||||
"status": "active",
|
||||
"headless_login_enabled": true,
|
||||
},
|
||||
},
|
||||
})
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
|
||||
json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil)
|
||||
|
||||
h := &AuthHandler{
|
||||
IdpProvider: mockIdp,
|
||||
KratosAdmin: mockKratos,
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||
},
|
||||
}
|
||||
|
||||
app := newHeadlessPasswordLoginTestApp(h)
|
||||
|
||||
clientAssertion := mustHeadlessClientAssertion(
|
||||
t,
|
||||
privateKey,
|
||||
"trusted-rp",
|
||||
"http://example.com/api/v1/auth/headless/password/login",
|
||||
)
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"client_id": "trusted-rp",
|
||||
"client_assertion": clientAssertion,
|
||||
"loginId": "employee001",
|
||||
"password": "password",
|
||||
"login_challenge": "challenge-123",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("expected 200, got %d, body: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var got map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if got["redirectTo"] != "http://rp/cb" {
|
||||
t.Fatalf("expected redirectTo http://rp/cb, got %v", got["redirectTo"])
|
||||
}
|
||||
if _, ok := got["sessionJwt"]; ok {
|
||||
t.Fatalf("expected headless response to omit sessionJwt, got %v", got["sessionJwt"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadlessPasswordLogin_MissingClientAssertionRejected(t *testing.T) {
|
||||
mockIdp := new(MockIdentityProvider)
|
||||
mockIdp.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{
|
||||
SessionToken: &domain.Token{JWT: "valid-jwt"},
|
||||
Subject: "kratos-identity-id",
|
||||
}, nil)
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil)
|
||||
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
|
||||
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||
Challenge: "challenge-123",
|
||||
Client: domain.HydraClient{
|
||||
ClientID: "trusted-rp",
|
||||
TokenEndpointAuthMethod: "private_key_jwt",
|
||||
JWKS: map[string]any{
|
||||
"keys": []map[string]any{},
|
||||
},
|
||||
Metadata: map[string]interface{}{
|
||||
"status": "active",
|
||||
"headless_login_enabled": true,
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
|
||||
json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
h := &AuthHandler{
|
||||
IdpProvider: mockIdp,
|
||||
KratosAdmin: mockKratos,
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||
},
|
||||
}
|
||||
|
||||
app := newHeadlessPasswordLoginTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"client_id": "trusted-rp",
|
||||
"loginId": "employee001",
|
||||
"password": "password",
|
||||
"login_challenge": "challenge-123",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("expected 400, got %d, body: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadlessPasswordLogin_InvalidClientAssertionRejected(t *testing.T) {
|
||||
mockIdp := new(MockIdentityProvider)
|
||||
mockIdp.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{
|
||||
SessionToken: &domain.Token{JWT: "valid-jwt"},
|
||||
Subject: "kratos-identity-id",
|
||||
}, nil)
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil)
|
||||
|
||||
validKey, jwks := mustHeadlessRSAJWK(t)
|
||||
invalidKey, _ := mustHeadlessRSAJWK(t)
|
||||
_ = validKey
|
||||
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
|
||||
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||
Challenge: "challenge-123",
|
||||
Client: domain.HydraClient{
|
||||
ClientID: "trusted-rp",
|
||||
TokenEndpointAuthMethod: "private_key_jwt",
|
||||
JWKS: jwks,
|
||||
Metadata: map[string]interface{}{
|
||||
"status": "active",
|
||||
"headless_login_enabled": true,
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
|
||||
json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
h := &AuthHandler{
|
||||
IdpProvider: mockIdp,
|
||||
KratosAdmin: mockKratos,
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||
},
|
||||
}
|
||||
|
||||
app := newHeadlessPasswordLoginTestApp(h)
|
||||
|
||||
clientAssertion := mustHeadlessClientAssertion(
|
||||
t,
|
||||
invalidKey,
|
||||
"trusted-rp",
|
||||
"http://example.com/api/v1/auth/headless/password/login",
|
||||
)
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"client_id": "trusted-rp",
|
||||
"client_assertion": clientAssertion,
|
||||
"loginId": "employee001",
|
||||
"password": "password",
|
||||
"login_challenge": "challenge-123",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("expected 401, got %d, body: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadlessPasswordLogin_HeadlessDisabledRejected(t *testing.T) {
|
||||
mockIdp := new(MockIdentityProvider)
|
||||
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet {
|
||||
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||
Challenge: "challenge-123",
|
||||
Client: domain.HydraClient{
|
||||
ClientID: "trusted-rp",
|
||||
TokenEndpointAuthMethod: "private_key_jwt",
|
||||
JWKSUri: "https://rp.example.com/.well-known/jwks.json",
|
||||
Metadata: map[string]interface{}{
|
||||
"status": "active",
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
h := &AuthHandler{
|
||||
IdpProvider: mockIdp,
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||
},
|
||||
}
|
||||
|
||||
app := newHeadlessPasswordLoginTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"client_id": "trusted-rp",
|
||||
"loginId": "employee001",
|
||||
"password": "password",
|
||||
"login_challenge": "challenge-123",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("expected 403, got %d, body: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadlessPasswordLogin_ClientIDMismatchRejected(t *testing.T) {
|
||||
mockIdp := new(MockIdentityProvider)
|
||||
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet {
|
||||
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||
Challenge: "challenge-123",
|
||||
Client: domain.HydraClient{
|
||||
ClientID: "other-rp",
|
||||
TokenEndpointAuthMethod: "private_key_jwt",
|
||||
JWKSUri: "https://rp.example.com/.well-known/jwks.json",
|
||||
Metadata: map[string]interface{}{
|
||||
"status": "active",
|
||||
"headless_login_enabled": true,
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
h := &AuthHandler{
|
||||
IdpProvider: mockIdp,
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||
},
|
||||
}
|
||||
|
||||
app := newHeadlessPasswordLoginTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"client_id": "trusted-rp",
|
||||
"loginId": "employee001",
|
||||
"password": "password",
|
||||
"login_challenge": "challenge-123",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("expected 403, got %d, body: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordLogin_OIDC_InactiveClient(t *testing.T) {
|
||||
mockIdp := new(MockIdentityProvider)
|
||||
mockIdp.On("SignIn", "user@example.com", "password").Return(&domain.AuthInfo{
|
||||
|
||||
Reference in New Issue
Block a user