forked from baron/baron-sso
168 lines
4.7 KiB
Go
168 lines
4.7 KiB
Go
package idp
|
|
|
|
import (
|
|
"baron-sso-backend/internal/domain"
|
|
"errors"
|
|
"net/http"
|
|
"reflect"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
type stubProvider struct {
|
|
name string
|
|
metadata []string
|
|
createErr error
|
|
initiateErr error
|
|
verifyErr error
|
|
updateErr error
|
|
signInErr error
|
|
userExistsErr error
|
|
issueErr error
|
|
linkInitErr error
|
|
verifyCodeErr error
|
|
policyErr error
|
|
initiateCalls int
|
|
verifyCalls int
|
|
updateCalls int
|
|
signInCalls int
|
|
createCalls int
|
|
userExistsCalls int
|
|
issueCalls int
|
|
linkInitCalls int
|
|
verifyCodeCalls int
|
|
policyCalls int
|
|
verifyResponse *domain.AuthInfo
|
|
userExists bool
|
|
policy *domain.PasswordPolicy
|
|
}
|
|
|
|
func (s *stubProvider) Name() string { return s.name }
|
|
|
|
func (s *stubProvider) GetMetadata() (*domain.IDPMetadata, error) {
|
|
return &domain.IDPMetadata{SupportedFields: s.metadata}, nil
|
|
}
|
|
|
|
func (s *stubProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
|
|
s.createCalls++
|
|
if s.createErr != nil {
|
|
return "", s.createErr
|
|
}
|
|
return "created-id", nil
|
|
}
|
|
|
|
func (s *stubProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
|
|
s.signInCalls++
|
|
if s.signInErr != nil {
|
|
return nil, s.signInErr
|
|
}
|
|
return &domain.AuthInfo{Subject: "subject-123"}, nil
|
|
}
|
|
|
|
func (s *stubProvider) UserExists(loginID string) (bool, error) {
|
|
s.userExistsCalls++
|
|
if s.userExistsErr != nil {
|
|
return false, s.userExistsErr
|
|
}
|
|
return s.userExists, nil
|
|
}
|
|
|
|
func (s *stubProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
|
|
s.issueCalls++
|
|
if s.issueErr != nil {
|
|
return nil, s.issueErr
|
|
}
|
|
return &domain.AuthInfo{Subject: "issue-subject"}, nil
|
|
}
|
|
|
|
func (s *stubProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
|
s.linkInitCalls++
|
|
if s.linkInitErr != nil {
|
|
return nil, s.linkInitErr
|
|
}
|
|
return &domain.LinkLoginInit{FlowID: "flow-123", Mode: "cookie"}, nil
|
|
}
|
|
|
|
func (s *stubProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
|
s.verifyCodeCalls++
|
|
if s.verifyCodeErr != nil {
|
|
return nil, s.verifyCodeErr
|
|
}
|
|
return &domain.AuthInfo{Subject: "verify-code-subject"}, nil
|
|
}
|
|
|
|
func (s *stubProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
|
|
s.policyCalls++
|
|
if s.policyErr != nil {
|
|
return nil, s.policyErr
|
|
}
|
|
return s.policy, nil
|
|
}
|
|
|
|
func (s *stubProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
|
|
s.initiateCalls++
|
|
return s.initiateErr
|
|
}
|
|
|
|
func (s *stubProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
|
|
s.verifyCalls++
|
|
if s.verifyErr != nil {
|
|
return nil, s.verifyErr
|
|
}
|
|
if s.verifyResponse != nil {
|
|
return s.verifyResponse, nil
|
|
}
|
|
return &domain.AuthInfo{}, nil
|
|
}
|
|
|
|
func (s *stubProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
|
|
s.updateCalls++
|
|
return s.updateErr
|
|
}
|
|
|
|
func TestChainedProviderMetadataUnion(t *testing.T) {
|
|
p1 := &stubProvider{name: "primary", metadata: []string{"id", "email"}}
|
|
p2 := &stubProvider{name: "backup", metadata: []string{"email", "phone_number", "grade"}}
|
|
|
|
chain := newChainedProvider([]domain.IdentityProvider{p1, p2})
|
|
meta, err := chain.GetMetadata()
|
|
if err != nil {
|
|
t.Fatalf("GetMetadata returned error: %v", err)
|
|
}
|
|
|
|
expected := []string{"id", "email", "phone_number", "grade"}
|
|
if !reflect.DeepEqual(meta.SupportedFields, expected) {
|
|
t.Fatalf("metadata mismatch: got %v, want %v", meta.SupportedFields, expected)
|
|
}
|
|
}
|
|
|
|
func TestChainedProviderUpdateUserPasswordFallback(t *testing.T) {
|
|
p1 := &stubProvider{name: "primary", metadata: []string{"id"}, updateErr: errors.New("boom")}
|
|
p2 := &stubProvider{name: "backup", metadata: []string{"id"}}
|
|
|
|
chain := newChainedProvider([]domain.IdentityProvider{p1, p2})
|
|
if err := chain.UpdateUserPassword("user@example.com", "Sup3r!Pass123", nil); err != nil {
|
|
t.Fatalf("expected fallback to succeed, got error: %v", err)
|
|
}
|
|
if p1.updateCalls != 1 || p2.updateCalls != 1 {
|
|
t.Fatalf("unexpected call counts: p1=%d p2=%d", p1.updateCalls, p2.updateCalls)
|
|
}
|
|
}
|
|
|
|
func TestChainedProviderUpdateUserPasswordAllFail(t *testing.T) {
|
|
p1 := &stubProvider{name: "primary", metadata: []string{"id"}, updateErr: errors.New("fail1")}
|
|
p2 := &stubProvider{name: "backup", metadata: []string{"id"}, updateErr: errors.New("fail2")}
|
|
|
|
chain := newChainedProvider([]domain.IdentityProvider{p1, p2})
|
|
err := chain.UpdateUserPassword("user@example.com", "Sup3r!Pass123", nil)
|
|
if err == nil {
|
|
t.Fatalf("expected error when all providers fail")
|
|
}
|
|
if !strings.Contains(err.Error(), "all IDP providers failed") {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if p1.updateCalls != 1 || p2.updateCalls != 1 {
|
|
t.Fatalf("unexpected call counts: p1=%d p2=%d", p1.updateCalls, p2.updateCalls)
|
|
}
|
|
}
|