forked from baron/baron-sso
ory-hosting 기본구동
This commit is contained in:
@@ -3,12 +3,26 @@ package idp
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Ory 계열(kratos/hydra)와 Descope 등 공급자 문자열을 정규화하기 위한 매핑.
|
||||
var providerAliases = map[string]string{
|
||||
"ory": "ory",
|
||||
"hydra": "ory",
|
||||
"kratos": "ory",
|
||||
"ory-kratos": "ory",
|
||||
"ory_hydra": "ory",
|
||||
"ory_kratos": "ory",
|
||||
"descope": "descope",
|
||||
"descope_sso": "descope",
|
||||
}
|
||||
|
||||
// getEnv는 환경 변수를 읽거나 대체 값을 반환하는 헬퍼 함수입니다.
|
||||
func getEnv(key, fallback string) string {
|
||||
if value, ok := os.LookupEnv(key); ok {
|
||||
@@ -22,42 +36,182 @@ func getEnv(key, fallback string) string {
|
||||
func InitializeProvider() (domain.IdentityProvider, error) {
|
||||
rawProviders := getEnv("IDP_PROVIDER", "descope") // 기본값은 descope입니다.
|
||||
providers := strings.Split(rawProviders, ",")
|
||||
slog.Info("Initializing IDP", "providers", rawProviders)
|
||||
slog.Info("Initializing IDP chain", "providers", rawProviders)
|
||||
|
||||
var initialized []domain.IdentityProvider
|
||||
for _, p := range providers {
|
||||
providerName := strings.TrimSpace(strings.ToLower(p))
|
||||
if canonical, ok := providerAliases[providerName]; ok {
|
||||
providerName = canonical
|
||||
}
|
||||
|
||||
switch providerName {
|
||||
case "ory":
|
||||
// Kratos/Hydra 주 공급자
|
||||
oryProvider := service.NewOryProvider()
|
||||
initialized = append(initialized, oryProvider)
|
||||
|
||||
case "descope":
|
||||
descopeProjectID := getEnv("DESCOPE_PROJECT_ID", "")
|
||||
descopeManagementKey := getEnv("DESCOPE_MANAGEMENT_KEY", "")
|
||||
// 선택된 공급자에 대한 키가 설정되었는지 확인하기 위한 기본 유효성 검사
|
||||
if descopeProjectID == "" || descopeManagementKey == "" {
|
||||
return nil, fmt.Errorf("DESCOPE_PROJECT_ID and DESCOPE_MANAGEMENT_KEY must be set for the 'descope' provider")
|
||||
slog.Warn("Skipping Descope provider due to missing credentials")
|
||||
continue
|
||||
}
|
||||
return service.NewDescopeProvider(descopeProjectID, descopeManagementKey), nil
|
||||
initialized = append(initialized, service.NewDescopeProvider(descopeProjectID, descopeManagementKey))
|
||||
|
||||
// --- 향후 공급자 구현 ---
|
||||
// case "ory":
|
||||
// // oryURL := getEnv("ORY_URL", "")
|
||||
// // if oryURL == "" {
|
||||
// // return nil, fmt.Errorf("ORY_URL must be set for the 'ory' provider")
|
||||
// // }
|
||||
// // return service.NewOryProvider(oryURL), nil
|
||||
// // return nil, fmt.Errorf(\"'ory' provider is not yet implemented\")
|
||||
|
||||
// case "keycloak":
|
||||
// // keycloakURL := getEnv("KEYCLOAK_URL", "")
|
||||
// // keycloakRealm := getEnv("KEYCLOAK_REALM", "")
|
||||
// // if keycloakURL == "" || keycloakRealm == "" {
|
||||
// // return nil, fmt.Errorf("KEYCLOAK_URL and KEYCLOAK_REALM must be set for the 'keycloak' provider")
|
||||
// // }
|
||||
// // return service.NewKeycloakProvider(keycloakURL, keycloakRealm), nil
|
||||
// // return nil, fmt.Errorf(\"'keycloak' provider is not yet implemented\")
|
||||
default:
|
||||
// 알 수 없는 공급자는 건너뛰고 다음 후보를 시도
|
||||
slog.Warn("Skipping unsupported IDP provider entry", "provider", providerName)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported or unknown IDP_PROVIDER specified: %s", rawProviders)
|
||||
if len(initialized) == 0 {
|
||||
return nil, fmt.Errorf("no valid IDP_PROVIDER entries configured from: %s", rawProviders)
|
||||
}
|
||||
|
||||
if len(initialized) == 1 {
|
||||
slog.Info("Initialized IDP provider", "provider", initialized[0].Name())
|
||||
return initialized[0], nil
|
||||
}
|
||||
|
||||
chain := newChainedProvider(initialized)
|
||||
slog.Info("Initialized IDP provider chain", "providers", chain.Name())
|
||||
return chain, nil
|
||||
}
|
||||
|
||||
// newChainedProvider는 우선순위 순으로 IDP를 시도하는 체인을 생성합니다.
|
||||
func newChainedProvider(providers []domain.IdentityProvider) domain.IdentityProvider {
|
||||
names := make([]string, len(providers))
|
||||
for i, p := range providers {
|
||||
names[i] = p.Name()
|
||||
}
|
||||
return &chainedProvider{
|
||||
providers: providers,
|
||||
names: names,
|
||||
}
|
||||
}
|
||||
|
||||
// chainedProvider는 다중 IDP를 우선순위대로 호출하며 실패 시 폴백합니다.
|
||||
type chainedProvider struct {
|
||||
providers []domain.IdentityProvider
|
||||
names []string
|
||||
}
|
||||
|
||||
func (c *chainedProvider) Name() string {
|
||||
return strings.Join(c.names, " > ")
|
||||
}
|
||||
|
||||
func (c *chainedProvider) GetMetadata() (*domain.IDPMetadata, error) {
|
||||
supported := make([]string, 0)
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, p := range c.providers {
|
||||
meta, err := p.GetMetadata()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch metadata from %s: %w", p.Name(), err)
|
||||
}
|
||||
for _, field := range meta.SupportedFields {
|
||||
if !seen[field] {
|
||||
seen[field] = true
|
||||
supported = append(supported, field)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &domain.IDPMetadata{SupportedFields: supported}, nil
|
||||
}
|
||||
|
||||
func (c *chainedProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
|
||||
var errs []error
|
||||
for idx, p := range c.providers {
|
||||
id, err := p.CreateUser(user, password)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "CreateUser", "error", err)
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
slog.Info("IDP fallback succeeded", "operation", "CreateUser", "provider", p.Name())
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return "", fmt.Errorf("no IDP providers available for CreateUser")
|
||||
}
|
||||
return "", fmt.Errorf("all IDP providers failed for CreateUser: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *chainedProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
|
||||
var errs []error
|
||||
for idx, p := range c.providers {
|
||||
info, err := p.SignIn(loginID, password)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "SignIn", "error", err)
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
slog.Info("IDP fallback succeeded", "operation", "SignIn", "provider", p.Name())
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return nil, fmt.Errorf("no IDP providers available for SignIn")
|
||||
}
|
||||
return nil, fmt.Errorf("all IDP providers failed for SignIn: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *chainedProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
|
||||
return c.tryProviders("InitiatePasswordReset", func(p domain.IdentityProvider) error {
|
||||
return p.InitiatePasswordReset(loginID, redirectUrl)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *chainedProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
|
||||
var errs []error
|
||||
for idx, p := range c.providers {
|
||||
info, err := p.VerifyPasswordResetToken(token)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
slog.Warn("IDP VerifyPasswordResetToken failed", "provider", p.Name(), "error", err)
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
slog.Info("IDP fallback succeeded", "operation", "VerifyPasswordResetToken", "provider", p.Name())
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
|
||||
if len(errs) == 0 {
|
||||
return nil, fmt.Errorf("no IDP providers available for VerifyPasswordResetToken")
|
||||
}
|
||||
return nil, fmt.Errorf("all IDP providers failed for VerifyPasswordResetToken: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *chainedProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
|
||||
return c.tryProviders("UpdateUserPassword", func(p domain.IdentityProvider) error {
|
||||
return p.UpdateUserPassword(loginID, newPassword, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *chainedProvider) tryProviders(operation string, fn func(domain.IdentityProvider) error) error {
|
||||
var errs []error
|
||||
for idx, p := range c.providers {
|
||||
if err := fn(p); err != nil {
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", operation, "error", err)
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
slog.Info("IDP fallback succeeded", "operation", operation, "provider", p.Name())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(errs) == 0 {
|
||||
return fmt.Errorf("no IDP providers available for %s", operation)
|
||||
}
|
||||
return fmt.Errorf("all IDP providers failed for %s: %w", operation, errors.Join(errs...))
|
||||
}
|
||||
|
||||
115
backend/internal/idp/factory_test.go
Normal file
115
backend/internal/idp/factory_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"errors"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type stubProvider struct {
|
||||
name string
|
||||
metadata []string
|
||||
createErr error
|
||||
initiateErr error
|
||||
verifyErr error
|
||||
updateErr error
|
||||
signInErr error
|
||||
initiateCalls int
|
||||
verifyCalls int
|
||||
updateCalls int
|
||||
signInCalls int
|
||||
createCalls int
|
||||
verifyResponse *domain.AuthInfo
|
||||
}
|
||||
|
||||
func (s *stubProvider) Name() string { return s.name }
|
||||
|
||||
func (s *stubProvider) GetMetadata() (*domain.IDPMetadata, error) {
|
||||
return &domain.IDPMetadata{SupportedFields: s.metadata}, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
|
||||
s.createCalls++
|
||||
if s.createErr != nil {
|
||||
return "", s.createErr
|
||||
}
|
||||
return "created-id", nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
|
||||
s.signInCalls++
|
||||
if s.signInErr != nil {
|
||||
return nil, s.signInErr
|
||||
}
|
||||
return &domain.AuthInfo{Subject: "subject-123"}, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
|
||||
s.initiateCalls++
|
||||
return s.initiateErr
|
||||
}
|
||||
|
||||
func (s *stubProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
|
||||
s.verifyCalls++
|
||||
if s.verifyErr != nil {
|
||||
return nil, s.verifyErr
|
||||
}
|
||||
if s.verifyResponse != nil {
|
||||
return s.verifyResponse, nil
|
||||
}
|
||||
return &domain.AuthInfo{}, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
|
||||
s.updateCalls++
|
||||
return s.updateErr
|
||||
}
|
||||
|
||||
func TestChainedProviderMetadataUnion(t *testing.T) {
|
||||
p1 := &stubProvider{name: "primary", metadata: []string{"id", "email"}}
|
||||
p2 := &stubProvider{name: "backup", metadata: []string{"email", "phone_number", "grade"}}
|
||||
|
||||
chain := newChainedProvider([]domain.IdentityProvider{p1, p2})
|
||||
meta, err := chain.GetMetadata()
|
||||
if err != nil {
|
||||
t.Fatalf("GetMetadata returned error: %v", err)
|
||||
}
|
||||
|
||||
expected := []string{"id", "email", "phone_number", "grade"}
|
||||
if !reflect.DeepEqual(meta.SupportedFields, expected) {
|
||||
t.Fatalf("metadata mismatch: got %v, want %v", meta.SupportedFields, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainedProviderUpdateUserPasswordFallback(t *testing.T) {
|
||||
p1 := &stubProvider{name: "primary", metadata: []string{"id"}, updateErr: errors.New("boom")}
|
||||
p2 := &stubProvider{name: "backup", metadata: []string{"id"}}
|
||||
|
||||
chain := newChainedProvider([]domain.IdentityProvider{p1, p2})
|
||||
if err := chain.UpdateUserPassword("user@example.com", "Sup3r!Pass123", nil); err != nil {
|
||||
t.Fatalf("expected fallback to succeed, got error: %v", err)
|
||||
}
|
||||
if p1.updateCalls != 1 || p2.updateCalls != 1 {
|
||||
t.Fatalf("unexpected call counts: p1=%d p2=%d", p1.updateCalls, p2.updateCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainedProviderUpdateUserPasswordAllFail(t *testing.T) {
|
||||
p1 := &stubProvider{name: "primary", metadata: []string{"id"}, updateErr: errors.New("fail1")}
|
||||
p2 := &stubProvider{name: "backup", metadata: []string{"id"}, updateErr: errors.New("fail2")}
|
||||
|
||||
chain := newChainedProvider([]domain.IdentityProvider{p1, p2})
|
||||
err := chain.UpdateUserPassword("user@example.com", "Sup3r!Pass123", nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error when all providers fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "all IDP providers failed") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if p1.updateCalls != 1 || p2.updateCalls != 1 {
|
||||
t.Fatalf("unexpected call counts: p1=%d p2=%d", p1.updateCalls, p2.updateCalls)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user