forked from baron/baron-sso
feat(auth): add trusted rp headless login flows
This commit is contained in:
@@ -526,6 +526,9 @@ func main() {
|
|||||||
auth.Post("/login/code/verify", authHandler.VerifyLoginCode)
|
auth.Post("/login/code/verify", authHandler.VerifyLoginCode)
|
||||||
auth.Post("/login/code/verify-short", authHandler.VerifyLoginShortCode)
|
auth.Post("/login/code/verify-short", authHandler.VerifyLoginShortCode)
|
||||||
auth.Post("/password/login", authHandler.PasswordLogin)
|
auth.Post("/password/login", authHandler.PasswordLogin)
|
||||||
|
auth.Post("/headless/password/login", authHandler.HeadlessPasswordLogin)
|
||||||
|
auth.Post("/headless/link/init", authHandler.HeadlessLinkInit)
|
||||||
|
auth.Post("/headless/link/poll", authHandler.HeadlessLinkPoll)
|
||||||
auth.Get("/tenant-info", authHandler.GetTenantInfo)
|
auth.Get("/tenant-info", authHandler.GetTenantInfo)
|
||||||
auth.Get("/consent", authHandler.GetConsentRequest)
|
auth.Get("/consent", authHandler.GetConsentRequest)
|
||||||
auth.Post("/consent/accept", authHandler.AcceptConsentRequest)
|
auth.Post("/consent/accept", authHandler.AcceptConsentRequest)
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-jose/go-jose/v4"
|
||||||
|
josejwt "github.com/go-jose/go-jose/v4/jwt"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -46,6 +48,7 @@ const (
|
|||||||
prefixLoginCodeSmsOnly = "login_code_sms_only:"
|
prefixLoginCodeSmsOnly = "login_code_sms_only:"
|
||||||
prefixLoginCodeQrPending = "login_code_qr_pending:"
|
prefixLoginCodeQrPending = "login_code_qr_pending:"
|
||||||
prefixLoginCodeQr = "login_code_qr:"
|
prefixLoginCodeQr = "login_code_qr:"
|
||||||
|
prefixHeadlessLinkState = "headless_link_state:"
|
||||||
prefixPollMeta = "poll_meta:"
|
prefixPollMeta = "poll_meta:"
|
||||||
prefixQrRef = "qr_ref:"
|
prefixQrRef = "qr_ref:"
|
||||||
prefixQrMeta = "qr_meta:"
|
prefixQrMeta = "qr_meta:"
|
||||||
@@ -75,6 +78,7 @@ const (
|
|||||||
loginCodeExpiration = 10 * time.Minute
|
loginCodeExpiration = 10 * time.Minute
|
||||||
linkResendCooldown = 60 * time.Second
|
linkResendCooldown = 60 * time.Second
|
||||||
prefixDrySend = "dry_send:"
|
prefixDrySend = "dry_send:"
|
||||||
|
headlessJWKSFetchTTL = 5 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
type AuthHandler struct {
|
type AuthHandler struct {
|
||||||
@@ -100,6 +104,40 @@ type signupState struct {
|
|||||||
ExpiresAt int64 `json:"expires_at"` // Unix timestamp
|
ExpiresAt int64 `json:"expires_at"` // Unix timestamp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type headlessLinkState struct {
|
||||||
|
ClientID string `json:"clientId"`
|
||||||
|
LoginChallenge string `json:"loginChallenge"`
|
||||||
|
LoginID string `json:"loginId"`
|
||||||
|
RedirectTo string `json:"redirectTo,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type headlessClientAssertionClaims struct {
|
||||||
|
Issuer string `json:"iss"`
|
||||||
|
Subject string `json:"sub"`
|
||||||
|
Audience headlessAssertionAud `json:"aud"`
|
||||||
|
ExpiresAt int64 `json:"exp"`
|
||||||
|
IssuedAt int64 `json:"iat,omitempty"`
|
||||||
|
NotBefore int64 `json:"nbf,omitempty"`
|
||||||
|
ID string `json:"jti,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type headlessAssertionAud []string
|
||||||
|
|
||||||
|
func (a *headlessAssertionAud) UnmarshalJSON(data []byte) error {
|
||||||
|
var single string
|
||||||
|
if err := json.Unmarshal(data, &single); err == nil {
|
||||||
|
*a = []string{single}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var list []string
|
||||||
|
if err := json.Unmarshal(data, &list); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*a = list
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GenerateSecureToken - Helper to generate secure random strings
|
// GenerateSecureToken - Helper to generate secure random strings
|
||||||
func GenerateSecureToken(length int) string {
|
func GenerateSecureToken(length int) string {
|
||||||
b := make([]byte, length)
|
b := make([]byte, length)
|
||||||
@@ -1224,87 +1262,9 @@ func (h *AuthHandler) PollEnchantedLink(c *fiber.Ctx) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if data["status"] == "approved" {
|
if data["status"] == "approved" {
|
||||||
loginID := data["loginId"]
|
_, authInfo, err := h.completeApprovedLinkLogin(c, req.PendingRef)
|
||||||
if loginID == "" {
|
if err != nil {
|
||||||
loginID = data["login_id"]
|
return err
|
||||||
}
|
|
||||||
if loginID == "" {
|
|
||||||
slog.Warn("[Poll] Approved but missing loginId", "pendingRef", req.PendingRef)
|
|
||||||
return errorJSON(c, fiber.StatusBadRequest, "Invalid session reference")
|
|
||||||
}
|
|
||||||
if h.IdpProvider == nil {
|
|
||||||
return errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
|
|
||||||
}
|
|
||||||
|
|
||||||
loginStrategy := h.loadLoginStrategy(req.PendingRef)
|
|
||||||
if loginStrategy == "" {
|
|
||||||
loginStrategy = loginFlowLink
|
|
||||||
}
|
|
||||||
|
|
||||||
var authInfo *domain.AuthInfo
|
|
||||||
var err error
|
|
||||||
if loginStrategy == loginFlowCode {
|
|
||||||
code, _ := h.RedisService.Get(prefixLoginCodeValue + req.PendingRef)
|
|
||||||
code = normalizeLoginCode(code)
|
|
||||||
if code == "" {
|
|
||||||
slog.Warn("[Poll] Missing login code for approved flow", "pendingRef", req.PendingRef)
|
|
||||||
return errorJSON(c, fiber.StatusBadRequest, "Login code expired")
|
|
||||||
}
|
|
||||||
flowID, _ := h.RedisService.Get(prefixLoginCode + loginID)
|
|
||||||
if flowID == "" {
|
|
||||||
return errorJSON(c, fiber.StatusNotFound, "Login flow expired")
|
|
||||||
}
|
|
||||||
authInfo, err = h.IdpProvider.VerifyLoginCode(loginID, flowID, code)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, domain.ErrNotSupported) {
|
|
||||||
return errorJSON(c, fiber.StatusNotImplemented, "Login method not supported")
|
|
||||||
}
|
|
||||||
slog.Error("[Poll] IDP code verify failed", "error", err)
|
|
||||||
return errorJSON(c, fiber.StatusInternalServerError, "Failed to verify login code")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
authInfo, err = h.IdpProvider.IssueSession(loginID)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, domain.ErrNotSupported) {
|
|
||||||
return errorJSON(c, fiber.StatusNotImplemented, "Login method not supported")
|
|
||||||
}
|
|
||||||
slog.Error("[Poll] IDP session issue failed", "error", err)
|
|
||||||
return errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if authInfo == nil || authInfo.SessionToken == nil || authInfo.SessionToken.JWT == "" {
|
|
||||||
return errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session")
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Locals("login_id", loginID)
|
|
||||||
setSessionIDLocal(c, authInfo.SessionToken)
|
|
||||||
sessionID := extractSessionIDFromToken(authInfo.SessionToken)
|
|
||||||
if sessionID == "" && authInfo.SessionToken != nil && authInfo.SessionToken.JWT != "" {
|
|
||||||
if resolved, err := h.getKratosSessionID(authInfo.SessionToken.JWT); err == nil && resolved != "" {
|
|
||||||
sessionID = resolved
|
|
||||||
authInfo.SessionToken.SessionID = resolved
|
|
||||||
setSessionIDLocal(c, authInfo.SessionToken)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionData := map[string]string{
|
|
||||||
"status": statusSuccess,
|
|
||||||
"jwt": authInfo.SessionToken.JWT,
|
|
||||||
}
|
|
||||||
if sessionID != "" {
|
|
||||||
sessionData["session_id"] = sessionID
|
|
||||||
}
|
|
||||||
sessionDataJSON, _ := json.Marshal(sessionData)
|
|
||||||
h.RedisService.Set(prefixSession+req.PendingRef, string(sessionDataJSON), defaultExpiration)
|
|
||||||
|
|
||||||
h.writeLinkAuditLog(loginID, req.PendingRef, authInfo.SessionToken, c)
|
|
||||||
h.clearLoginMeta(req.PendingRef)
|
|
||||||
if loginStrategy == loginFlowCode {
|
|
||||||
h.RedisService.Delete(prefixLoginCode + loginID)
|
|
||||||
h.RedisService.Delete(prefixLoginCodePending + loginID)
|
|
||||||
h.RedisService.Delete(prefixLoginCodeSmsTarget + loginID)
|
|
||||||
h.RedisService.Delete(prefixLoginCodeSmsLookup + loginID)
|
|
||||||
h.RedisService.Delete(prefixLoginCodeValue + req.PendingRef)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(fiber.Map{
|
return c.JSON(fiber.Map{
|
||||||
@@ -1671,6 +1631,617 @@ func logOidcRedirectSummary(source, redirectTo string) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) authenticatePasswordLogin(ctx context.Context, loginID, password string) (*domain.AuthInfo, error) {
|
||||||
|
if h.IdpProvider == nil {
|
||||||
|
return nil, fmt.Errorf("authentication service not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
authInfo, err := h.IdpProvider.SignIn(loginID, password)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
subject, resolveErr := h.resolveKratosIdentityIDFromLoginID(ctx, loginID)
|
||||||
|
if resolveErr != nil || subject == "" {
|
||||||
|
slog.Error("Failed to resolve kratos identity after login", "loginID", loginID, "error", resolveErr)
|
||||||
|
return nil, fmt.Errorf("failed to resolve user identity")
|
||||||
|
}
|
||||||
|
|
||||||
|
authInfo.Subject = subject
|
||||||
|
return authInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func passwordLoginErrorSpec(err error) (int, string, string) {
|
||||||
|
if err == nil {
|
||||||
|
return fiber.StatusOK, "", ""
|
||||||
|
}
|
||||||
|
if errors.Is(err, domain.ErrNotSupported) {
|
||||||
|
return fiber.StatusNotImplemented, "not_supported", "Login method not supported"
|
||||||
|
}
|
||||||
|
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "identity") {
|
||||||
|
return fiber.StatusNotFound, "not_found", "User not registered"
|
||||||
|
}
|
||||||
|
if strings.Contains(err.Error(), "failed to resolve user identity") {
|
||||||
|
return fiber.StatusInternalServerError, "internal_error", "Failed to resolve user identity"
|
||||||
|
}
|
||||||
|
return fiber.StatusUnauthorized, "password_or_email_mismatch", "Invalid credentials"
|
||||||
|
}
|
||||||
|
|
||||||
|
func headlessAssertionAudiences(c *fiber.Ctx) []string {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
path := strings.TrimSpace(c.Path())
|
||||||
|
if path == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
base := strings.TrimRight(strings.TrimSpace(c.BaseURL()), "/")
|
||||||
|
if base == "" {
|
||||||
|
return []string{path}
|
||||||
|
}
|
||||||
|
|
||||||
|
return []string{base + path, path}
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsHeadlessAudience(expected []string, actual headlessAssertionAud) bool {
|
||||||
|
for _, audience := range actual {
|
||||||
|
for _, candidate := range expected {
|
||||||
|
if strings.TrimSpace(audience) == strings.TrimSpace(candidate) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) loadHeadlessJWKS(ctx context.Context, client domain.HydraClient) (*jose.JSONWebKeySet, error) {
|
||||||
|
var raw []byte
|
||||||
|
switch {
|
||||||
|
case client.JWKS != nil:
|
||||||
|
data, err := json.Marshal(client.JWKS)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to encode jwks: %w", err)
|
||||||
|
}
|
||||||
|
raw = data
|
||||||
|
case strings.TrimSpace(client.JWKSUri) != "":
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, strings.TrimSpace(client.JWKSUri), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to build jwks request: %w", err)
|
||||||
|
}
|
||||||
|
client := &http.Client{Timeout: headlessJWKSFetchTTL}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to fetch jwks: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode >= 300 {
|
||||||
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||||
|
return nil, fmt.Errorf("failed to fetch jwks status=%d body=%s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read jwks response: %w", err)
|
||||||
|
}
|
||||||
|
raw = body
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("trusted rp public key is not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
var keySet jose.JSONWebKeySet
|
||||||
|
if err := json.Unmarshal(raw, &keySet); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode jwks: %w", err)
|
||||||
|
}
|
||||||
|
if len(keySet.Keys) == 0 {
|
||||||
|
return nil, fmt.Errorf("trusted rp jwks has no keys")
|
||||||
|
}
|
||||||
|
return &keySet, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateHeadlessClientAssertionClaims(c *fiber.Ctx, claims headlessClientAssertionClaims, clientID string) error {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
if claims.Issuer != clientID || claims.Subject != clientID {
|
||||||
|
return fmt.Errorf("client assertion iss/sub mismatch")
|
||||||
|
}
|
||||||
|
if claims.ExpiresAt == 0 || claims.ExpiresAt <= now {
|
||||||
|
return fmt.Errorf("client assertion expired")
|
||||||
|
}
|
||||||
|
if claims.NotBefore != 0 && claims.NotBefore > now {
|
||||||
|
return fmt.Errorf("client assertion not active yet")
|
||||||
|
}
|
||||||
|
if claims.IssuedAt != 0 && claims.IssuedAt > now+60 {
|
||||||
|
return fmt.Errorf("client assertion issued in the future")
|
||||||
|
}
|
||||||
|
if !containsHeadlessAudience(headlessAssertionAudiences(c), claims.Audience) {
|
||||||
|
return fmt.Errorf("client assertion audience mismatch")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.HydraClient, clientID, clientAssertion string) error {
|
||||||
|
assertion := strings.TrimSpace(clientAssertion)
|
||||||
|
if assertion == "" {
|
||||||
|
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_assertion is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
keySet, err := h.loadHeadlessJWKS(c.Context(), client)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to load jwks for headless client assertion", "clientID", clientID, "error", err)
|
||||||
|
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := josejwt.ParseSigned(assertion, []jose.SignatureAlgorithm{
|
||||||
|
jose.RS256, jose.RS384, jose.RS512,
|
||||||
|
jose.PS256, jose.PS384, jose.PS512,
|
||||||
|
jose.ES256, jose.ES384, jose.ES512,
|
||||||
|
jose.EdDSA,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedKid := ""
|
||||||
|
if len(token.Headers) > 0 {
|
||||||
|
expectedKid = strings.TrimSpace(token.Headers[0].KeyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key := range keySet.Keys {
|
||||||
|
if expectedKid != "" && key.KeyID != "" && key.KeyID != expectedKid {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims headlessClientAssertionClaims
|
||||||
|
if err := token.Claims(key.Key, &claims); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := validateHeadlessClientAssertionClaims(c, claims, clientID); err != nil {
|
||||||
|
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) storeHeadlessLinkState(pendingRef string, state headlessLinkState, ttl time.Duration) {
|
||||||
|
if h.RedisService == nil || pendingRef == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
raw, err := json.Marshal(state)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = h.RedisService.Set(prefixHeadlessLinkState+pendingRef, string(raw), ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) loadHeadlessLinkState(pendingRef string) (headlessLinkState, bool) {
|
||||||
|
if h.RedisService == nil || pendingRef == "" {
|
||||||
|
return headlessLinkState{}, false
|
||||||
|
}
|
||||||
|
raw, err := h.RedisService.Get(prefixHeadlessLinkState + pendingRef)
|
||||||
|
if err != nil || raw == "" {
|
||||||
|
return headlessLinkState{}, false
|
||||||
|
}
|
||||||
|
var state headlessLinkState
|
||||||
|
if err := json.Unmarshal([]byte(raw), &state); err != nil {
|
||||||
|
return headlessLinkState{}, false
|
||||||
|
}
|
||||||
|
return state, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) completeApprovedLinkLogin(c *fiber.Ctx, pendingRef string) (string, *domain.AuthInfo, error) {
|
||||||
|
val, err := h.RedisService.Get(prefixSession + pendingRef)
|
||||||
|
if err != nil || val == "" {
|
||||||
|
return "", nil, errorJSON(c, fiber.StatusBadRequest, "Invalid session reference")
|
||||||
|
}
|
||||||
|
|
||||||
|
var data map[string]string
|
||||||
|
_ = json.Unmarshal([]byte(val), &data)
|
||||||
|
loginID := data["loginId"]
|
||||||
|
if loginID == "" {
|
||||||
|
loginID = data["login_id"]
|
||||||
|
}
|
||||||
|
if loginID == "" {
|
||||||
|
slog.Warn("[Poll] Approved but missing loginId", "pendingRef", pendingRef)
|
||||||
|
return "", nil, errorJSON(c, fiber.StatusBadRequest, "Invalid session reference")
|
||||||
|
}
|
||||||
|
if h.IdpProvider == nil {
|
||||||
|
return "", nil, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
|
||||||
|
}
|
||||||
|
|
||||||
|
loginStrategy := h.loadLoginStrategy(pendingRef)
|
||||||
|
if loginStrategy == "" {
|
||||||
|
loginStrategy = loginFlowLink
|
||||||
|
}
|
||||||
|
|
||||||
|
var authInfo *domain.AuthInfo
|
||||||
|
if loginStrategy == loginFlowCode {
|
||||||
|
code, _ := h.RedisService.Get(prefixLoginCodeValue + pendingRef)
|
||||||
|
code = normalizeLoginCode(code)
|
||||||
|
if code == "" {
|
||||||
|
slog.Warn("[Poll] Missing login code for approved flow", "pendingRef", pendingRef)
|
||||||
|
return "", nil, errorJSON(c, fiber.StatusBadRequest, "Login code expired")
|
||||||
|
}
|
||||||
|
flowID, _ := h.RedisService.Get(prefixLoginCode + loginID)
|
||||||
|
if flowID == "" {
|
||||||
|
return "", nil, errorJSON(c, fiber.StatusNotFound, "Login flow expired")
|
||||||
|
}
|
||||||
|
authInfo, err = h.IdpProvider.VerifyLoginCode(loginID, flowID, code)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, domain.ErrNotSupported) {
|
||||||
|
return "", nil, errorJSON(c, fiber.StatusNotImplemented, "Login method not supported")
|
||||||
|
}
|
||||||
|
slog.Error("[Poll] IDP code verify failed", "error", err)
|
||||||
|
return "", nil, errorJSON(c, fiber.StatusInternalServerError, "Failed to verify login code")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
authInfo, err = h.IdpProvider.IssueSession(loginID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, domain.ErrNotSupported) {
|
||||||
|
return "", nil, errorJSON(c, fiber.StatusNotImplemented, "Login method not supported")
|
||||||
|
}
|
||||||
|
slog.Error("[Poll] IDP session issue failed", "error", err)
|
||||||
|
return "", nil, errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if authInfo == nil || authInfo.SessionToken == nil || authInfo.SessionToken.JWT == "" {
|
||||||
|
return "", nil, errorJSON(c, fiber.StatusInternalServerError, "Failed to issue session")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Locals("login_id", loginID)
|
||||||
|
setSessionIDLocal(c, authInfo.SessionToken)
|
||||||
|
sessionID := extractSessionIDFromToken(authInfo.SessionToken)
|
||||||
|
if sessionID == "" && authInfo.SessionToken != nil && authInfo.SessionToken.JWT != "" {
|
||||||
|
if resolved, err := h.getKratosSessionID(authInfo.SessionToken.JWT); err == nil && resolved != "" {
|
||||||
|
sessionID = resolved
|
||||||
|
authInfo.SessionToken.SessionID = resolved
|
||||||
|
setSessionIDLocal(c, authInfo.SessionToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionData := map[string]string{
|
||||||
|
"status": statusSuccess,
|
||||||
|
"jwt": authInfo.SessionToken.JWT,
|
||||||
|
}
|
||||||
|
if sessionID != "" {
|
||||||
|
sessionData["session_id"] = sessionID
|
||||||
|
}
|
||||||
|
sessionDataJSON, _ := json.Marshal(sessionData)
|
||||||
|
_ = h.RedisService.Set(prefixSession+pendingRef, string(sessionDataJSON), defaultExpiration)
|
||||||
|
|
||||||
|
h.writeLinkAuditLog(loginID, pendingRef, authInfo.SessionToken, c)
|
||||||
|
h.clearLoginMeta(pendingRef)
|
||||||
|
if loginStrategy == loginFlowCode {
|
||||||
|
_ = h.RedisService.Delete(prefixLoginCode + loginID)
|
||||||
|
_ = h.RedisService.Delete(prefixLoginCodePending + loginID)
|
||||||
|
_ = h.RedisService.Delete(prefixLoginCodeSmsTarget + loginID)
|
||||||
|
_ = h.RedisService.Delete(prefixLoginCodeSmsLookup + loginID)
|
||||||
|
_ = h.RedisService.Delete(prefixLoginCodeValue + pendingRef)
|
||||||
|
}
|
||||||
|
|
||||||
|
return loginID, authInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) validateHeadlessPasswordLoginClient(loginReq *domain.HydraLoginRequest, clientID string) error {
|
||||||
|
if loginReq == nil {
|
||||||
|
return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request")
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(loginReq.Client.ClientID) != strings.TrimSpace(clientID) {
|
||||||
|
return fiber.NewError(fiber.StatusForbidden, "The client application is not allowed to use this login request.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if metadata := loginReq.Client.Metadata; metadata != nil {
|
||||||
|
if status, ok := metadata["status"].(string); ok && strings.ToLower(status) == "inactive" {
|
||||||
|
return fiber.NewError(fiber.StatusForbidden, "The client application is disabled.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !loginReq.Client.IsHeadlessLoginEnabled() {
|
||||||
|
return fiber.NewError(fiber.StatusForbidden, "The client application is not allowed to use headless password login.")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) HeadlessPasswordLogin(c *fiber.Ctx) error {
|
||||||
|
var req struct {
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
ClientAssertion string `json:"client_assertion"`
|
||||||
|
LoginID string `json:"loginId"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
LoginChallenge string `json:"login_challenge"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.BodyParser(&req); err != nil {
|
||||||
|
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "Invalid request body")
|
||||||
|
}
|
||||||
|
|
||||||
|
clientID := strings.TrimSpace(req.ClientID)
|
||||||
|
loginID := strings.TrimSpace(req.LoginID)
|
||||||
|
loginChallenge := strings.TrimSpace(req.LoginChallenge)
|
||||||
|
if clientID == "" || loginID == "" || strings.TrimSpace(req.Password) == "" || loginChallenge == "" {
|
||||||
|
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_id, loginId, password and login_challenge are required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.IdpProvider == nil || h.Hydra == nil {
|
||||||
|
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
loginReq, err := h.Hydra.GetLoginRequest(c.Context(), loginChallenge)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to get hydra login request for headless password login", "error", err)
|
||||||
|
return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request")
|
||||||
|
}
|
||||||
|
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
authInfo, authErr := h.authenticatePasswordLogin(c.Context(), loginID, req.Password)
|
||||||
|
if authErr != nil {
|
||||||
|
status, code, message := passwordLoginErrorSpec(authErr)
|
||||||
|
return errorJSONCode(c, status, code, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
setSessionIDLocal(c, authInfo.SessionToken)
|
||||||
|
|
||||||
|
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), loginChallenge, authInfo.Subject)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to accept hydra login request in headless password login", "error", err)
|
||||||
|
return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request")
|
||||||
|
}
|
||||||
|
|
||||||
|
logOidcRedirectSummary("headless_password_login", acceptResp.RedirectTo)
|
||||||
|
return c.JSON(fiber.Map{
|
||||||
|
"redirectTo": acceptResp.RedirectTo,
|
||||||
|
"status": "ok",
|
||||||
|
"provider": h.IdpProvider.Name(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) startHeadlessPhoneLink(c *fiber.Ctx, loginID string) (fiber.Map, string, string, time.Duration, error) {
|
||||||
|
rawLoginID := strings.ReplaceAll(loginID, "-", "")
|
||||||
|
rawLoginID = strings.ReplaceAll(rawLoginID, " ", "")
|
||||||
|
if rawLoginID == "" || strings.Contains(rawLoginID, "@") {
|
||||||
|
return nil, "", "", 0, errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "phone-based loginId is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
lookupLoginID := normalizePhoneForLoginID(rawLoginID)
|
||||||
|
if h.IdpProvider == nil {
|
||||||
|
return nil, "", "", 0, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
|
||||||
|
}
|
||||||
|
exists, err := h.IdpProvider.UserExists(lookupLoginID)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("[HeadlessLink] IDP user lookup failed", "loginID", rawLoginID, "error", err)
|
||||||
|
return nil, "", "", 0, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
slog.Warn("[HeadlessLink] User not found", "loginID", rawLoginID)
|
||||||
|
return nil, "", "", 0, errorJSON(c, fiber.StatusNotFound, "User not registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
userfrontURL := h.resolveUserfrontURL(c)
|
||||||
|
if init, err := h.IdpProvider.InitiateLinkLogin(lookupLoginID, userfrontURL); err == nil && init != nil && init.Mode != "" {
|
||||||
|
keyLoginID := lookupLoginID
|
||||||
|
if init.LoginID != "" {
|
||||||
|
keyLoginID = init.LoginID
|
||||||
|
}
|
||||||
|
if init.FlowID != "" {
|
||||||
|
_ = h.RedisService.Set(prefixLoginCode+keyLoginID, init.FlowID, loginCodeExpiration)
|
||||||
|
}
|
||||||
|
pendingRef := GenerateSecureToken(3)
|
||||||
|
sessionData, _ := json.Marshal(map[string]string{
|
||||||
|
"status": statusPending,
|
||||||
|
"loginId": keyLoginID,
|
||||||
|
})
|
||||||
|
_ = h.RedisService.Set(prefixSession+pendingRef, string(sessionData), loginCodeExpiration)
|
||||||
|
h.storeLoginMeta(pendingRef, rawLoginID, "sms", loginFlowLink, loginFlowCode, loginCodeExpiration)
|
||||||
|
_ = h.RedisService.Set(prefixLoginCodePending+keyLoginID, pendingRef, loginCodeExpiration)
|
||||||
|
if keyLoginID != lookupLoginID {
|
||||||
|
_ = h.RedisService.Set(prefixLoginCodeSmsTarget+keyLoginID, lookupLoginID, loginCodeExpiration)
|
||||||
|
_ = h.RedisService.Set(prefixLoginCodeSmsLookup+lookupLoginID, keyLoginID, loginCodeExpiration)
|
||||||
|
}
|
||||||
|
expiresIn := int(loginCodeExpiration.Seconds())
|
||||||
|
if !init.ExpiresAt.IsZero() {
|
||||||
|
if seconds := int(time.Until(init.ExpiresAt).Seconds()); seconds > 0 {
|
||||||
|
expiresIn = seconds
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fiber.Map{
|
||||||
|
"pendingRef": pendingRef,
|
||||||
|
"status": "pending",
|
||||||
|
"mode": init.Mode,
|
||||||
|
"provider": h.IdpProvider.Name(),
|
||||||
|
"expiresIn": expiresIn,
|
||||||
|
"interval": int(minPollInterval.Seconds()),
|
||||||
|
"resendAfter": int(linkResendCooldown.Seconds()),
|
||||||
|
}, pendingRef, keyLoginID, loginCodeExpiration, nil
|
||||||
|
} else if err != nil && !errors.Is(err, domain.ErrNotSupported) {
|
||||||
|
slog.Error("[HeadlessLink] Link login init failed", "provider", h.IdpProvider.Name(), "error", err)
|
||||||
|
return nil, "", "", 0, errorJSON(c, fiber.StatusServiceUnavailable, "Identity provider unavailable")
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.SmsService == nil {
|
||||||
|
return nil, "", "", 0, errorJSON(c, fiber.StatusInternalServerError, "SMS service not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
token := GenerateSecureToken(3)
|
||||||
|
pendingRef := GenerateSecureToken(3)
|
||||||
|
sessionData, _ := json.Marshal(map[string]string{
|
||||||
|
"status": statusPending,
|
||||||
|
"loginId": lookupLoginID,
|
||||||
|
})
|
||||||
|
_ = h.RedisService.Set(prefixSession+pendingRef, string(sessionData), defaultExpiration)
|
||||||
|
_ = h.RedisService.Set(prefixToken+token, fmt.Sprintf(`{"pendingRef":"%s","loginId":"%s"}`, pendingRef, lookupLoginID), defaultExpiration)
|
||||||
|
h.storeLoginMeta(pendingRef, rawLoginID, "sms", loginFlowLink, loginFlowLink, defaultExpiration)
|
||||||
|
|
||||||
|
link := fmt.Sprintf("%s/verify/%s", strings.TrimRight(userfrontURL, "/"), token)
|
||||||
|
content := fmt.Sprintf("[Baron 로그인] 로그인 링크: %s", link)
|
||||||
|
if err := h.SmsService.SendSms(rawLoginID, content); err != nil {
|
||||||
|
slog.Error("[HeadlessLink] SMS send failed", "error", err)
|
||||||
|
return nil, "", "", 0, errorJSON(c, fiber.StatusInternalServerError, "Failed to send SMS")
|
||||||
|
}
|
||||||
|
|
||||||
|
return fiber.Map{
|
||||||
|
"pendingRef": pendingRef,
|
||||||
|
"status": "pending",
|
||||||
|
"provider": h.IdpProvider.Name(),
|
||||||
|
"expiresIn": int(defaultExpiration.Seconds()),
|
||||||
|
"interval": int(minPollInterval.Seconds()),
|
||||||
|
"resendAfter": int(linkResendCooldown.Seconds()),
|
||||||
|
}, pendingRef, lookupLoginID, defaultExpiration, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) HeadlessLinkInit(c *fiber.Ctx) error {
|
||||||
|
var req struct {
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
ClientAssertion string `json:"client_assertion"`
|
||||||
|
LoginID string `json:"loginId"`
|
||||||
|
LoginChallenge string `json:"login_challenge"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.BodyParser(&req); err != nil {
|
||||||
|
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "Invalid request body")
|
||||||
|
}
|
||||||
|
|
||||||
|
clientID := strings.TrimSpace(req.ClientID)
|
||||||
|
loginChallenge := strings.TrimSpace(req.LoginChallenge)
|
||||||
|
loginID := strings.TrimSpace(req.LoginID)
|
||||||
|
if clientID == "" || loginChallenge == "" || loginID == "" {
|
||||||
|
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_id, client_assertion, loginId and login_challenge are required")
|
||||||
|
}
|
||||||
|
if h.Hydra == nil {
|
||||||
|
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
loginReq, err := h.Hydra.GetLoginRequest(c.Context(), loginChallenge)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to get hydra login request for headless link init", "error", err)
|
||||||
|
return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request")
|
||||||
|
}
|
||||||
|
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, pendingRef, resolvedLoginID, ttl, err := h.startHeadlessPhoneLink(c, loginID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.storeHeadlessLinkState(pendingRef, headlessLinkState{
|
||||||
|
ClientID: clientID,
|
||||||
|
LoginChallenge: loginChallenge,
|
||||||
|
LoginID: resolvedLoginID,
|
||||||
|
}, ttl)
|
||||||
|
return c.JSON(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) HeadlessLinkPoll(c *fiber.Ctx) error {
|
||||||
|
var req struct {
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
ClientAssertion string `json:"client_assertion"`
|
||||||
|
PendingRef string `json:"pendingRef"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.BodyParser(&req); err != nil {
|
||||||
|
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "Invalid request body")
|
||||||
|
}
|
||||||
|
|
||||||
|
clientID := strings.TrimSpace(req.ClientID)
|
||||||
|
pendingRef := strings.TrimSpace(req.PendingRef)
|
||||||
|
if clientID == "" || pendingRef == "" {
|
||||||
|
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_id, client_assertion and pendingRef are required")
|
||||||
|
}
|
||||||
|
|
||||||
|
state, ok := h.loadHeadlessLinkState(pendingRef)
|
||||||
|
if !ok {
|
||||||
|
return c.JSON(fiber.Map{
|
||||||
|
"error": "expired_token",
|
||||||
|
"code": "expired_token",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if state.ClientID != clientID {
|
||||||
|
return fiber.NewError(fiber.StatusForbidden, "The client application is not allowed to use this pending login.")
|
||||||
|
}
|
||||||
|
if state.RedirectTo != "" {
|
||||||
|
return c.JSON(fiber.Map{
|
||||||
|
"redirectTo": state.RedirectTo,
|
||||||
|
"status": "ok",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if h.Hydra == nil {
|
||||||
|
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
loginReq, err := h.Hydra.GetLoginRequest(c.Context(), state.LoginChallenge)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to get hydra login request for headless link poll", "error", err)
|
||||||
|
return fiber.NewError(fiber.StatusInternalServerError, "Failed to load OIDC login request")
|
||||||
|
}
|
||||||
|
if err := h.validateHeadlessPasswordLoginClient(loginReq, clientID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := h.verifyHeadlessClientAssertion(c, loginReq.Client, clientID, req.ClientAssertion); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
val, err := h.RedisService.Get(prefixSession + pendingRef)
|
||||||
|
if err != nil || val == "" {
|
||||||
|
return c.JSON(fiber.Map{
|
||||||
|
"error": "expired_token",
|
||||||
|
"code": "expired_token",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
var session map[string]string
|
||||||
|
_ = json.Unmarshal([]byte(val), &session)
|
||||||
|
if session["status"] == statusPending {
|
||||||
|
return c.JSON(fiber.Map{
|
||||||
|
"error": "authorization_pending",
|
||||||
|
"code": "authorization_pending",
|
||||||
|
"interval": int(minPollInterval.Seconds()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
loginID := state.LoginID
|
||||||
|
if session["status"] == "approved" {
|
||||||
|
completedLoginID, _, err := h.completeApprovedLinkLogin(c, pendingRef)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
loginID = completedLoginID
|
||||||
|
}
|
||||||
|
|
||||||
|
if loginID == "" {
|
||||||
|
return errorJSON(c, fiber.StatusInternalServerError, "Failed to resolve approved user identity")
|
||||||
|
}
|
||||||
|
subject, err := h.resolveKratosIdentityIDFromLoginID(c.Context(), loginID)
|
||||||
|
if err != nil || subject == "" {
|
||||||
|
slog.Error("failed to resolve kratos identity for headless link poll", "loginID", loginID, "error", err)
|
||||||
|
return errorJSON(c, fiber.StatusInternalServerError, "Failed to resolve user identity")
|
||||||
|
}
|
||||||
|
|
||||||
|
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), state.LoginChallenge, subject)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to accept hydra login request in headless link poll", "error", err)
|
||||||
|
return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request")
|
||||||
|
}
|
||||||
|
|
||||||
|
state.RedirectTo = acceptResp.RedirectTo
|
||||||
|
h.storeHeadlessLinkState(pendingRef, state, defaultExpiration)
|
||||||
|
logOidcRedirectSummary("headless_link_poll", acceptResp.RedirectTo)
|
||||||
|
return c.JSON(fiber.Map{
|
||||||
|
"redirectTo": acceptResp.RedirectTo,
|
||||||
|
"status": "ok",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
|
func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
ale := logger.NewAuditLogEntry(c, "login")
|
ale := logger.NewAuditLogEntry(c, "login")
|
||||||
@@ -1704,28 +2275,16 @@ func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
|
|||||||
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
|
return errorJSONCode(c, fiber.StatusInternalServerError, "service_unavailable", "Authentication service not configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
authInfo, err := h.IdpProvider.SignIn(loginID, req.Password)
|
authInfo, err := h.authenticatePasswordLogin(c.Context(), loginID, req.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, domain.ErrNotSupported) {
|
status, code, message := passwordLoginErrorSpec(err)
|
||||||
return errorJSONCode(c, fiber.StatusNotImplemented, "not_supported", "Login method not supported")
|
ale.Status = status
|
||||||
}
|
|
||||||
ale.Status = fiber.StatusUnauthorized
|
|
||||||
ale.LatencyMs = time.Since(startTime)
|
ale.LatencyMs = time.Since(startTime)
|
||||||
ale.ProviderError = err.Error()
|
ale.ProviderError = err.Error()
|
||||||
ale.Log(slog.LevelWarn, "IDP sign-in failed", slog.String("provider", h.IdpProvider.Name()))
|
ale.Log(slog.LevelWarn, "IDP sign-in failed", slog.String("provider", h.IdpProvider.Name()))
|
||||||
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "identity") {
|
return errorJSONCode(c, status, code, message)
|
||||||
return errorJSONCode(c, fiber.StatusNotFound, "not_found", "User not registered")
|
|
||||||
}
|
|
||||||
return errorJSONCode(c, fiber.StatusUnauthorized, "password_or_email_mismatch", "Invalid credentials")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
subject, resolveErr := h.resolveKratosIdentityIDFromLoginID(c.Context(), loginID)
|
|
||||||
if resolveErr != nil || subject == "" {
|
|
||||||
slog.Error("Failed to resolve kratos identity after login", "loginID", loginID, "error", resolveErr)
|
|
||||||
return fiber.NewError(fiber.StatusInternalServerError, "Failed to resolve user identity")
|
|
||||||
}
|
|
||||||
authInfo.Subject = subject
|
|
||||||
|
|
||||||
ale.Status = fiber.StatusOK
|
ale.Status = fiber.StatusOK
|
||||||
ale.LatencyMs = time.Since(startTime)
|
ale.LatencyMs = time.Since(startTime)
|
||||||
setSessionIDLocal(c, authInfo.SessionToken)
|
setSessionIDLocal(c, authInfo.SessionToken)
|
||||||
@@ -1746,7 +2305,7 @@ func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), req.LoginChallenge, subject)
|
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), req.LoginChallenge, authInfo.Subject)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to accept hydra login request", "error", err)
|
slog.Error("failed to accept hydra login request", "error", err)
|
||||||
return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request")
|
return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request")
|
||||||
|
|||||||
@@ -2,14 +2,17 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"baron-sso-backend/internal/domain"
|
"baron-sso-backend/internal/domain"
|
||||||
|
"baron-sso-backend/internal/service"
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Mock services
|
// Mock services
|
||||||
@@ -21,6 +24,14 @@ type mockSmsService struct{}
|
|||||||
|
|
||||||
func (m *mockSmsService) SendSms(to, content string) error { return nil }
|
func (m *mockSmsService) SendSms(to, content string) error { return nil }
|
||||||
|
|
||||||
|
func newHeadlessLinkTestApp(h *AuthHandler) *fiber.App {
|
||||||
|
app := fiber.New()
|
||||||
|
app.Post("/api/v1/auth/headless/link/init", h.HeadlessLinkInit)
|
||||||
|
app.Post("/api/v1/auth/headless/link/poll", h.HeadlessLinkPoll)
|
||||||
|
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
|
||||||
|
return app
|
||||||
|
}
|
||||||
|
|
||||||
func TestEnchantedLinkFlow_Email_Success(t *testing.T) {
|
func TestEnchantedLinkFlow_Email_Success(t *testing.T) {
|
||||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||||
// Force "Not Supported" for InitiateLinkLogin only to trigger custom Enchanted Link logic
|
// Force "Not Supported" for InitiateLinkLogin only to trigger custom Enchanted Link logic
|
||||||
@@ -144,3 +155,162 @@ func TestPollEnchantedLink_ExpiredToken_ReturnsCode(t *testing.T) {
|
|||||||
assert.Equal(t, "expired_token", got["error"])
|
assert.Equal(t, "expired_token", got["error"])
|
||||||
assert.Equal(t, "expired_token", got["code"])
|
assert.Equal(t, "expired_token", got["code"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHeadlessLinkInit_TrustedClientSuccess(t *testing.T) {
|
||||||
|
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||||
|
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||||
|
|
||||||
|
idp := &mockIdpProvider{
|
||||||
|
userExists: true,
|
||||||
|
initiateLinkErr: domain.ErrNotSupported,
|
||||||
|
}
|
||||||
|
|
||||||
|
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet {
|
||||||
|
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||||
|
Challenge: "challenge-123",
|
||||||
|
Client: domain.HydraClient{
|
||||||
|
ClientID: "trusted-rp",
|
||||||
|
TokenEndpointAuthMethod: "private_key_jwt",
|
||||||
|
JWKS: jwks,
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"status": "active",
|
||||||
|
"headless_login_enabled": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.NotFound(w, r)
|
||||||
|
})
|
||||||
|
|
||||||
|
h := &AuthHandler{
|
||||||
|
RedisService: redis,
|
||||||
|
IdpProvider: idp,
|
||||||
|
SmsService: &mockSmsService{},
|
||||||
|
Hydra: &service.HydraAdminService{
|
||||||
|
AdminURL: "http://hydra.test",
|
||||||
|
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
app := newHeadlessLinkTestApp(h)
|
||||||
|
t.Setenv("USERFRONT_URL", "http://userfront.test")
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]string{
|
||||||
|
"client_id": "trusted-rp",
|
||||||
|
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "trusted-rp", "http://example.com/api/v1/auth/headless/link/init"),
|
||||||
|
"loginId": "010-1234-5678",
|
||||||
|
"login_challenge": "challenge-123",
|
||||||
|
})
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
resp, _ := app.Test(req, -1)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
var got map[string]interface{}
|
||||||
|
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||||
|
assert.NotEmpty(t, got["pendingRef"])
|
||||||
|
_, hasUserCode := got["userCode"]
|
||||||
|
assert.False(t, hasUserCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeadlessLinkPoll_AfterApprovalReturnsRedirect(t *testing.T) {
|
||||||
|
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||||
|
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||||
|
|
||||||
|
idp := &mockIdpProvider{
|
||||||
|
userExists: true,
|
||||||
|
initiateLinkErr: domain.ErrNotSupported,
|
||||||
|
}
|
||||||
|
|
||||||
|
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
|
||||||
|
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||||
|
Challenge: "challenge-123",
|
||||||
|
Client: domain.HydraClient{
|
||||||
|
ClientID: "trusted-rp",
|
||||||
|
TokenEndpointAuthMethod: "private_key_jwt",
|
||||||
|
JWKS: jwks,
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"status": "active",
|
||||||
|
"headless_login_enabled": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.NotFound(w, r)
|
||||||
|
})
|
||||||
|
|
||||||
|
mockKratos := new(MockKratosAdminService)
|
||||||
|
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "+821012345678").Return("kratos-identity-id", nil)
|
||||||
|
|
||||||
|
h := &AuthHandler{
|
||||||
|
RedisService: redis,
|
||||||
|
IdpProvider: idp,
|
||||||
|
SmsService: &mockSmsService{},
|
||||||
|
KratosAdmin: mockKratos,
|
||||||
|
Hydra: &service.HydraAdminService{
|
||||||
|
AdminURL: "http://hydra.test",
|
||||||
|
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
app := newHeadlessLinkTestApp(h)
|
||||||
|
t.Setenv("USERFRONT_URL", "http://userfront.test")
|
||||||
|
|
||||||
|
initBody, _ := json.Marshal(map[string]string{
|
||||||
|
"client_id": "trusted-rp",
|
||||||
|
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "trusted-rp", "http://example.com/api/v1/auth/headless/link/init"),
|
||||||
|
"loginId": "010-1234-5678",
|
||||||
|
"login_challenge": "challenge-123",
|
||||||
|
})
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(initBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
resp, _ := app.Test(req, -1)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
var initResp map[string]interface{}
|
||||||
|
_ = json.NewDecoder(resp.Body).Decode(&initResp)
|
||||||
|
pendingRef := initResp["pendingRef"].(string)
|
||||||
|
assert.NotEmpty(t, pendingRef)
|
||||||
|
|
||||||
|
var token string
|
||||||
|
for k := range redis.data {
|
||||||
|
if len(k) > 16 && k[:16] == "enchanted_token:" {
|
||||||
|
token = k[16:]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.NotEmpty(t, token)
|
||||||
|
|
||||||
|
verifyBody, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"token": token,
|
||||||
|
"verifyOnly": true,
|
||||||
|
})
|
||||||
|
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
resp, _ = app.Test(req, -1)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
pollBody, _ := json.Marshal(map[string]string{
|
||||||
|
"client_id": "trusted-rp",
|
||||||
|
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "trusted-rp", "http://example.com/api/v1/auth/headless/link/poll"),
|
||||||
|
"pendingRef": pendingRef,
|
||||||
|
})
|
||||||
|
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/poll", bytes.NewReader(pollBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
resp, _ = app.Test(req, -1)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
var pollResp map[string]interface{}
|
||||||
|
_ = json.NewDecoder(resp.Body).Decode(&pollResp)
|
||||||
|
assert.Equal(t, "http://rp/cb", pollResp["redirectTo"])
|
||||||
|
assert.Equal(t, "ok", pollResp["status"])
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ import (
|
|||||||
"baron-sso-backend/internal/service"
|
"baron-sso-backend/internal/service"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
@@ -17,7 +19,10 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-jose/go-jose/v4"
|
||||||
|
josejwt "github.com/go-jose/go-jose/v4/jwt"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
)
|
)
|
||||||
@@ -121,6 +126,79 @@ func newAuthLoginTestApp(h *AuthHandler) *fiber.App {
|
|||||||
return app
|
return app
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newHeadlessPasswordLoginTestApp(h *AuthHandler) *fiber.App {
|
||||||
|
app := fiber.New()
|
||||||
|
app.Post("/api/v1/auth/headless/password/login", h.HeadlessPasswordLogin)
|
||||||
|
return app
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustHeadlessRSAJWK(t *testing.T) (*rsa.PrivateKey, map[string]any) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to generate rsa key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
keySet := jose.JSONWebKeySet{
|
||||||
|
Keys: []jose.JSONWebKey{
|
||||||
|
{
|
||||||
|
Key: &privateKey.PublicKey,
|
||||||
|
KeyID: "test-kid",
|
||||||
|
Use: "sig",
|
||||||
|
Algorithm: string(jose.RS256),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
raw, err := json.Marshal(keySet)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal jwks: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var jwks map[string]any
|
||||||
|
if err := json.Unmarshal(raw, &jwks); err != nil {
|
||||||
|
t.Fatalf("failed to decode jwks map: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return privateKey, jwks
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustHeadlessClientAssertion(t *testing.T, privateKey *rsa.PrivateKey, clientID, audience string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
signer, err := jose.NewSigner(jose.SigningKey{
|
||||||
|
Algorithm: jose.RS256,
|
||||||
|
Key: jose.JSONWebKey{
|
||||||
|
Key: privateKey,
|
||||||
|
KeyID: "test-kid",
|
||||||
|
Use: "sig",
|
||||||
|
Algorithm: string(jose.RS256),
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create signer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
raw, err := josejwt.Signed(signer).Claims(josejwt.Claims{
|
||||||
|
Issuer: clientID,
|
||||||
|
Subject: clientID,
|
||||||
|
Audience: josejwt.Audience{audience},
|
||||||
|
Expiry: josejwt.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
IssuedAt: josejwt.NewNumericDate(now),
|
||||||
|
NotBefore: josejwt.NewNumericDate(
|
||||||
|
now.Add(-1 * time.Minute),
|
||||||
|
),
|
||||||
|
ID: "assertion-1",
|
||||||
|
}).Serialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to sign client assertion: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
|
||||||
// mockHydraTransport simulates Hydra API responses
|
// mockHydraTransport simulates Hydra API responses
|
||||||
func mockHydraTransport(handler http.Handler) http.RoundTripper {
|
func mockHydraTransport(handler http.Handler) http.RoundTripper {
|
||||||
return roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
return roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
@@ -206,6 +284,342 @@ func TestPasswordLogin_OIDC_Success(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHeadlessPasswordLogin_TrustedClientSuccess(t *testing.T) {
|
||||||
|
mockIdp := new(MockIdentityProvider)
|
||||||
|
mockIdp.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{
|
||||||
|
SessionToken: &domain.Token{JWT: "valid-jwt"},
|
||||||
|
Subject: "kratos-identity-id",
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||||
|
jwksBody, _ := json.Marshal(jwks)
|
||||||
|
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write(jwksBody)
|
||||||
|
}))
|
||||||
|
defer jwksServer.Close()
|
||||||
|
|
||||||
|
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
|
||||||
|
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||||
|
Challenge: "challenge-123",
|
||||||
|
Client: domain.HydraClient{
|
||||||
|
ClientID: "trusted-rp",
|
||||||
|
TokenEndpointAuthMethod: "private_key_jwt",
|
||||||
|
JWKSUri: jwksServer.URL + "/.well-known/jwks.json",
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"status": "active",
|
||||||
|
"headless_login_enabled": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
mockKratos := new(MockKratosAdminService)
|
||||||
|
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil)
|
||||||
|
|
||||||
|
h := &AuthHandler{
|
||||||
|
IdpProvider: mockIdp,
|
||||||
|
KratosAdmin: mockKratos,
|
||||||
|
Hydra: &service.HydraAdminService{
|
||||||
|
AdminURL: "http://hydra.test",
|
||||||
|
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
app := newHeadlessPasswordLoginTestApp(h)
|
||||||
|
|
||||||
|
clientAssertion := mustHeadlessClientAssertion(
|
||||||
|
t,
|
||||||
|
privateKey,
|
||||||
|
"trusted-rp",
|
||||||
|
"http://example.com/api/v1/auth/headless/password/login",
|
||||||
|
)
|
||||||
|
body, _ := json.Marshal(map[string]string{
|
||||||
|
"client_id": "trusted-rp",
|
||||||
|
"client_assertion": clientAssertion,
|
||||||
|
"loginId": "employee001",
|
||||||
|
"password": "password",
|
||||||
|
"login_challenge": "challenge-123",
|
||||||
|
})
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("expected 200, got %d, body: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
var got map[string]interface{}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||||
|
t.Fatalf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
if got["redirectTo"] != "http://rp/cb" {
|
||||||
|
t.Fatalf("expected redirectTo http://rp/cb, got %v", got["redirectTo"])
|
||||||
|
}
|
||||||
|
if _, ok := got["sessionJwt"]; ok {
|
||||||
|
t.Fatalf("expected headless response to omit sessionJwt, got %v", got["sessionJwt"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeadlessPasswordLogin_MissingClientAssertionRejected(t *testing.T) {
|
||||||
|
mockIdp := new(MockIdentityProvider)
|
||||||
|
mockIdp.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{
|
||||||
|
SessionToken: &domain.Token{JWT: "valid-jwt"},
|
||||||
|
Subject: "kratos-identity-id",
|
||||||
|
}, nil)
|
||||||
|
mockKratos := new(MockKratosAdminService)
|
||||||
|
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil)
|
||||||
|
|
||||||
|
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
|
||||||
|
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||||
|
Challenge: "challenge-123",
|
||||||
|
Client: domain.HydraClient{
|
||||||
|
ClientID: "trusted-rp",
|
||||||
|
TokenEndpointAuthMethod: "private_key_jwt",
|
||||||
|
JWKS: map[string]any{
|
||||||
|
"keys": []map[string]any{},
|
||||||
|
},
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"status": "active",
|
||||||
|
"headless_login_enabled": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.NotFound(w, r)
|
||||||
|
})
|
||||||
|
|
||||||
|
h := &AuthHandler{
|
||||||
|
IdpProvider: mockIdp,
|
||||||
|
KratosAdmin: mockKratos,
|
||||||
|
Hydra: &service.HydraAdminService{
|
||||||
|
AdminURL: "http://hydra.test",
|
||||||
|
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
app := newHeadlessPasswordLoginTestApp(h)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]string{
|
||||||
|
"client_id": "trusted-rp",
|
||||||
|
"loginId": "employee001",
|
||||||
|
"password": "password",
|
||||||
|
"login_challenge": "challenge-123",
|
||||||
|
})
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("expected 400, got %d, body: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeadlessPasswordLogin_InvalidClientAssertionRejected(t *testing.T) {
|
||||||
|
mockIdp := new(MockIdentityProvider)
|
||||||
|
mockIdp.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{
|
||||||
|
SessionToken: &domain.Token{JWT: "valid-jwt"},
|
||||||
|
Subject: "kratos-identity-id",
|
||||||
|
}, nil)
|
||||||
|
mockKratos := new(MockKratosAdminService)
|
||||||
|
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil)
|
||||||
|
|
||||||
|
validKey, jwks := mustHeadlessRSAJWK(t)
|
||||||
|
invalidKey, _ := mustHeadlessRSAJWK(t)
|
||||||
|
_ = validKey
|
||||||
|
|
||||||
|
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
|
||||||
|
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||||
|
Challenge: "challenge-123",
|
||||||
|
Client: domain.HydraClient{
|
||||||
|
ClientID: "trusted-rp",
|
||||||
|
TokenEndpointAuthMethod: "private_key_jwt",
|
||||||
|
JWKS: jwks,
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"status": "active",
|
||||||
|
"headless_login_enabled": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.NotFound(w, r)
|
||||||
|
})
|
||||||
|
|
||||||
|
h := &AuthHandler{
|
||||||
|
IdpProvider: mockIdp,
|
||||||
|
KratosAdmin: mockKratos,
|
||||||
|
Hydra: &service.HydraAdminService{
|
||||||
|
AdminURL: "http://hydra.test",
|
||||||
|
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
app := newHeadlessPasswordLoginTestApp(h)
|
||||||
|
|
||||||
|
clientAssertion := mustHeadlessClientAssertion(
|
||||||
|
t,
|
||||||
|
invalidKey,
|
||||||
|
"trusted-rp",
|
||||||
|
"http://example.com/api/v1/auth/headless/password/login",
|
||||||
|
)
|
||||||
|
body, _ := json.Marshal(map[string]string{
|
||||||
|
"client_id": "trusted-rp",
|
||||||
|
"client_assertion": clientAssertion,
|
||||||
|
"loginId": "employee001",
|
||||||
|
"password": "password",
|
||||||
|
"login_challenge": "challenge-123",
|
||||||
|
})
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("expected 401, got %d, body: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeadlessPasswordLogin_HeadlessDisabledRejected(t *testing.T) {
|
||||||
|
mockIdp := new(MockIdentityProvider)
|
||||||
|
|
||||||
|
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet {
|
||||||
|
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||||
|
Challenge: "challenge-123",
|
||||||
|
Client: domain.HydraClient{
|
||||||
|
ClientID: "trusted-rp",
|
||||||
|
TokenEndpointAuthMethod: "private_key_jwt",
|
||||||
|
JWKSUri: "https://rp.example.com/.well-known/jwks.json",
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"status": "active",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.NotFound(w, r)
|
||||||
|
})
|
||||||
|
|
||||||
|
h := &AuthHandler{
|
||||||
|
IdpProvider: mockIdp,
|
||||||
|
Hydra: &service.HydraAdminService{
|
||||||
|
AdminURL: "http://hydra.test",
|
||||||
|
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
app := newHeadlessPasswordLoginTestApp(h)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]string{
|
||||||
|
"client_id": "trusted-rp",
|
||||||
|
"loginId": "employee001",
|
||||||
|
"password": "password",
|
||||||
|
"login_challenge": "challenge-123",
|
||||||
|
})
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusForbidden {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("expected 403, got %d, body: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeadlessPasswordLogin_ClientIDMismatchRejected(t *testing.T) {
|
||||||
|
mockIdp := new(MockIdentityProvider)
|
||||||
|
|
||||||
|
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet {
|
||||||
|
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||||
|
Challenge: "challenge-123",
|
||||||
|
Client: domain.HydraClient{
|
||||||
|
ClientID: "other-rp",
|
||||||
|
TokenEndpointAuthMethod: "private_key_jwt",
|
||||||
|
JWKSUri: "https://rp.example.com/.well-known/jwks.json",
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"status": "active",
|
||||||
|
"headless_login_enabled": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.NotFound(w, r)
|
||||||
|
})
|
||||||
|
|
||||||
|
h := &AuthHandler{
|
||||||
|
IdpProvider: mockIdp,
|
||||||
|
Hydra: &service.HydraAdminService{
|
||||||
|
AdminURL: "http://hydra.test",
|
||||||
|
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
app := newHeadlessPasswordLoginTestApp(h)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]string{
|
||||||
|
"client_id": "trusted-rp",
|
||||||
|
"loginId": "employee001",
|
||||||
|
"password": "password",
|
||||||
|
"login_challenge": "challenge-123",
|
||||||
|
})
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusForbidden {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("expected 403, got %d, body: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestPasswordLogin_OIDC_InactiveClient(t *testing.T) {
|
func TestPasswordLogin_OIDC_InactiveClient(t *testing.T) {
|
||||||
mockIdp := new(MockIdentityProvider)
|
mockIdp := new(MockIdentityProvider)
|
||||||
mockIdp.On("SignIn", "user@example.com", "password").Return(&domain.AuthInfo{
|
mockIdp.On("SignIn", "user@example.com", "password").Return(&domain.AuthInfo{
|
||||||
|
|||||||
Reference in New Issue
Block a user