diff --git a/backend/internal/domain/models.go b/backend/internal/domain/models.go index d5b749b8..fe911ac2 100644 --- a/backend/internal/domain/models.go +++ b/backend/internal/domain/models.go @@ -32,3 +32,13 @@ type AuditCursor struct { Timestamp time.Time EventID string } + +// RedisRepository defines interface for KV storage (Redis) +type RedisRepository interface { + Set(key string, value string, expiration time.Duration) error + Get(key string) (string, error) + Delete(key string) error + StoreVerificationCode(phone, code string) error + GetVerificationCode(phone string) (string, error) + DeleteVerificationCode(phone string) error +} diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 9941f64d..5bc2dc16 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -80,7 +80,7 @@ const ( type AuthHandler struct { SmsService domain.SmsService EmailService domain.EmailService - RedisService *service.RedisService + RedisService domain.RedisRepository KratosAdmin *service.KratosAdminService IdpProvider domain.IdentityProvider AuditRepo domain.AuditRepository @@ -132,7 +132,7 @@ func GenerateUserCode() string { ) } -func checkPollInterval(redis *service.RedisService, key string, interval time.Duration) (bool, int) { +func checkPollInterval(redis domain.RedisRepository, key string, interval time.Duration) (bool, int) { now := time.Now().UnixMilli() val, err := redis.Get(key) if err == nil && val != "" { @@ -147,7 +147,7 @@ func checkPollInterval(redis *service.RedisService, key string, interval time.Du return false, int(interval.Seconds()) } -func NewAuthHandler(redisService *service.RedisService, idpProvider domain.IdentityProvider, auditRepo domain.AuditRepository, oathkeeperRepo domain.OathkeeperLogRepository, tenantService service.TenantService, ketoService service.KetoService, userRepo repository.UserRepository, consentRepo repository.ClientConsentRepository) *AuthHandler { +func NewAuthHandler(redisService domain.RedisRepository, idpProvider domain.IdentityProvider, auditRepo domain.AuditRepository, oathkeeperRepo domain.OathkeeperLogRepository, tenantService service.TenantService, ketoService service.KetoService, userRepo repository.UserRepository, consentRepo repository.ClientConsentRepository) *AuthHandler { return &AuthHandler{ SmsService: service.NewSmsService(), EmailService: service.NewEmailService(), diff --git a/backend/internal/handler/common_test.go b/backend/internal/handler/common_test.go new file mode 100644 index 00000000..7934299b --- /dev/null +++ b/backend/internal/handler/common_test.go @@ -0,0 +1,149 @@ +package handler + +import ( + "baron-sso-backend/internal/domain" + "bytes" + "context" + "encoding/json" + "io" + "net/http" +) + +// --- Mock IDP Provider --- + +type mockIdpProvider struct { + userExists bool + name string + signInInfo *domain.AuthInfo + issueSession *domain.AuthInfo + verifyCodeInfo *domain.AuthInfo + err error + initiateLinkErr error +} + +func (m *mockIdpProvider) Name() string { + if m.name != "" { + return m.name + } + return "mock-idp" +} + +func (m *mockIdpProvider) GetMetadata() (*domain.IDPMetadata, error) { return nil, m.err } +func (m *mockIdpProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) { + return "mock-user-id", m.err +} +func (m *mockIdpProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) { + return m.signInInfo, m.err +} +func (m *mockIdpProvider) UserExists(loginID string) (bool, error) { return m.userExists, m.err } +func (m *mockIdpProvider) IssueSession(loginID string) (*domain.AuthInfo, error) { + if m.issueSession != nil { + return m.issueSession, m.err + } + return &domain.AuthInfo{ + SessionToken: &domain.Token{JWT: "valid-jwt", SessionID: "valid-sid"}, + }, m.err +} +func (m *mockIdpProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) { + if m.initiateLinkErr != nil { + return nil, m.initiateLinkErr + } + return &domain.LinkLoginInit{FlowID: "mock-flow-id", Mode: "code"}, m.err +} +func (m *mockIdpProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) { + return m.verifyCodeInfo, m.err +} +func (m *mockIdpProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) { return nil, m.err } +func (m *mockIdpProvider) InitiatePasswordReset(loginID, redirectUrl string) error { return m.err } +func (m *mockIdpProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) { + return nil, m.err +} +func (m *mockIdpProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error { + return m.err +} + +// --- Mock Audit Repository --- + +type mockAuditRepo struct { + logs []domain.AuditLog +} + +func (m *mockAuditRepo) Create(log *domain.AuditLog) error { + m.logs = append(m.logs, *log) + return nil +} +func (m *mockAuditRepo) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor) ([]domain.AuditLog, error) { + return m.logs, nil +} +func (m *mockAuditRepo) FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]domain.AuditLog, error) { + var results []domain.AuditLog + for _, log := range m.logs { + if log.UserID == userID { + for _, et := range eventTypes { + if log.EventType == et { + results = append(results, log) + break + } + } + } + } + return results, nil +} +func (m *mockAuditRepo) Ping(ctx context.Context) error { return nil } + +// --- Mock Consent Repository --- + +type mockConsentRepo struct { + consents []domain.ClientConsent +} + +func (m *mockConsentRepo) Upsert(ctx context.Context, consent *domain.ClientConsent) error { + m.consents = append(m.consents, *consent) + return nil +} +func (m *mockConsentRepo) ListBySubject(ctx context.Context, subject string) ([]domain.ClientConsent, error) { + var results []domain.ClientConsent + for _, c := range m.consents { + if c.Subject == subject { + results = append(results, c) + } + } + return results, nil +} +func (m *mockConsentRepo) Delete(ctx context.Context, clientID, subject string) error { return nil } +func (m *mockConsentRepo) List(ctx context.Context, clientID string, limit, offset int) ([]domain.ClientConsentWithTenantInfo, int64, error) { + return nil, 0, nil +} +func (m *mockConsentRepo) ListByTenant(ctx context.Context, clientID, tenantID string, limit, offset int) ([]domain.ClientConsentWithTenantInfo, int64, error) { + return nil, 0, nil +} + + +// --- HTTP Mock Helpers --- + +type roundTripFunc func(req *http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func httpResponse(r *http.Request, code int, body string) *http.Response { + return &http.Response{ + StatusCode: code, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewBufferString(body)), + Request: r, + } +} + +func httpJSONAny(r *http.Request, code int, data any) *http.Response { + body, _ := json.Marshal(data) + return &http.Response{ + StatusCode: code, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + }, + Body: io.NopCloser(bytes.NewBuffer(body)), + Request: r, + } +} diff --git a/backend/internal/handler/dev_handler.go b/backend/internal/handler/dev_handler.go index 0d05d87d..6514d741 100644 --- a/backend/internal/handler/dev_handler.go +++ b/backend/internal/handler/dev_handler.go @@ -18,13 +18,13 @@ import ( type DevHandler struct { Hydra *service.HydraAdminService - Redis *service.RedisService + Redis domain.RedisRepository SecretRepo domain.ClientSecretRepository KratosAdmin *service.KratosAdminService ConsentRepo repository.ClientConsentRepository } -func NewDevHandler(redis *service.RedisService, secretRepo domain.ClientSecretRepository, consentRepo repository.ClientConsentRepository) *DevHandler { +func NewDevHandler(redis domain.RedisRepository, secretRepo domain.ClientSecretRepository, consentRepo repository.ClientConsentRepository) *DevHandler { return &DevHandler{ Hydra: service.NewHydraAdminService(), Redis: redis,