forked from baron/baron-sso
feat(auth): add trusted rp headless login flows
This commit is contained in:
@@ -25,6 +25,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
josejwt "github.com/go-jose/go-jose/v4/jwt"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
@@ -46,6 +48,7 @@ const (
|
||||
prefixLoginCodeSmsOnly = "login_code_sms_only:"
|
||||
prefixLoginCodeQrPending = "login_code_qr_pending:"
|
||||
prefixLoginCodeQr = "login_code_qr:"
|
||||
prefixHeadlessLinkState = "headless_link_state:"
|
||||
prefixPollMeta = "poll_meta:"
|
||||
prefixQrRef = "qr_ref:"
|
||||
prefixQrMeta = "qr_meta:"
|
||||
@@ -75,6 +78,7 @@ const (
|
||||
loginCodeExpiration = 10 * time.Minute
|
||||
linkResendCooldown = 60 * time.Second
|
||||
prefixDrySend = "dry_send:"
|
||||
headlessJWKSFetchTTL = 5 * time.Second
|
||||
)
|
||||
|
||||
type AuthHandler struct {
|
||||
@@ -100,6 +104,40 @@ type signupState struct {
|
||||
ExpiresAt int64 `json:"expires_at"` // Unix timestamp
|
||||
}
|
||||
|
||||
type headlessLinkState struct {
|
||||
ClientID string `json:"clientId"`
|
||||
LoginChallenge string `json:"loginChallenge"`
|
||||
LoginID string `json:"loginId"`
|
||||
RedirectTo string `json:"redirectTo,omitempty"`
|
||||
}
|
||||
|
||||
type headlessClientAssertionClaims struct {
|
||||
Issuer string `json:"iss"`
|
||||
Subject string `json:"sub"`
|
||||
Audience headlessAssertionAud `json:"aud"`
|
||||
ExpiresAt int64 `json:"exp"`
|
||||
IssuedAt int64 `json:"iat,omitempty"`
|
||||
NotBefore int64 `json:"nbf,omitempty"`
|
||||
ID string `json:"jti,omitempty"`
|
||||
}
|
||||
|
||||
type headlessAssertionAud []string
|
||||
|
||||
func (a *headlessAssertionAud) UnmarshalJSON(data []byte) error {
|
||||
var single string
|
||||
if err := json.Unmarshal(data, &single); err == nil {
|
||||
*a = []string{single}
|
||||
return nil
|
||||
}
|
||||
|
||||
var list []string
|
||||
if err := json.Unmarshal(data, &list); err != nil {
|
||||
return err
|
||||
}
|
||||
*a = list
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateSecureToken - Helper to generate secure random strings
|
||||
func GenerateSecureToken(length int) string {
|
||||
b := make([]byte, length)
|
||||
@@ -1224,87 +1262,9 @@ func (h *AuthHandler) PollEnchantedLink(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
if data["status"] == "approved" {
|
||||
loginID := data["loginId"]
|
||||
if loginID == "" {
|
||||
loginID = data["login_id"]
|
||||
}
|
||||
if loginID == "" {
|
||||
slog.Warn("[Poll] Approved but missing loginId", "pendingRef", req.PendingRef)
|
||||
return errorJSON(c, fiber.StatusBadRequest, "Invalid session reference")
|
||||
}
|
||||
if h.IdpProvider == nil {
|
||||
return errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
|
||||
}
|
||||
|
||||
loginStrategy := h.loadLoginStrategy(req.PendingRef)
|
||||
if loginStrategy == "" {
|
||||
loginStrategy = loginFlowLink
|
||||
}
|
||||
|
||||
var authInfo *domain.AuthInfo
|
||||
var err error
|
||||
if loginStrategy == loginFlowCode {
|
||||
code, _ := h.RedisService.Get(prefixLoginCodeValue + req.PendingRef)
|
||||
code = normalizeLoginCode(code)
|
||||
if code == "" {
|
||||
slog.Warn("[Poll] Missing login code for approved flow", "pendingRef", req.PendingRef)
|
||||
return errorJSON(c, fiber.StatusBadRequest, "Login code expired")
|
||||
}
|
||||
flowID, _ := h.RedisService.Get(prefixLoginCode + loginID)
|
||||
if flowID == "" {
|
||||
return errorJSON(c, fiber.StatusNotFound, "Login flow expired")
|
||||
}
|
||||
authInfo, err = h.IdpProvider.VerifyLoginCode(loginID, flowID, code)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
return errorJSON(c, fiber.StatusNotImplemented, "Login method not supported")
|
||||
}
|
||||
slog.Error("[Poll] IDP code verify failed", "error", err)
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "Failed to verify login code")
|
||||
}
|
||||
} else {
|
||||
authInfo, err = h.IdpProvider.IssueSession(loginID)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
return errorJSON(c, fiber.StatusNotImplemented, "Login method not supported")
|
||||
}
|
||||
slog.Error("[Poll] IDP session issue failed", "error", err)
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session")
|
||||
}
|
||||
}
|
||||
if authInfo == nil || authInfo.SessionToken == nil || authInfo.SessionToken.JWT == "" {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session")
|
||||
}
|
||||
|
||||
c.Locals("login_id", loginID)
|
||||
setSessionIDLocal(c, authInfo.SessionToken)
|
||||
sessionID := extractSessionIDFromToken(authInfo.SessionToken)
|
||||
if sessionID == "" && authInfo.SessionToken != nil && authInfo.SessionToken.JWT != "" {
|
||||
if resolved, err := h.getKratosSessionID(authInfo.SessionToken.JWT); err == nil && resolved != "" {
|
||||
sessionID = resolved
|
||||
authInfo.SessionToken.SessionID = resolved
|
||||
setSessionIDLocal(c, authInfo.SessionToken)
|
||||
}
|
||||
}
|
||||
|
||||
sessionData := map[string]string{
|
||||
"status": statusSuccess,
|
||||
"jwt": authInfo.SessionToken.JWT,
|
||||
}
|
||||
if sessionID != "" {
|
||||
sessionData["session_id"] = sessionID
|
||||
}
|
||||
sessionDataJSON, _ := json.Marshal(sessionData)
|
||||
h.RedisService.Set(prefixSession+req.PendingRef, string(sessionDataJSON), defaultExpiration)
|
||||
|
||||
h.writeLinkAuditLog(loginID, req.PendingRef, authInfo.SessionToken, c)
|
||||
h.clearLoginMeta(req.PendingRef)
|
||||
if loginStrategy == loginFlowCode {
|
||||
h.RedisService.Delete(prefixLoginCode + loginID)
|
||||
h.RedisService.Delete(prefixLoginCodePending + loginID)
|
||||
h.RedisService.Delete(prefixLoginCodeSmsTarget + loginID)
|
||||
h.RedisService.Delete(prefixLoginCodeSmsLookup + loginID)
|
||||
h.RedisService.Delete(prefixLoginCodeValue + req.PendingRef)
|
||||
_, authInfo, err := h.completeApprovedLinkLogin(c, req.PendingRef)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
@@ -1671,6 +1631,617 @@ func logOidcRedirectSummary(source, redirectTo string) {
|
||||
)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) authenticatePasswordLogin(ctx context.Context, loginID, password string) (*domain.AuthInfo, error) {
|
||||
if h.IdpProvider == nil {
|
||||
return nil, fmt.Errorf("authentication service not configured")
|
||||
}
|
||||
|
||||
authInfo, err := h.IdpProvider.SignIn(loginID, password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
subject, resolveErr := h.resolveKratosIdentityIDFromLoginID(ctx, loginID)
|
||||
if resolveErr != nil || subject == "" {
|
||||
slog.Error("Failed to resolve kratos identity after login", "loginID", loginID, "error", resolveErr)
|
||||
return nil, fmt.Errorf("failed to resolve user identity")
|
||||
}
|
||||
|
||||
authInfo.Subject = subject
|
||||
return authInfo, nil
|
||||
}
|
||||
|
||||
func passwordLoginErrorSpec(err error) (int, string, string) {
|
||||
if err == nil {
|
||||
return fiber.StatusOK, "", ""
|
||||
}
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
return fiber.StatusNotImplemented, "not_supported", "Login method not supported"
|
||||
}
|
||||
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "identity") {
|
||||
return fiber.StatusNotFound, "not_found", "User not registered"
|
||||
}
|
||||
if strings.Contains(err.Error(), "failed to resolve user identity") {
|
||||
return fiber.StatusInternalServerError, "internal_error", "Failed to resolve user identity"
|
||||
}
|
||||
return fiber.StatusUnauthorized, "password_or_email_mismatch", "Invalid credentials"
|
||||
}
|
||||
|
||||
func headlessAssertionAudiences(c *fiber.Ctx) []string {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
path := strings.TrimSpace(c.Path())
|
||||
if path == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
base := strings.TrimRight(strings.TrimSpace(c.BaseURL()), "/")
|
||||
if base == "" {
|
||||
return []string{path}
|
||||
}
|
||||
|
||||
return []string{base + path, path}
|
||||
}
|
||||
|
||||
func containsHeadlessAudience(expected []string, actual headlessAssertionAud) bool {
|
||||
for _, audience := range actual {
|
||||
for _, candidate := range expected {
|
||||
if strings.TrimSpace(audience) == strings.TrimSpace(candidate) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *AuthHandler) loadHeadlessJWKS(ctx context.Context, client domain.HydraClient) (*jose.JSONWebKeySet, error) {
|
||||
var raw []byte
|
||||
switch {
|
||||
case client.JWKS != nil:
|
||||
data, err := json.Marshal(client.JWKS)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode jwks: %w", err)
|
||||
}
|
||||
raw = data
|
||||
case strings.TrimSpace(client.JWKSUri) != "":
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, strings.TrimSpace(client.JWKSUri), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build jwks request: %w", err)
|
||||
}
|
||||
client := &http.Client{Timeout: headlessJWKSFetchTTL}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch jwks: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return nil, fmt.Errorf("failed to fetch jwks status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read jwks response: %w", err)
|
||||
}
|
||||
raw = body
|
||||
default:
|
||||
return nil, fmt.Errorf("trusted rp public key is not configured")
|
||||
}
|
||||
|
||||
var keySet jose.JSONWebKeySet
|
||||
if err := json.Unmarshal(raw, &keySet); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode jwks: %w", err)
|
||||
}
|
||||
if len(keySet.Keys) == 0 {
|
||||
return nil, fmt.Errorf("trusted rp jwks has no keys")
|
||||
}
|
||||
return &keySet, nil
|
||||
}
|
||||
|
||||
func validateHeadlessClientAssertionClaims(c *fiber.Ctx, claims headlessClientAssertionClaims, clientID string) error {
|
||||
now := time.Now().Unix()
|
||||
if claims.Issuer != clientID || claims.Subject != clientID {
|
||||
return fmt.Errorf("client assertion iss/sub mismatch")
|
||||
}
|
||||
if claims.ExpiresAt == 0 || claims.ExpiresAt <= now {
|
||||
return fmt.Errorf("client assertion expired")
|
||||
}
|
||||
if claims.NotBefore != 0 && claims.NotBefore > now {
|
||||
return fmt.Errorf("client assertion not active yet")
|
||||
}
|
||||
if claims.IssuedAt != 0 && claims.IssuedAt > now+60 {
|
||||
return fmt.Errorf("client assertion issued in the future")
|
||||
}
|
||||
if !containsHeadlessAudience(headlessAssertionAudiences(c), claims.Audience) {
|
||||
return fmt.Errorf("client assertion audience mismatch")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.HydraClient, clientID, clientAssertion string) error {
|
||||
assertion := strings.TrimSpace(clientAssertion)
|
||||
if assertion == "" {
|
||||
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_assertion is required")
|
||||
}
|
||||
|
||||
keySet, err := h.loadHeadlessJWKS(c.Context(), client)
|
||||
if err != nil {
|
||||
slog.Error("failed to load jwks for headless client assertion", "clientID", clientID, "error", err)
|
||||
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
|
||||
}
|
||||
|
||||
token, err := josejwt.ParseSigned(assertion, []jose.SignatureAlgorithm{
|
||||
jose.RS256, jose.RS384, jose.RS512,
|
||||
jose.PS256, jose.PS384, jose.PS512,
|
||||
jose.ES256, jose.ES384, jose.ES512,
|
||||
jose.EdDSA,
|
||||
})
|
||||
if err != nil {
|
||||
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
|
||||
}
|
||||
|
||||
expectedKid := ""
|
||||
if len(token.Headers) > 0 {
|
||||
expectedKid = strings.TrimSpace(token.Headers[0].KeyID)
|
||||
}
|
||||
|
||||
for _, key := range keySet.Keys {
|
||||
if expectedKid != "" && key.KeyID != "" && key.KeyID != expectedKid {
|
||||
continue
|
||||
}
|
||||
|
||||
var claims headlessClientAssertionClaims
|
||||
if err := token.Claims(key.Key, &claims); err != nil {
|
||||
continue
|
||||
}
|
||||
if err := validateHeadlessClientAssertionClaims(c, claims, clientID); err != nil {
|
||||
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
|
||||
}
|
||||
|
||||
func (h *AuthHandler) storeHeadlessLinkState(pendingRef string, state headlessLinkState, ttl time.Duration) {
|
||||
if h.RedisService == nil || pendingRef == "" {
|
||||
return
|
||||
}
|
||||
raw, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = h.RedisService.Set(prefixHeadlessLinkState+pendingRef, string(raw), ttl)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) loadHeadlessLinkState(pendingRef string) (headlessLinkState, bool) {
|
||||
if h.RedisService == nil || pendingRef == "" {
|
||||
return headlessLinkState{}, false
|
||||
}
|
||||
raw, err := h.RedisService.Get(prefixHeadlessLinkState + pendingRef)
|
||||
if err != nil || raw == "" {
|
||||
return headlessLinkState{}, false
|
||||
}
|
||||
var state headlessLinkState
|
||||
if err := json.Unmarshal([]byte(raw), &state); err != nil {
|
||||
return headlessLinkState{}, false
|
||||
}
|
||||
return state, true
|
||||
}
|
||||
|
||||
func (h *AuthHandler) completeApprovedLinkLogin(c *fiber.Ctx, pendingRef string) (string, *domain.AuthInfo, error) {
|
||||
val, err := h.RedisService.Get(prefixSession + pendingRef)
|
||||
if err != nil || val == "" {
|
||||
return "", nil, errorJSON(c, fiber.StatusBadRequest, "Invalid session reference")
|
||||
}
|
||||
|
||||
var data map[string]string
|
||||
_ = json.Unmarshal([]byte(val), &data)
|
||||
loginID := data["loginId"]
|
||||
if loginID == "" {
|
||||
loginID = data["login_id"]
|
||||
}
|
||||
if loginID == "" {
|
||||
slog.Warn("[Poll] Approved but missing loginId", "pendingRef", pendingRef)
|
||||
return "", nil, errorJSON(c, fiber.StatusBadRequest, "Invalid session reference")
|
||||
}
|
||||
if h.IdpProvider == nil {
|
||||
return "", nil, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
|
||||
}
|
||||
|
||||
loginStrategy := h.loadLoginStrategy(pendingRef)
|
||||
if loginStrategy == "" {
|
||||
loginStrategy = loginFlowLink
|
||||
}
|
||||
|
||||
var authInfo *domain.AuthInfo
|
||||
if loginStrategy == loginFlowCode {
|
||||
code, _ := h.RedisService.Get(prefixLoginCodeValue + pendingRef)
|
||||
code = normalizeLoginCode(code)
|
||||
if code == "" {
|
||||
slog.Warn("[Poll] Missing login code for approved flow", "pendingRef", pendingRef)
|
||||
return "", nil, errorJSON(c, fiber.StatusBadRequest, "Login code expired")
|
||||
}
|
||||
flowID, _ := h.RedisService.Get(prefixLoginCode + loginID)
|
||||
if flowID == "" {
|
||||
return "", nil, errorJSON(c, fiber.StatusNotFound, "Login flow expired")
|
||||
}
|
||||
authInfo, err = h.IdpProvider.VerifyLoginCode(loginID, flowID, code)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
return "", nil, errorJSON(c, fiber.StatusNotImplemented, "Login method not supported")
|
||||
}
|
||||
slog.Error("[Poll] IDP code verify failed", "error", err)
|
||||
return "", nil, errorJSON(c, fiber.StatusInternalServerError, "Failed to verify login code")
|
||||
}
|
||||
} else {
|
||||
authInfo, err = h.IdpProvider.IssueSession(loginID)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
return "", nil, errorJSON(c, fiber.StatusNotImplemented, "Login method not supported")
|
||||
}
|
||||
slog.Error("[Poll] IDP session issue failed", "error", err)
|
||||
return "", nil, errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session")
|
||||
}
|
||||
}
|
||||
if authInfo == nil || authInfo.SessionToken == nil || authInfo.SessionToken.JWT == "" {
|
||||
return "", nil, errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session")
|
||||
}
|
||||
|
||||
c.Locals("login_id", loginID)
|
||||
setSessionIDLocal(c, authInfo.SessionToken)
|
||||
sessionID := extractSessionIDFromToken(authInfo.SessionToken)
|
||||
if sessionID == "" && authInfo.SessionToken != nil && authInfo.SessionToken.JWT != "" {
|
||||
if resolved, err := h.getKratosSessionID(authInfo.SessionToken.JWT); err == nil && resolved != "" {
|
||||
sessionID = resolved
|
||||
authInfo.SessionToken.SessionID = resolved
|
||||
setSessionIDLocal(c, authInfo.SessionToken)
|
||||
}
|
||||
}
|
||||
|
||||
sessionData := map[string]string{
|
||||
"status": statusSuccess,
|
||||
"jwt": authInfo.SessionToken.JWT,
|
||||
}
|
||||
if sessionID != "" {
|
||||
sessionData["session_id"] = sessionID
|
||||
}
|
||||
sessionDataJSON, _ := json.Marshal(sessionData)
|
||||
_ = h.RedisService.Set(prefixSession+pendingRef, string(sessionDataJSON), defaultExpiration)
|
||||
|
||||
h.writeLinkAuditLog(loginID, pendingRef, authInfo.SessionToken, c)
|
||||
h.clearLoginMeta(pendingRef)
|
||||
if loginStrategy == loginFlowCode {
|
||||
_ = h.RedisService.Delete(prefixLoginCode + loginID)
|
||||
_ = h.RedisService.Delete(prefixLoginCodePending + loginID)
|
||||
_ = h.RedisService.Delete(prefixLoginCodeSmsTarget + loginID)
|
||||
_ = h.RedisService.Delete(prefixLoginCodeSmsLookup + loginID)
|
||||
_ = h.RedisService.Delete(prefixLoginCodeValue + pendingRef)
|
||||
}
|
||||
|
||||
return loginID, authInfo, nil
|
||||
}
|
||||
|
||||
func (h *AuthHandler) validateHeadlessPasswordLoginClient(loginReq *domain.HydraLoginRequest, clientID string) error {
|
||||
if loginReq == nil {
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request")
|
||||
}
|
||||
|
||||
if strings.TrimSpace(loginReq.Client.ClientID) != strings.TrimSpace(clientID) {
|
||||
return fiber.NewError(fiber.StatusForbidden, "The client application is not allowed to use this login request.")
|
||||
}
|
||||
|
||||
if metadata := loginReq.Client.Metadata; metadata != nil {
|
||||
if status, ok := metadata["status"].(string); ok && strings.ToLower(status) == "inactive" {
|
||||
return fiber.NewError(fiber.StatusForbidden, "The client application is disabled.")
|
||||
}
|
||||
}
|
||||
|
||||
if !loginReq.Client.IsHeadlessLoginEnabled() {
|
||||
return fiber.NewError(fiber.StatusForbidden, "The client application is not allowed to use headless password login.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *AuthHandler) HeadlessPasswordLogin(c *fiber.Ctx) error {
|
||||
var req struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientAssertion string `json:"client_assertion"`
|
||||
LoginID string `json:"loginId"`
|
||||
Password string `json:"password"`
|
||||
LoginChallenge string `json:"login_challenge"`
|
||||
}
|
||||
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "Invalid request body")
|
||||
}
|
||||
|
||||
clientID := strings.TrimSpace(req.ClientID)
|
||||
loginID := strings.TrimSpace(req.LoginID)
|
||||
loginChallenge := strings.TrimSpace(req.LoginChallenge)
|
||||
if clientID == "" || loginID == "" || strings.TrimSpace(req.Password) == "" || loginChallenge == "" {
|
||||
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_id, loginId, password and login_challenge are required")
|
||||
}
|
||||
|
||||
if h.IdpProvider == nil || h.Hydra == nil {
|
||||
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
|
||||
}
|
||||
|
||||
loginReq, err := h.Hydra.GetLoginRequest(c.Context(), loginChallenge)
|
||||
if err != nil {
|
||||
slog.Error("failed to get hydra login request for headless password login", "error", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request")
|
||||
}
|
||||
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
authInfo, authErr := h.authenticatePasswordLogin(c.Context(), loginID, req.Password)
|
||||
if authErr != nil {
|
||||
status, code, message := passwordLoginErrorSpec(authErr)
|
||||
return errorJSONCode(c, status, code, message)
|
||||
}
|
||||
|
||||
setSessionIDLocal(c, authInfo.SessionToken)
|
||||
|
||||
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), loginChallenge, authInfo.Subject)
|
||||
if err != nil {
|
||||
slog.Error("failed to accept hydra login request in headless password login", "error", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request")
|
||||
}
|
||||
|
||||
logOidcRedirectSummary("headless_password_login", acceptResp.RedirectTo)
|
||||
return c.JSON(fiber.Map{
|
||||
"redirectTo": acceptResp.RedirectTo,
|
||||
"status": "ok",
|
||||
"provider": h.IdpProvider.Name(),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) startHeadlessPhoneLink(c *fiber.Ctx, loginID string) (fiber.Map, string, string, time.Duration, error) {
|
||||
rawLoginID := strings.ReplaceAll(loginID, "-", "")
|
||||
rawLoginID = strings.ReplaceAll(rawLoginID, " ", "")
|
||||
if rawLoginID == "" || strings.Contains(rawLoginID, "@") {
|
||||
return nil, "", "", 0, errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "phone-based loginId is required")
|
||||
}
|
||||
|
||||
lookupLoginID := normalizePhoneForLoginID(rawLoginID)
|
||||
if h.IdpProvider == nil {
|
||||
return nil, "", "", 0, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
|
||||
}
|
||||
exists, err := h.IdpProvider.UserExists(lookupLoginID)
|
||||
if err != nil {
|
||||
slog.Warn("[HeadlessLink] IDP user lookup failed", "loginID", rawLoginID, "error", err)
|
||||
return nil, "", "", 0, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
|
||||
}
|
||||
if !exists {
|
||||
slog.Warn("[HeadlessLink] User not found", "loginID", rawLoginID)
|
||||
return nil, "", "", 0, errorJSON(c, fiber.StatusNotFound, "User not registered")
|
||||
}
|
||||
|
||||
userfrontURL := h.resolveUserfrontURL(c)
|
||||
if init, err := h.IdpProvider.InitiateLinkLogin(lookupLoginID, userfrontURL); err == nil && init != nil && init.Mode != "" {
|
||||
keyLoginID := lookupLoginID
|
||||
if init.LoginID != "" {
|
||||
keyLoginID = init.LoginID
|
||||
}
|
||||
if init.FlowID != "" {
|
||||
_ = h.RedisService.Set(prefixLoginCode+keyLoginID, init.FlowID, loginCodeExpiration)
|
||||
}
|
||||
pendingRef := GenerateSecureToken(3)
|
||||
sessionData, _ := json.Marshal(map[string]string{
|
||||
"status": statusPending,
|
||||
"loginId": keyLoginID,
|
||||
})
|
||||
_ = h.RedisService.Set(prefixSession+pendingRef, string(sessionData), loginCodeExpiration)
|
||||
h.storeLoginMeta(pendingRef, rawLoginID, "sms", loginFlowLink, loginFlowCode, loginCodeExpiration)
|
||||
_ = h.RedisService.Set(prefixLoginCodePending+keyLoginID, pendingRef, loginCodeExpiration)
|
||||
if keyLoginID != lookupLoginID {
|
||||
_ = h.RedisService.Set(prefixLoginCodeSmsTarget+keyLoginID, lookupLoginID, loginCodeExpiration)
|
||||
_ = h.RedisService.Set(prefixLoginCodeSmsLookup+lookupLoginID, keyLoginID, loginCodeExpiration)
|
||||
}
|
||||
expiresIn := int(loginCodeExpiration.Seconds())
|
||||
if !init.ExpiresAt.IsZero() {
|
||||
if seconds := int(time.Until(init.ExpiresAt).Seconds()); seconds > 0 {
|
||||
expiresIn = seconds
|
||||
}
|
||||
}
|
||||
return fiber.Map{
|
||||
"pendingRef": pendingRef,
|
||||
"status": "pending",
|
||||
"mode": init.Mode,
|
||||
"provider": h.IdpProvider.Name(),
|
||||
"expiresIn": expiresIn,
|
||||
"interval": int(minPollInterval.Seconds()),
|
||||
"resendAfter": int(linkResendCooldown.Seconds()),
|
||||
}, pendingRef, keyLoginID, loginCodeExpiration, nil
|
||||
} else if err != nil && !errors.Is(err, domain.ErrNotSupported) {
|
||||
slog.Error("[HeadlessLink] Link login init failed", "provider", h.IdpProvider.Name(), "error", err)
|
||||
return nil, "", "", 0, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
|
||||
}
|
||||
|
||||
if h.SmsService == nil {
|
||||
return nil, "", "", 0, errorJSON(c, fiber.StatusInternalServerError, "SMS service not configured")
|
||||
}
|
||||
|
||||
token := GenerateSecureToken(3)
|
||||
pendingRef := GenerateSecureToken(3)
|
||||
sessionData, _ := json.Marshal(map[string]string{
|
||||
"status": statusPending,
|
||||
"loginId": lookupLoginID,
|
||||
})
|
||||
_ = h.RedisService.Set(prefixSession+pendingRef, string(sessionData), defaultExpiration)
|
||||
_ = h.RedisService.Set(prefixToken+token, fmt.Sprintf(`{"pendingRef":"%s","loginId":"%s"}`, pendingRef, lookupLoginID), defaultExpiration)
|
||||
h.storeLoginMeta(pendingRef, rawLoginID, "sms", loginFlowLink, loginFlowLink, defaultExpiration)
|
||||
|
||||
link := fmt.Sprintf("%s/verify/%s", strings.TrimRight(userfrontURL, "/"), token)
|
||||
content := fmt.Sprintf("[Baron 로그인] 로그인 링크: %s", link)
|
||||
if err := h.SmsService.SendSms(rawLoginID, content); err != nil {
|
||||
slog.Error("[HeadlessLink] SMS send failed", "error", err)
|
||||
return nil, "", "", 0, errorJSON(c, fiber.StatusInternalServerError, "Failed to send SMS")
|
||||
}
|
||||
|
||||
return fiber.Map{
|
||||
"pendingRef": pendingRef,
|
||||
"status": "pending",
|
||||
"provider": h.IdpProvider.Name(),
|
||||
"expiresIn": int(defaultExpiration.Seconds()),
|
||||
"interval": int(minPollInterval.Seconds()),
|
||||
"resendAfter": int(linkResendCooldown.Seconds()),
|
||||
}, pendingRef, lookupLoginID, defaultExpiration, nil
|
||||
}
|
||||
|
||||
func (h *AuthHandler) HeadlessLinkInit(c *fiber.Ctx) error {
|
||||
var req struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientAssertion string `json:"client_assertion"`
|
||||
LoginID string `json:"loginId"`
|
||||
LoginChallenge string `json:"login_challenge"`
|
||||
}
|
||||
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "Invalid request body")
|
||||
}
|
||||
|
||||
clientID := strings.TrimSpace(req.ClientID)
|
||||
loginChallenge := strings.TrimSpace(req.LoginChallenge)
|
||||
loginID := strings.TrimSpace(req.LoginID)
|
||||
if clientID == "" || loginChallenge == "" || loginID == "" {
|
||||
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_id, client_assertion, loginId and login_challenge are required")
|
||||
}
|
||||
if h.Hydra == nil {
|
||||
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
|
||||
}
|
||||
|
||||
loginReq, err := h.Hydra.GetLoginRequest(c.Context(), loginChallenge)
|
||||
if err != nil {
|
||||
slog.Error("failed to get hydra login request for headless link init", "error", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request")
|
||||
}
|
||||
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, pendingRef, resolvedLoginID, ttl, err := h.startHeadlessPhoneLink(c, loginID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.storeHeadlessLinkState(pendingRef, headlessLinkState{
|
||||
ClientID: clientID,
|
||||
LoginChallenge: loginChallenge,
|
||||
LoginID: resolvedLoginID,
|
||||
}, ttl)
|
||||
return c.JSON(resp)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) HeadlessLinkPoll(c *fiber.Ctx) error {
|
||||
var req struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientAssertion string `json:"client_assertion"`
|
||||
PendingRef string `json:"pendingRef"`
|
||||
}
|
||||
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "Invalid request body")
|
||||
}
|
||||
|
||||
clientID := strings.TrimSpace(req.ClientID)
|
||||
pendingRef := strings.TrimSpace(req.PendingRef)
|
||||
if clientID == "" || pendingRef == "" {
|
||||
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_id, client_assertion and pendingRef are required")
|
||||
}
|
||||
|
||||
state, ok := h.loadHeadlessLinkState(pendingRef)
|
||||
if !ok {
|
||||
return c.JSON(fiber.Map{
|
||||
"error": "expired_token",
|
||||
"code": "expired_token",
|
||||
})
|
||||
}
|
||||
if state.ClientID != clientID {
|
||||
return fiber.NewError(fiber.StatusForbidden, "The client application is not allowed to use this pending login.")
|
||||
}
|
||||
if state.RedirectTo != "" {
|
||||
return c.JSON(fiber.Map{
|
||||
"redirectTo": state.RedirectTo,
|
||||
"status": "ok",
|
||||
})
|
||||
}
|
||||
if h.Hydra == nil {
|
||||
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
|
||||
}
|
||||
|
||||
loginReq, err := h.Hydra.GetLoginRequest(c.Context(), state.LoginChallenge)
|
||||
if err != nil {
|
||||
slog.Error("failed to get hydra login request for headless link poll", "error", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request")
|
||||
}
|
||||
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val, err := h.RedisService.Get(prefixSession + pendingRef)
|
||||
if err != nil || val == "" {
|
||||
return c.JSON(fiber.Map{
|
||||
"error": "expired_token",
|
||||
"code": "expired_token",
|
||||
})
|
||||
}
|
||||
|
||||
var session map[string]string
|
||||
_ = json.Unmarshal([]byte(val), &session)
|
||||
if session["status"] == statusPending {
|
||||
return c.JSON(fiber.Map{
|
||||
"error": "authorization_pending",
|
||||
"code": "authorization_pending",
|
||||
"interval": int(minPollInterval.Seconds()),
|
||||
})
|
||||
}
|
||||
|
||||
loginID := state.LoginID
|
||||
if session["status"] == "approved" {
|
||||
completedLoginID, _, err := h.completeApprovedLinkLogin(c, pendingRef)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
loginID = completedLoginID
|
||||
}
|
||||
|
||||
if loginID == "" {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "Failed to resolve approved user identity")
|
||||
}
|
||||
subject, err := h.resolveKratosIdentityIDFromLoginID(c.Context(), loginID)
|
||||
if err != nil || subject == "" {
|
||||
slog.Error("failed to resolve kratos identity for headless link poll", "loginID", loginID, "error", err)
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "Failed to resolve user identity")
|
||||
}
|
||||
|
||||
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), state.LoginChallenge, subject)
|
||||
if err != nil {
|
||||
slog.Error("failed to accept hydra login request in headless link poll", "error", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request")
|
||||
}
|
||||
|
||||
state.RedirectTo = acceptResp.RedirectTo
|
||||
h.storeHeadlessLinkState(pendingRef, state, defaultExpiration)
|
||||
logOidcRedirectSummary("headless_link_poll", acceptResp.RedirectTo)
|
||||
return c.JSON(fiber.Map{
|
||||
"redirectTo": acceptResp.RedirectTo,
|
||||
"status": "ok",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
|
||||
startTime := time.Now()
|
||||
ale := logger.NewAuditLogEntry(c, "login")
|
||||
@@ -1704,28 +2275,16 @@ func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
|
||||
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
|
||||
}
|
||||
|
||||
authInfo, err := h.IdpProvider.SignIn(loginID, req.Password)
|
||||
authInfo, err := h.authenticatePasswordLogin(c.Context(), loginID, req.Password)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
return errorJSONCode(c, fiber.StatusNotImplemented, "not_supported", "Login method not supported")
|
||||
}
|
||||
ale.Status = fiber.StatusUnauthorized
|
||||
status, code, message := passwordLoginErrorSpec(err)
|
||||
ale.Status = status
|
||||
ale.LatencyMs = time.Since(startTime)
|
||||
ale.ProviderError = err.Error()
|
||||
ale.Log(slog.LevelWarn, "IDP sign-in failed", slog.String("provider", h.IdpProvider.Name()))
|
||||
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "identity") {
|
||||
return errorJSONCode(c, fiber.StatusNotFound, "not_found", "User not registered")
|
||||
}
|
||||
return errorJSONCode(c, fiber.StatusUnauthorized, "password_or_email_mismatch", "Invalid credentials")
|
||||
return errorJSONCode(c, status, code, message)
|
||||
}
|
||||
|
||||
subject, resolveErr := h.resolveKratosIdentityIDFromLoginID(c.Context(), loginID)
|
||||
if resolveErr != nil || subject == "" {
|
||||
slog.Error("Failed to resolve kratos identity after login", "loginID", loginID, "error", resolveErr)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "Failed to resolve user identity")
|
||||
}
|
||||
authInfo.Subject = subject
|
||||
|
||||
ale.Status = fiber.StatusOK
|
||||
ale.LatencyMs = time.Since(startTime)
|
||||
setSessionIDLocal(c, authInfo.SessionToken)
|
||||
@@ -1746,7 +2305,7 @@ func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
|
||||
}
|
||||
}
|
||||
|
||||
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), req.LoginChallenge, subject)
|
||||
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), req.LoginChallenge, authInfo.Subject)
|
||||
if err != nil {
|
||||
slog.Error("failed to accept hydra login request", "error", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request")
|
||||
|
||||
Reference in New Issue
Block a user