package idp import ( "baron-sso-backend/internal/domain" "baron-sso-backend/internal/service" "errors" "fmt" "log/slog" "net/http" "os" "strings" ) // Ory 계열(kratos/hydra) 공급자 문자열을 정규화하기 위한 매핑. var providerAliases = map[string]string{ "ory": "ory", "hydra": "ory", "kratos": "ory", "ory-kratos": "ory", "ory_hydra": "ory", "ory_kratos": "ory", } // getEnv는 환경 변수를 읽거나 대체 값을 반환하는 헬퍼 함수입니다. func getEnv(key, fallback string) string { if value, ok := os.LookupEnv(key); ok { return value } return fallback } // InitializeProvider는 환경 설정을 기반으로 IDP 공급자를 생성하고 반환합니다. // 이것은 IdentityProvider 인터페이스의 팩토리 역할을 합니다. func InitializeProvider() (domain.IdentityProvider, error) { rawProviders := getEnv("IDP_PROVIDER", "ory") providers := strings.Split(rawProviders, ",") slog.Info("Initializing IDP chain", "providers", rawProviders) var initialized []domain.IdentityProvider for _, p := range providers { providerName := strings.TrimSpace(strings.ToLower(p)) if canonical, ok := providerAliases[providerName]; ok { providerName = canonical } switch providerName { case "ory": // Kratos/Hydra 주 공급자 oryProvider := service.NewOryProvider() initialized = append(initialized, oryProvider) default: // 알 수 없는 공급자는 건너뛰고 다음 후보를 시도 slog.Warn("Skipping unsupported IDP provider entry", "provider", providerName) } } if len(initialized) == 0 { return nil, fmt.Errorf("no valid IDP_PROVIDER entries configured from: %s", rawProviders) } if len(initialized) == 1 { slog.Info("Initialized IDP provider", "provider", initialized[0].Name()) return initialized[0], nil } chain := newChainedProvider(initialized) slog.Info("Initialized IDP provider chain", "providers", chain.Name()) return chain, nil } // newChainedProvider는 우선순위 순으로 IDP를 시도하는 체인을 생성합니다. func newChainedProvider(providers []domain.IdentityProvider) domain.IdentityProvider { names := make([]string, len(providers)) for i, p := range providers { names[i] = p.Name() } return &chainedProvider{ providers: providers, names: names, } } // chainedProvider는 다중 IDP를 우선순위대로 호출하며 실패 시 폴백합니다. type chainedProvider struct { providers []domain.IdentityProvider names []string } func (c *chainedProvider) Name() string { return strings.Join(c.names, " > ") } func (c *chainedProvider) GetMetadata() (*domain.IDPMetadata, error) { supported := make([]string, 0) seen := make(map[string]bool) for _, p := range c.providers { meta, err := p.GetMetadata() if err != nil { return nil, fmt.Errorf("failed to fetch metadata from %s: %w", p.Name(), err) } for _, field := range meta.SupportedFields { if !seen[field] { seen[field] = true supported = append(supported, field) } } } return &domain.IDPMetadata{SupportedFields: supported}, nil } func (c *chainedProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) { for _, p := range c.providers { id, err := p.CreateUser(user, password) if err != nil { if errors.Is(err, domain.ErrNotSupported) { continue } slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "CreateUser", "error", err) return "", err } return id, nil } return "", domain.ErrNotSupported } func (c *chainedProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) { for _, p := range c.providers { info, err := p.SignIn(loginID, password) if err != nil { if errors.Is(err, domain.ErrNotSupported) { continue } slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "SignIn", "error", err) return nil, err } return info, nil } return nil, domain.ErrNotSupported } func (c *chainedProvider) UserExists(loginID string) (bool, error) { var errs []error for _, p := range c.providers { exists, err := p.UserExists(loginID) if err != nil { if errors.Is(err, domain.ErrNotSupported) { continue } errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err)) continue } if exists { return true, nil } } if len(errs) == 0 { return false, nil } return false, fmt.Errorf("all IDP providers failed for UserExists: %w", errors.Join(errs...)) } func (c *chainedProvider) IssueSession(loginID string) (*domain.AuthInfo, error) { var errs []error for idx, p := range c.providers { info, err := p.IssueSession(loginID) if err != nil { if errors.Is(err, domain.ErrNotSupported) { continue } errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err)) slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "IssueSession", "error", err) continue } if idx > 0 { slog.Info("IDP fallback succeeded", "operation", "IssueSession", "provider", p.Name()) } return info, nil } if len(errs) == 0 { return nil, domain.ErrNotSupported } return nil, fmt.Errorf("all IDP providers failed for IssueSession: %w", errors.Join(errs...)) } func (c *chainedProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) { var errs []error for idx, p := range c.providers { info, err := p.InitiateLinkLogin(loginID, returnTo) if err != nil { if errors.Is(err, domain.ErrNotSupported) { continue } errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err)) slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "InitiateLinkLogin", "error", err) continue } if idx > 0 { slog.Info("IDP fallback succeeded", "operation", "InitiateLinkLogin", "provider", p.Name()) } return info, nil } if len(errs) == 0 { return nil, domain.ErrNotSupported } return nil, fmt.Errorf("all IDP providers failed for InitiateLinkLogin: %w", errors.Join(errs...)) } func (c *chainedProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) { var errs []error for idx, p := range c.providers { info, err := p.VerifyLoginCode(loginID, flowID, code) if err != nil { if errors.Is(err, domain.ErrNotSupported) { continue } errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err)) slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "VerifyLoginCode", "error", err) continue } if idx > 0 { slog.Info("IDP fallback succeeded", "operation", "VerifyLoginCode", "provider", p.Name()) } return info, nil } if len(errs) == 0 { return nil, domain.ErrNotSupported } return nil, fmt.Errorf("all IDP providers failed for VerifyLoginCode: %w", errors.Join(errs...)) } func (c *chainedProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) { var errs []error for _, p := range c.providers { policy, err := p.GetPasswordPolicy() if err != nil { if errors.Is(err, domain.ErrNotSupported) { continue } errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err)) continue } if policy != nil { return policy, nil } } if len(errs) == 0 { return nil, domain.ErrNotSupported } return nil, fmt.Errorf("all IDP providers failed for GetPasswordPolicy: %w", errors.Join(errs...)) } func (c *chainedProvider) InitiatePasswordReset(loginID, redirectUrl string) error { return c.tryProviders("InitiatePasswordReset", func(p domain.IdentityProvider) error { return p.InitiatePasswordReset(loginID, redirectUrl) }) } func (c *chainedProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) { var errs []error for idx, p := range c.providers { info, err := p.VerifyPasswordResetToken(token) if err != nil { errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err)) slog.Warn("IDP VerifyPasswordResetToken failed", "provider", p.Name(), "error", err) continue } if idx > 0 { slog.Info("IDP fallback succeeded", "operation", "VerifyPasswordResetToken", "provider", p.Name()) } return info, nil } if len(errs) == 0 { return nil, fmt.Errorf("no IDP providers available for VerifyPasswordResetToken") } return nil, fmt.Errorf("all IDP providers failed for VerifyPasswordResetToken: %w", errors.Join(errs...)) } func (c *chainedProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error { return c.tryProviders("UpdateUserPassword", func(p domain.IdentityProvider) error { return p.UpdateUserPassword(loginID, newPassword, r) }) } func (c *chainedProvider) tryProviders(operation string, fn func(domain.IdentityProvider) error) error { var errs []error for idx, p := range c.providers { if err := fn(p); err != nil { errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err)) slog.Warn("IDP provider failed", "provider", p.Name(), "operation", operation, "error", err) continue } if idx > 0 { slog.Info("IDP fallback succeeded", "operation", operation, "provider", p.Name()) } return nil } if len(errs) == 0 { return fmt.Errorf("no IDP providers available for %s", operation) } return fmt.Errorf("all IDP providers failed for %s: %w", operation, errors.Join(errs...)) }