forked from baron/baron-sso
OIDC back-channel logout 백엔드 전송 기능 추가
This commit is contained in:
192
backend/internal/service/backchannel_logout_service.go
Normal file
192
backend/internal/service/backchannel_logout_service.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
josejwt "github.com/go-jose/go-jose/v4/jwt"
|
||||
)
|
||||
|
||||
const backchannelLogoutEventURI = "http://schemas.openid.net/event/backchannel-logout"
|
||||
|
||||
type BackchannelLogoutService struct {
|
||||
issuer string
|
||||
keyID string
|
||||
signer jose.Signer
|
||||
publicJWK jose.JSONWebKey
|
||||
client *http.Client
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
func NewBackchannelLogoutService() (*BackchannelLogoutService, error) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate backchannel logout key: %w", err)
|
||||
}
|
||||
|
||||
keyID := randomBackchannelKeyID()
|
||||
if keyID == "" {
|
||||
keyID = fmt.Sprintf("bcl-%d", time.Now().UnixNano())
|
||||
}
|
||||
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.RS256,
|
||||
Key: jose.JSONWebKey{
|
||||
Key: privateKey,
|
||||
KeyID: keyID,
|
||||
Algorithm: string(jose.RS256),
|
||||
Use: "sig",
|
||||
},
|
||||
}, (&jose.SignerOptions{}).WithType("JWT"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize backchannel logout signer: %w", err)
|
||||
}
|
||||
|
||||
return &BackchannelLogoutService{
|
||||
issuer: resolveBackchannelLogoutIssuer(),
|
||||
keyID: keyID,
|
||||
signer: signer,
|
||||
publicJWK: jose.JSONWebKey{
|
||||
Key: &privateKey.PublicKey,
|
||||
KeyID: keyID,
|
||||
Algorithm: string(jose.RS256),
|
||||
Use: "sig",
|
||||
},
|
||||
client: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 3 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: 3 * time.Second,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func randomBackchannelKeyID() string {
|
||||
buf := make([]byte, 8)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return ""
|
||||
}
|
||||
return hex.EncodeToString(buf)
|
||||
}
|
||||
|
||||
func resolveBackchannelLogoutIssuer() string {
|
||||
if explicit := strings.TrimSpace(os.Getenv("BACKCHANNEL_LOGOUT_ISSUER")); explicit != "" {
|
||||
return strings.TrimRight(explicit, "/")
|
||||
}
|
||||
|
||||
if hydraPublic := strings.TrimSpace(os.Getenv("HYDRA_PUBLIC_URL")); hydraPublic != "" {
|
||||
return strings.TrimRight(hydraPublic, "/")
|
||||
}
|
||||
|
||||
if oathkeeperPublic := strings.TrimSpace(os.Getenv("OATHKEEPER_PUBLIC_URL")); oathkeeperPublic != "" {
|
||||
return strings.TrimRight(oathkeeperPublic, "/") + "/oidc"
|
||||
}
|
||||
|
||||
if userfrontURL := strings.TrimSpace(os.Getenv("USERFRONT_URL")); userfrontURL != "" {
|
||||
return strings.TrimRight(userfrontURL, "/") + "/oidc"
|
||||
}
|
||||
|
||||
return "http://localhost:5000/oidc"
|
||||
}
|
||||
|
||||
func (s *BackchannelLogoutService) Issuer() string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return s.issuer
|
||||
}
|
||||
|
||||
func (s *BackchannelLogoutService) PublicJWKS() map[string]any {
|
||||
if s == nil {
|
||||
return map[string]any{"keys": []any{}}
|
||||
}
|
||||
return map[string]any{
|
||||
"keys": []jose.JSONWebKey{s.publicJWK.Public()},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackchannelLogoutService) BuildLogoutToken(clientID, subject, sessionID string) (string, error) {
|
||||
if s == nil || s.signer == nil {
|
||||
return "", fmt.Errorf("backchannel logout service is unavailable")
|
||||
}
|
||||
clientID = strings.TrimSpace(clientID)
|
||||
subject = strings.TrimSpace(subject)
|
||||
sessionID = strings.TrimSpace(sessionID)
|
||||
if clientID == "" {
|
||||
return "", fmt.Errorf("client id is required")
|
||||
}
|
||||
if subject == "" && sessionID == "" {
|
||||
return "", fmt.Errorf("subject or session id is required")
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
claims := josejwt.Claims{
|
||||
Issuer: s.issuer,
|
||||
Audience: josejwt.Audience{clientID},
|
||||
IssuedAt: josejwt.NewNumericDate(now),
|
||||
ID: fmt.Sprintf("%s-%d", s.keyID, now.UnixNano()),
|
||||
}
|
||||
if subject != "" {
|
||||
claims.Subject = subject
|
||||
}
|
||||
|
||||
extra := map[string]any{
|
||||
"events": map[string]any{
|
||||
backchannelLogoutEventURI: map[string]any{},
|
||||
},
|
||||
}
|
||||
if sessionID != "" {
|
||||
extra["sid"] = sessionID
|
||||
}
|
||||
|
||||
return josejwt.Signed(s.signer).Claims(claims).Claims(extra).Serialize()
|
||||
}
|
||||
|
||||
func (s *BackchannelLogoutService) SendLogoutToken(ctx context.Context, endpoint, logoutToken string) (int, error) {
|
||||
if s == nil {
|
||||
return 0, fmt.Errorf("backchannel logout service is unavailable")
|
||||
}
|
||||
form := url.Values{}
|
||||
form.Set("logout_token", logoutToken)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := s.client
|
||||
if s.HTTPClient != nil {
|
||||
client = s.HTTPClient
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return resp.StatusCode, fmt.Errorf("backchannel logout endpoint returned status %d", resp.StatusCode)
|
||||
}
|
||||
return resp.StatusCode, nil
|
||||
}
|
||||
|
||||
func (s *BackchannelLogoutService) MarshalPublicJWKS() ([]byte, error) {
|
||||
return json.Marshal(s.PublicJWKS())
|
||||
}
|
||||
Reference in New Issue
Block a user