package idp import ( "baron-sso-backend/internal/domain" "errors" "net/http" "reflect" "strings" "testing" ) type stubProvider struct { name string metadata []string createErr error initiateErr error verifyErr error updateErr error signInErr error userExistsErr error issueErr error linkInitErr error verifyCodeErr error policyErr error initiateCalls int verifyCalls int updateCalls int signInCalls int createCalls int userExistsCalls int issueCalls int linkInitCalls int verifyCodeCalls int policyCalls int verifyResponse *domain.AuthInfo userExists bool policy *domain.PasswordPolicy } func (s *stubProvider) Name() string { return s.name } func (s *stubProvider) GetMetadata() (*domain.IDPMetadata, error) { return &domain.IDPMetadata{SupportedFields: s.metadata}, nil } func (s *stubProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) { s.createCalls++ if s.createErr != nil { return "", s.createErr } return "created-id", nil } func (s *stubProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) { s.signInCalls++ if s.signInErr != nil { return nil, s.signInErr } return &domain.AuthInfo{Subject: "subject-123"}, nil } func (s *stubProvider) UserExists(loginID string) (bool, error) { s.userExistsCalls++ if s.userExistsErr != nil { return false, s.userExistsErr } return s.userExists, nil } func (s *stubProvider) IssueSession(loginID string) (*domain.AuthInfo, error) { s.issueCalls++ if s.issueErr != nil { return nil, s.issueErr } return &domain.AuthInfo{Subject: "issue-subject"}, nil } func (s *stubProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) { s.linkInitCalls++ if s.linkInitErr != nil { return nil, s.linkInitErr } return &domain.LinkLoginInit{FlowID: "flow-123", Mode: "cookie"}, nil } func (s *stubProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) { s.verifyCodeCalls++ if s.verifyCodeErr != nil { return nil, s.verifyCodeErr } return &domain.AuthInfo{Subject: "verify-code-subject"}, nil } func (s *stubProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) { s.policyCalls++ if s.policyErr != nil { return nil, s.policyErr } return s.policy, nil } func (s *stubProvider) InitiatePasswordReset(loginID, redirectUrl string) error { s.initiateCalls++ return s.initiateErr } func (s *stubProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) { s.verifyCalls++ if s.verifyErr != nil { return nil, s.verifyErr } if s.verifyResponse != nil { return s.verifyResponse, nil } return &domain.AuthInfo{}, nil } func (s *stubProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error { s.updateCalls++ return s.updateErr } func TestChainedProviderMetadataUnion(t *testing.T) { p1 := &stubProvider{name: "primary", metadata: []string{"id", "email"}} p2 := &stubProvider{name: "backup", metadata: []string{"email", "phone_number", "grade"}} chain := newChainedProvider([]domain.IdentityProvider{p1, p2}) meta, err := chain.GetMetadata() if err != nil { t.Fatalf("GetMetadata returned error: %v", err) } expected := []string{"id", "email", "phone_number", "grade"} if !reflect.DeepEqual(meta.SupportedFields, expected) { t.Fatalf("metadata mismatch: got %v, want %v", meta.SupportedFields, expected) } } func TestChainedProviderUpdateUserPasswordFallback(t *testing.T) { p1 := &stubProvider{name: "primary", metadata: []string{"id"}, updateErr: errors.New("boom")} p2 := &stubProvider{name: "backup", metadata: []string{"id"}} chain := newChainedProvider([]domain.IdentityProvider{p1, p2}) if err := chain.UpdateUserPassword("user@example.com", "Sup3r!Pass123", nil); err != nil { t.Fatalf("expected fallback to succeed, got error: %v", err) } if p1.updateCalls != 1 || p2.updateCalls != 1 { t.Fatalf("unexpected call counts: p1=%d p2=%d", p1.updateCalls, p2.updateCalls) } } func TestChainedProviderUpdateUserPasswordAllFail(t *testing.T) { p1 := &stubProvider{name: "primary", metadata: []string{"id"}, updateErr: errors.New("fail1")} p2 := &stubProvider{name: "backup", metadata: []string{"id"}, updateErr: errors.New("fail2")} chain := newChainedProvider([]domain.IdentityProvider{p1, p2}) err := chain.UpdateUserPassword("user@example.com", "Sup3r!Pass123", nil) if err == nil { t.Fatalf("expected error when all providers fail") } if !strings.Contains(err.Error(), "all IDP providers failed") { t.Fatalf("unexpected error: %v", err) } if p1.updateCalls != 1 || p2.updateCalls != 1 { t.Fatalf("unexpected call counts: p1=%d p2=%d", p1.updateCalls, p2.updateCalls) } }