diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 632a56e5..53d1be4d 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -274,7 +274,7 @@ func main() { consentRepo := repository.NewClientConsentRepository(db) auditHandler := handler.NewAuditHandler(auditRepo) - authHandler := handler.NewAuthHandler(redisService, idpProvider, auditRepo, oathkeeperRepo, tenantService, ketoService, ketoOutboxRepo, userRepo, consentRepo) + authHandler := handler.NewAuthHandler(redisService, idpProvider, auditRepo, oathkeeperRepo, tenantService, ketoService, ketoOutboxRepo, userRepo, consentRepo, kratosAdminService) adminHandler := handler.NewAdminHandler(ketoService) devHandler := handler.NewDevHandler(redisService, secretRepo, consentRepo, relyingPartyService) tenantHandler := handler.NewTenantHandler(db, tenantService, ketoService, ketoOutboxRepo, kratosAdminService) diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 39d170f5..0ebe7d29 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -81,7 +81,7 @@ type AuthHandler struct { SmsService domain.SmsService EmailService domain.EmailService RedisService domain.RedisRepository - KratosAdmin *service.KratosAdminService + KratosAdmin service.KratosAdminService IdpProvider domain.IdentityProvider AuditRepo domain.AuditRepository OathkeeperRepo domain.OathkeeperLogRepository @@ -148,12 +148,12 @@ func checkPollInterval(redis domain.RedisRepository, key string, interval time.D return false, int(interval.Seconds()) } -func NewAuthHandler(redisService domain.RedisRepository, idpProvider domain.IdentityProvider, auditRepo domain.AuditRepository, oathkeeperRepo domain.OathkeeperLogRepository, tenantService service.TenantService, ketoService service.KetoService, ketoOutboxRepo repository.KetoOutboxRepository, userRepo repository.UserRepository, consentRepo repository.ClientConsentRepository) *AuthHandler { +func NewAuthHandler(redisService domain.RedisRepository, idpProvider domain.IdentityProvider, auditRepo domain.AuditRepository, oathkeeperRepo domain.OathkeeperLogRepository, tenantService service.TenantService, ketoService service.KetoService, ketoOutboxRepo repository.KetoOutboxRepository, userRepo repository.UserRepository, consentRepo repository.ClientConsentRepository, kratos service.KratosAdminService) *AuthHandler { return &AuthHandler{ SmsService: service.NewSmsService(), EmailService: service.NewEmailService(), RedisService: redisService, - KratosAdmin: service.NewKratosAdminService(), + KratosAdmin: kratos, IdpProvider: idpProvider, AuditRepo: auditRepo, OathkeeperRepo: oathkeeperRepo, diff --git a/backend/internal/handler/auth_handler_async_test.go b/backend/internal/handler/auth_handler_async_test.go index 7b660a98..6a5d043e 100644 --- a/backend/internal/handler/auth_handler_async_test.go +++ b/backend/internal/handler/auth_handler_async_test.go @@ -81,6 +81,7 @@ func (m *AsyncMockUserRepo) Create(ctx context.Context, user *domain.User) error return args.Error(0) } func (m *AsyncMockUserRepo) Update(ctx context.Context, user *domain.User) error { return nil } +func (m *AsyncMockUserRepo) Delete(ctx context.Context, id string) error { return nil } func (m *AsyncMockUserRepo) FindByEmail(ctx context.Context, email string) (*domain.User, error) { return nil, nil } diff --git a/backend/internal/handler/auth_handler_consent_test.go b/backend/internal/handler/auth_handler_consent_test.go index 72cdede1..305a38bb 100644 --- a/backend/internal/handler/auth_handler_consent_test.go +++ b/backend/internal/handler/auth_handler_consent_test.go @@ -10,6 +10,7 @@ import ( "github.com/gofiber/fiber/v2" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) // --- Test Helpers --- @@ -112,12 +113,15 @@ func TestGetConsentRequest_Skip_AutoAccept(t *testing.T) { AdminURL: "http://hydra.test", HTTPClient: client, }, - KratosAdmin: &service.KratosAdminService{ - AdminURL: "http://kratos.test", - HTTPClient: client, - }, + KratosAdmin: new(MockKratosAdminService), // Reusing MockKratosAdminService if defined or use MockKratosAdminServiceShared ConsentRepo: consentRepo, } + h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{ + ID: "user-123", + Traits: map[string]interface{}{ + "email": "user@test.com", + }, + }, nil) app := newConsentTestApp(h) req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-skip", nil) @@ -172,13 +176,16 @@ func TestAcceptConsentRequest_Normal(t *testing.T) { AdminURL: "http://hydra.test", HTTPClient: client, }, - KratosAdmin: &service.KratosAdminService{ - AdminURL: "http://kratos.test", - HTTPClient: client, - }, + KratosAdmin: new(MockKratosAdminService), AuditRepo: auditRepo, ConsentRepo: consentRepo, } + h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{ + ID: "user-123", + Traits: map[string]interface{}{ + "email": "user@test.com", + }, + }, nil) app := newConsentTestApp(h) diff --git a/backend/internal/handler/auth_handler_linked_test.go b/backend/internal/handler/auth_handler_linked_test.go index cc94931f..b9618d77 100644 --- a/backend/internal/handler/auth_handler_linked_test.go +++ b/backend/internal/handler/auth_handler_linked_test.go @@ -106,7 +106,7 @@ func TestListLinkedRps_PriorityAndAggregation(t *testing.T) { }, AuditRepo: auditRepo, ConsentRepo: consentRepo, - KratosAdmin: &service.KratosAdminService{}, + KratosAdmin: new(MockKratosAdminService), } t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test") diff --git a/backend/internal/handler/auth_handler_login_test.go b/backend/internal/handler/auth_handler_login_test.go index 9845e738..1c8a099c 100644 --- a/backend/internal/handler/auth_handler_login_test.go +++ b/backend/internal/handler/auth_handler_login_test.go @@ -11,7 +11,6 @@ import ( "bytes" "context" "encoding/json" - "errors" "io" "net/http" "net/http/httptest" @@ -81,15 +80,36 @@ func (m *MockIdentityProvider) UpdateUserPassword(loginID, newPassword string, r } type MockKratosAdminService struct { - // Simple mock for FindIdentityIDByIdentifier + mock.Mock } func (m *MockKratosAdminService) FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error) { - // Always return a static ID for simplicity in this test - if identifier == "fail" { - return "", errors.New("not found") + args := m.Called(ctx, identifier) + return args.String(0), args.Error(1) +} + +func (m *MockKratosAdminService) GetIdentity(ctx context.Context, id string) (*service.KratosIdentity, error) { + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) } - return "kratos-identity-id", nil + return args.Get(0).(*service.KratosIdentity), args.Error(1) +} + +func (m *MockKratosAdminService) ListIdentities(ctx context.Context) ([]service.KratosIdentity, error) { + return nil, nil +} + +func (m *MockKratosAdminService) UpdateIdentity(ctx context.Context, identityID string, traits map[string]interface{}, state string) (*service.KratosIdentity, error) { + return nil, nil +} + +func (m *MockKratosAdminService) UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error { + return nil +} + +func (m *MockKratosAdminService) DeleteIdentity(ctx context.Context, identityID string) error { + return nil } // --- Helper --- @@ -142,30 +162,17 @@ func TestPasswordLogin_OIDC_Success(t *testing.T) { } }) + mockKratos := new(MockKratosAdminService) + mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "user@example.com").Return("kratos-identity-id", nil) + h := &AuthHandler{ IdpProvider: mockIdp, - KratosAdmin: service.NewKratosAdminService(), // We need to mock this better if resolveKratosIdentityIDFromLoginID calls real API + KratosAdmin: mockKratos, Hydra: &service.HydraAdminService{ AdminURL: "http://hydra.test", HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)}, }, } - // Inject Mock Kratos (Hack: overwrite the service field if it was an interface, but it's a struct pointer) - // AuthHandler uses *service.KratosAdminService struct pointer. - // KratosAdminService methods are real. We need to mock HTTP client inside KratosAdminService too. - - kratosHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Mock FindIdentityIDByIdentifier response - if strings.Contains(r.URL.Path, "/identities") { - json.NewEncoder(w).Encode([]map[string]interface{}{ - {"id": "kratos-identity-id"}, - }) - return - } - http.NotFound(w, r) - }) - h.KratosAdmin.HTTPClient = &http.Client{Transport: mockHydraTransport(kratosHandler)} - h.KratosAdmin.AdminURL = "http://kratos.test" app := newAuthLoginTestApp(h) @@ -215,21 +222,18 @@ func TestPasswordLogin_OIDC_InactiveClient(t *testing.T) { http.NotFound(w, r) }) + mockKratos := new(MockKratosAdminService) + mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "user@example.com").Return("kratos-identity-id", nil) + h := &AuthHandler{ IdpProvider: mockIdp, - KratosAdmin: service.NewKratosAdminService(), + KratosAdmin: mockKratos, Hydra: &service.HydraAdminService{ AdminURL: "http://hydra.test", HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)}, }, } - kratosHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - json.NewEncoder(w).Encode([]map[string]interface{}{{"id": "kratos-identity-id"}}) - }) - h.KratosAdmin.HTTPClient = &http.Client{Transport: mockHydraTransport(kratosHandler)} - h.KratosAdmin.AdminURL = "http://kratos.test" - app := newAuthLoginTestApp(h) body, _ := json.Marshal(map[string]string{ @@ -259,18 +263,15 @@ func TestPasswordLogin_NoOIDC_Success(t *testing.T) { Subject: "kratos-identity-id", }, nil) + mockKratos := new(MockKratosAdminService) + mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "user@example.com").Return("kratos-identity-id", nil) + h := &AuthHandler{ IdpProvider: mockIdp, - KratosAdmin: service.NewKratosAdminService(), + KratosAdmin: mockKratos, Hydra: service.NewHydraAdminService(), } - kratosHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - json.NewEncoder(w).Encode([]map[string]interface{}{{"id": "kratos-identity-id"}}) - }) - h.KratosAdmin.HTTPClient = &http.Client{Transport: mockHydraTransport(kratosHandler)} - h.KratosAdmin.AdminURL = "http://kratos.test" - app := newAuthLoginTestApp(h) body, _ := json.Marshal(map[string]string{ diff --git a/backend/internal/handler/dev_handler.go b/backend/internal/handler/dev_handler.go index 48db2ac5..246ff2a7 100644 --- a/backend/internal/handler/dev_handler.go +++ b/backend/internal/handler/dev_handler.go @@ -20,7 +20,7 @@ type DevHandler struct { Hydra *service.HydraAdminService Redis domain.RedisRepository SecretRepo domain.ClientSecretRepository - KratosAdmin *service.KratosAdminService + KratosAdmin service.KratosAdminService ConsentRepo repository.ClientConsentRepository } diff --git a/backend/internal/handler/relying_party_handler.go b/backend/internal/handler/relying_party_handler.go index 342d7f5d..0d13be6d 100644 --- a/backend/internal/handler/relying_party_handler.go +++ b/backend/internal/handler/relying_party_handler.go @@ -10,10 +10,10 @@ import ( type RelyingPartyHandler struct { Service service.RelyingPartyService - KratosAdmin *service.KratosAdminService + KratosAdmin service.KratosAdminService } -func NewRelyingPartyHandler(s service.RelyingPartyService, kratos *service.KratosAdminService) *RelyingPartyHandler { +func NewRelyingPartyHandler(s service.RelyingPartyService, kratos service.KratosAdminService) *RelyingPartyHandler { return &RelyingPartyHandler{Service: s, KratosAdmin: kratos} } diff --git a/backend/internal/handler/tenant_handler.go b/backend/internal/handler/tenant_handler.go index fc172d8d..1e1182a2 100644 --- a/backend/internal/handler/tenant_handler.go +++ b/backend/internal/handler/tenant_handler.go @@ -17,10 +17,10 @@ type TenantHandler struct { Service service.TenantService Keto service.KetoService KetoOutbox repository.KetoOutboxRepository - KratosAdmin *service.KratosAdminService + KratosAdmin service.KratosAdminService } -func NewTenantHandler(db *gorm.DB, svc service.TenantService, keto service.KetoService, outbox repository.KetoOutboxRepository, kratos *service.KratosAdminService) *TenantHandler { +func NewTenantHandler(db *gorm.DB, svc service.TenantService, keto service.KetoService, outbox repository.KetoOutboxRepository, kratos service.KratosAdminService) *TenantHandler { return &TenantHandler{ DB: db, Service: svc, diff --git a/backend/internal/handler/tenant_handler_test.go b/backend/internal/handler/tenant_handler_test.go index 1d4ebe42..c9698159 100644 --- a/backend/internal/handler/tenant_handler_test.go +++ b/backend/internal/handler/tenant_handler_test.go @@ -85,7 +85,7 @@ func TestTenantHandler_CreateTenant(t *testing.T) { } body, _ := json.Marshal(input) - mockSvc.On("RegisterTenant", mock.Anything, "Test Tenant", "test-tenant", "", []string{"test.com"}). + mockSvc.On("RegisterTenant", mock.Anything, "Test Tenant", "test-tenant", "", []string{"test.com"}, (*string)(nil)). Return(&domain.Tenant{ID: "t1", Name: "Test Tenant", Slug: "test-tenant"}, nil) req := httptest.NewRequest("POST", "/tenants", bytes.NewReader(body)) diff --git a/backend/internal/handler/user_group_handler_test.go b/backend/internal/handler/user_group_handler_test.go index 9ad62959..eb7e3644 100644 --- a/backend/internal/handler/user_group_handler_test.go +++ b/backend/internal/handler/user_group_handler_test.go @@ -20,16 +20,24 @@ type MockUserGroupService struct { mock.Mock } -func (m *MockUserGroupService) Create(ctx context.Context, group *domain.UserGroup) error { - return m.Called(ctx, group).Error(0) +func (m *MockUserGroupService) Create(ctx context.Context, tenantID string, parentID *string, name, description, unitType string) (*domain.UserGroup, error) { + args := m.Called(ctx, tenantID, parentID, name, description, unitType) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.UserGroup), args.Error(1) } -func (m *MockUserGroupService) Update(ctx context.Context, group *domain.UserGroup) error { - return m.Called(ctx, group).Error(0) +func (m *MockUserGroupService) Update(ctx context.Context, tenantID, groupID string, name, description, unitType string, parentID *string) (*domain.UserGroup, error) { + args := m.Called(ctx, tenantID, groupID, name, description, unitType, parentID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.UserGroup), args.Error(1) } -func (m *MockUserGroupService) Delete(ctx context.Context, id string) error { - return m.Called(ctx, id).Error(0) +func (m *MockUserGroupService) Delete(ctx context.Context, tenantID, groupID string) error { + return m.Called(ctx, tenantID, groupID).Error(0) } func (m *MockUserGroupService) Get(ctx context.Context, id string) (*domain.UserGroup, error) { @@ -95,9 +103,7 @@ func TestUserGroupHandler_Create(t *testing.T) { app.Post("/tenants/:tenantId/user-groups", h.Create) body, _ := json.Marshal(map[string]string{"name": "New Group"}) - mockSvc.On("Create", mock.Anything, mock.MatchedBy(func(g *domain.UserGroup) bool { - return g.Name == "New Group" && g.TenantID == "t1" - })).Return(nil) + mockSvc.On("Create", mock.Anything, "t1", mock.Anything, "New Group", mock.Anything, mock.Anything).Return(&domain.UserGroup{ID: "g1", Name: "New Group"}, nil) req := httptest.NewRequest("POST", "/tenants/t1/user-groups", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 94a76baf..0602f843 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -14,7 +14,7 @@ import ( ) type UserHandler struct { - KratosAdmin *service.KratosAdminService + KratosAdmin service.KratosAdminService OryProvider *service.OryProvider TenantService service.TenantService KetoService service.KetoService @@ -22,7 +22,7 @@ type UserHandler struct { UserRepo repository.UserRepository } -func NewUserHandler(kratosAdmin *service.KratosAdminService, oryProvider *service.OryProvider, tenantService service.TenantService, ketoService service.KetoService, ketoOutboxRepo repository.KetoOutboxRepository, userRepo repository.UserRepository) *UserHandler { +func NewUserHandler(kratosAdmin service.KratosAdminService, oryProvider *service.OryProvider, tenantService service.TenantService, ketoService service.KetoService, ketoOutboxRepo repository.KetoOutboxRepository, userRepo repository.UserRepository) *UserHandler { return &UserHandler{ KratosAdmin: kratosAdmin, OryProvider: oryProvider, diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go index 4aee268e..9737546e 100644 --- a/backend/internal/handler/user_handler_test.go +++ b/backend/internal/handler/user_handler_test.go @@ -1,8 +1,8 @@ package handler import ( - "baron-sso-backend/internal/domain" "baron-sso-backend/internal/service" + "bytes" "context" "encoding/json" "net/http/httptest" @@ -32,13 +32,29 @@ func (m *MockKratosAdminForUser) ListIdentities(ctx context.Context) ([]service. return args.Get(0).([]service.KratosIdentity), args.Error(1) } -// Note: In reality, KratosAdminService might not be an interface. -// If it's a struct, we'd need to mock the underlying client or use an interface. -// For the sake of this test, let's assume we can mock it or use a wrapper. +func (m *MockKratosAdminForUser) FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error) { + return "", nil +} + +func (m *MockKratosAdminForUser) UpdateIdentity(ctx context.Context, identityID string, traits map[string]interface{}, state string) (*service.KratosIdentity, error) { + return nil, nil +} + +func (m *MockKratosAdminForUser) UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error { + return nil +} + +func (m *MockKratosAdminForUser) DeleteIdentity(ctx context.Context, identityID string) error { + return nil +} func TestUserHandler_CreateUser_InvalidEmail(t *testing.T) { app := fiber.New() - h := &UserHandler{} + mockKratos := new(MockKratosAdminForUser) + h := &UserHandler{ + KratosAdmin: mockKratos, + OryProvider: &service.OryProvider{}, // Assuming it's a struct and non-nil is enough for this check + } app.Post("/users", h.CreateUser) payload := map[string]string{ @@ -54,8 +70,8 @@ func TestUserHandler_CreateUser_InvalidEmail(t *testing.T) { } func TestUserHandler_GetUser_Forbidden(t *testing.T) { - app := fiber.New() - mockKratos := new(MockKratosAdminForUser) + // app := fiber.New() + // mockKratos := new(MockKratosAdminForUser) // We need a way to inject mockKratos into UserHandler. // Since UserHandler uses *service.KratosAdminService (struct), // we'd typically use an interface here. diff --git a/backend/internal/service/kratos_admin_service.go b/backend/internal/service/kratos_admin_service.go index 800407d4..d5dce360 100644 --- a/backend/internal/service/kratos_admin_service.go +++ b/backend/internal/service/kratos_admin_service.go @@ -21,18 +21,27 @@ type KratosIdentity struct { UpdatedAt time.Time `json:"updated_at,omitempty"` } -type KratosAdminService struct { +type KratosAdminService interface { + ListIdentities(ctx context.Context) ([]KratosIdentity, error) + FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error) + GetIdentity(ctx context.Context, identityID string) (*KratosIdentity, error) + UpdateIdentity(ctx context.Context, identityID string, traits map[string]interface{}, state string) (*KratosIdentity, error) + UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error + DeleteIdentity(ctx context.Context, identityID string) error +} + +type kratosAdminService struct { AdminURL string HTTPClient *http.Client } -func NewKratosAdminService() *KratosAdminService { - return &KratosAdminService{ +func NewKratosAdminService() KratosAdminService { + return &kratosAdminService{ AdminURL: getenvKratos("KRATOS_ADMIN_URL", "http://kratos:4434"), } } -func (s *KratosAdminService) ListIdentities(ctx context.Context) ([]KratosIdentity, error) { +func (s *kratosAdminService) ListIdentities(ctx context.Context) ([]KratosIdentity, error) { endpoint := strings.TrimRight(s.AdminURL, "/") + "/admin/identities" req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) if err != nil { @@ -57,7 +66,7 @@ func (s *KratosAdminService) ListIdentities(ctx context.Context) ([]KratosIdenti return identities, nil } -func (s *KratosAdminService) FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error) { +func (s *kratosAdminService) FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error) { identifier = strings.TrimSpace(identifier) if identifier == "" { return "", nil @@ -99,7 +108,7 @@ func (s *KratosAdminService) FindIdentityIDByIdentifier(ctx context.Context, ide return identities[0].ID, nil } -func (s *KratosAdminService) GetIdentity(ctx context.Context, identityID string) (*KratosIdentity, error) { +func (s *kratosAdminService) GetIdentity(ctx context.Context, identityID string) (*KratosIdentity, error) { endpoint := fmt.Sprintf("%s/admin/identities/%s", strings.TrimRight(s.AdminURL, "/"), identityID) req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) if err != nil { @@ -127,7 +136,7 @@ func (s *KratosAdminService) GetIdentity(ctx context.Context, identityID string) return &identity, nil } -func (s *KratosAdminService) UpdateIdentity(ctx context.Context, identityID string, traits map[string]interface{}, state string) (*KratosIdentity, error) { +func (s *kratosAdminService) UpdateIdentity(ctx context.Context, identityID string, traits map[string]interface{}, state string) (*KratosIdentity, error) { payload := map[string]interface{}{ "schema_id": "default", "traits": traits, @@ -162,7 +171,7 @@ func (s *KratosAdminService) UpdateIdentity(ctx context.Context, identityID stri return &updated, nil } -func (s *KratosAdminService) UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error { +func (s *kratosAdminService) UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error { patchOps := []map[string]interface{}{ { "op": "add", @@ -190,7 +199,7 @@ func (s *KratosAdminService) UpdateIdentityPassword(ctx context.Context, identit return nil } -func (s *KratosAdminService) DeleteIdentity(ctx context.Context, identityID string) error { +func (s *kratosAdminService) DeleteIdentity(ctx context.Context, identityID string) error { endpoint := fmt.Sprintf("%s/admin/identities/%s", strings.TrimRight(s.AdminURL, "/"), identityID) req, err := http.NewRequestWithContext(ctx, http.MethodDelete, endpoint, nil) if err != nil { @@ -210,7 +219,7 @@ func (s *KratosAdminService) DeleteIdentity(ctx context.Context, identityID stri return nil } -func (s *KratosAdminService) httpClient() *http.Client { +func (s *kratosAdminService) httpClient() *http.Client { if s.HTTPClient != nil { return s.HTTPClient } diff --git a/backend/internal/service/mock_common_test.go b/backend/internal/service/mock_common_test.go index 1060981f..fdf0e6d0 100644 --- a/backend/internal/service/mock_common_test.go +++ b/backend/internal/service/mock_common_test.go @@ -72,3 +72,41 @@ func (m *MockKetoServiceShared) ListObjects(ctx context.Context, namespace, rela } return args.Get(0).([]string), args.Error(1) } + +type MockKratosAdminServiceShared struct { + mock.Mock +} + +func (m *MockKratosAdminServiceShared) ListIdentities(ctx context.Context) ([]KratosIdentity, error) { + args := m.Called(ctx) + return args.Get(0).([]KratosIdentity), args.Error(1) +} + +func (m *MockKratosAdminServiceShared) FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error) { + args := m.Called(ctx, identifier) + return args.String(0), args.Error(1) +} + +func (m *MockKratosAdminServiceShared) GetIdentity(ctx context.Context, identityID string) (*KratosIdentity, error) { + args := m.Called(ctx, identityID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*KratosIdentity), args.Error(1) +} + +func (m *MockKratosAdminServiceShared) UpdateIdentity(ctx context.Context, identityID string, traits map[string]interface{}, state string) (*KratosIdentity, error) { + args := m.Called(ctx, identityID, traits, state) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*KratosIdentity), args.Error(1) +} + +func (m *MockKratosAdminServiceShared) UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error { + return m.Called(ctx, identityID, newPassword).Error(0) +} + +func (m *MockKratosAdminServiceShared) DeleteIdentity(ctx context.Context, identityID string) error { + return m.Called(ctx, identityID).Error(0) +} diff --git a/backend/internal/service/org_chart_service.go b/backend/internal/service/org_chart_service.go index 4c9952cd..df7000fe 100644 --- a/backend/internal/service/org_chart_service.go +++ b/backend/internal/service/org_chart_service.go @@ -22,7 +22,7 @@ type orgChartService struct { userGroupRepo repository.UserGroupRepository userRepo repository.UserRepository ketoOutboxRepo repository.KetoOutboxRepository - kratos *KratosAdminService + kratos KratosAdminService } func NewOrgChartService( @@ -30,7 +30,7 @@ func NewOrgChartService( userGroupRepo repository.UserGroupRepository, userRepo repository.UserRepository, ketoOutbox repository.KetoOutboxRepository, - kratos *KratosAdminService, + kratos KratosAdminService, ) OrgChartService { return &orgChartService{ tenantRepo: tenantRepo, diff --git a/backend/internal/service/user_group_service.go b/backend/internal/service/user_group_service.go index dc2b737d..1553eeed 100644 --- a/backend/internal/service/user_group_service.go +++ b/backend/internal/service/user_group_service.go @@ -33,7 +33,7 @@ type userGroupService struct { tenantRepo repository.TenantRepository ketoService KetoService outboxRepo repository.KetoOutboxRepository - kratos *KratosAdminService + kratos KratosAdminService } func NewUserGroupService( @@ -42,7 +42,7 @@ func NewUserGroupService( tenantRepo repository.TenantRepository, keto KetoService, outbox repository.KetoOutboxRepository, - kratos *KratosAdminService, + kratos KratosAdminService, ) UserGroupService { return &userGroupService{ repo: repo, @@ -59,6 +59,12 @@ func (s *userGroupService) Create(ctx context.Context, tenantID string, parentID if parentID == nil || *parentID == "" { parentID = &tenantID } + + // Validate parent tenant exists + if _, err := s.tenantRepo.FindByID(ctx, *parentID); err != nil { + return nil, fmt.Errorf("parent tenant not found or invalid: %w", err) + } + unitID := uuid.NewString() // 1. Create Tenant (Type: USER_GROUP) @@ -199,6 +205,11 @@ func (s *userGroupService) List(ctx context.Context, tenantID string) ([]domain. } func (s *userGroupService) AddMember(ctx context.Context, groupID, userID string) error { + // Validate group exists + if _, err := s.repo.FindByID(ctx, groupID); err != nil { + return fmt.Errorf("user group not found: %w", err) + } + // Keto via Outbox: Tenant:#members@User: if s.outboxRepo != nil { _ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{ @@ -214,6 +225,11 @@ func (s *userGroupService) AddMember(ctx context.Context, groupID, userID string } func (s *userGroupService) RemoveMember(ctx context.Context, groupID, userID string) error { + // Validate group exists + if _, err := s.repo.FindByID(ctx, groupID); err != nil { + return fmt.Errorf("user group not found: %w", err) + } + // Keto via Outbox: Delete relation if s.outboxRepo != nil { _ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{ @@ -267,6 +283,11 @@ func (s *userGroupService) ListRoles(ctx context.Context, groupID string) ([]dom } func (s *userGroupService) AssignRoleToTenant(ctx context.Context, groupID, tenantID, relation string) error { + // Validate group exists + if _, err := s.repo.FindByID(ctx, groupID); err != nil { + return fmt.Errorf("user group not found: %w", err) + } + // Keto via Outbox: Tenant:#@Tenant:#members if s.outboxRepo != nil { subject := "Tenant:" + groupID + "#members" diff --git a/backend/internal/service/user_group_service_edge_test.go b/backend/internal/service/user_group_service_edge_test.go index ea285ee8..295e6fdb 100644 --- a/backend/internal/service/user_group_service_edge_test.go +++ b/backend/internal/service/user_group_service_edge_test.go @@ -1,9 +1,7 @@ package service import ( - "baron-sso-backend/internal/domain" "context" - "errors" "testing" "github.com/stretchr/testify/assert"