package service import ( "context" "crypto/rand" "crypto/rsa" "encoding/hex" "encoding/json" "fmt" "net" "net/http" "net/url" "os" "strings" "time" "github.com/go-jose/go-jose/v4" josejwt "github.com/go-jose/go-jose/v4/jwt" ) const backchannelLogoutEventURI = "http://schemas.openid.net/event/backchannel-logout" type BackchannelLogoutService struct { issuer string keyID string signer jose.Signer publicJWK jose.JSONWebKey client *http.Client HTTPClient *http.Client } func NewBackchannelLogoutService() (*BackchannelLogoutService, error) { privateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, fmt.Errorf("failed to generate backchannel logout key: %w", err) } keyID := randomBackchannelKeyID() if keyID == "" { keyID = fmt.Sprintf("bcl-%d", time.Now().UnixNano()) } signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.RS256, Key: jose.JSONWebKey{ Key: privateKey, KeyID: keyID, Algorithm: string(jose.RS256), Use: "sig", }, }, (&jose.SignerOptions{}).WithType("JWT")) if err != nil { return nil, fmt.Errorf("failed to initialize backchannel logout signer: %w", err) } return &BackchannelLogoutService{ issuer: resolveBackchannelLogoutIssuer(), keyID: keyID, signer: signer, publicJWK: jose.JSONWebKey{ Key: &privateKey.PublicKey, KeyID: keyID, Algorithm: string(jose.RS256), Use: "sig", }, client: &http.Client{ Timeout: 5 * time.Second, Transport: &http.Transport{ DialContext: (&net.Dialer{ Timeout: 3 * time.Second, KeepAlive: 30 * time.Second, }).DialContext, TLSHandshakeTimeout: 3 * time.Second, }, }, }, nil } func randomBackchannelKeyID() string { buf := make([]byte, 8) if _, err := rand.Read(buf); err != nil { return "" } return hex.EncodeToString(buf) } func resolveBackchannelLogoutIssuer() string { if explicit := strings.TrimSpace(os.Getenv("BACKCHANNEL_LOGOUT_ISSUER")); explicit != "" { return strings.TrimRight(explicit, "/") } if hydraPublic := strings.TrimSpace(os.Getenv("HYDRA_PUBLIC_URL")); hydraPublic != "" { return strings.TrimRight(hydraPublic, "/") } if oathkeeperPublic := strings.TrimSpace(os.Getenv("OATHKEEPER_PUBLIC_URL")); oathkeeperPublic != "" { return strings.TrimRight(oathkeeperPublic, "/") + "/oidc" } if userfrontURL := strings.TrimSpace(os.Getenv("USERFRONT_URL")); userfrontURL != "" { return strings.TrimRight(userfrontURL, "/") + "/oidc" } return "http://localhost:5000/oidc" } func (s *BackchannelLogoutService) Issuer() string { if s == nil { return "" } return s.issuer } func (s *BackchannelLogoutService) PublicJWKS() map[string]any { if s == nil { return map[string]any{"keys": []any{}} } return map[string]any{ "keys": []jose.JSONWebKey{s.publicJWK.Public()}, } } func (s *BackchannelLogoutService) BuildLogoutToken(clientID, subject, sessionID string) (string, error) { if s == nil || s.signer == nil { return "", fmt.Errorf("backchannel logout service is unavailable") } clientID = strings.TrimSpace(clientID) subject = strings.TrimSpace(subject) sessionID = strings.TrimSpace(sessionID) if clientID == "" { return "", fmt.Errorf("client id is required") } if subject == "" && sessionID == "" { return "", fmt.Errorf("subject or session id is required") } now := time.Now().UTC() claims := josejwt.Claims{ Issuer: s.issuer, Audience: josejwt.Audience{clientID}, IssuedAt: josejwt.NewNumericDate(now), ID: fmt.Sprintf("%s-%d", s.keyID, now.UnixNano()), } if subject != "" { claims.Subject = subject } extra := map[string]any{ "events": map[string]any{ backchannelLogoutEventURI: map[string]any{}, }, } if sessionID != "" { extra["sid"] = sessionID } return josejwt.Signed(s.signer).Claims(claims).Claims(extra).Serialize() } func (s *BackchannelLogoutService) SendLogoutToken(ctx context.Context, endpoint, logoutToken string) (int, error) { if s == nil { return 0, fmt.Errorf("backchannel logout service is unavailable") } form := url.Values{} form.Set("logout_token", logoutToken) req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode())) if err != nil { return 0, err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", "application/json") client := s.client if s.HTTPClient != nil { client = s.HTTPClient } resp, err := client.Do(req) if err != nil { return 0, err } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { return resp.StatusCode, fmt.Errorf("backchannel logout endpoint returned status %d", resp.StatusCode) } return resp.StatusCode, nil } func (s *BackchannelLogoutService) MarshalPublicJWKS() ([]byte, error) { return json.Marshal(s.PublicJWKS()) }