package service import ( "baron-sso-backend/internal/domain" "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) // MockTenantRepository is a mock implementation of repository.TenantRepository type MockTenantRepository struct { mock.Mock } func (m *MockTenantRepository) Create(ctx context.Context, tenant *domain.Tenant) error { return m.Called(ctx, tenant).Error(0) } func (m *MockTenantRepository) Update(ctx context.Context, tenant *domain.Tenant) error { return m.Called(ctx, tenant).Error(0) } func (m *MockTenantRepository) FindByID(ctx context.Context, id string) (*domain.Tenant, error) { args := m.Called(ctx, id) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*domain.Tenant), args.Error(1) } func (m *MockTenantRepository) FindBySlug(ctx context.Context, slug string) (*domain.Tenant, error) { args := m.Called(ctx, slug) return args.Get(0).(*domain.Tenant), args.Error(1) } func (m *MockTenantRepository) FindByName(ctx context.Context, name string) (*domain.Tenant, error) { args := m.Called(ctx, name) return args.Get(0).(*domain.Tenant), args.Error(1) } func (m *MockTenantRepository) FindByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) { args := m.Called(ctx, domainName) return args.Get(0).(*domain.Tenant), args.Error(1) } func (m *MockTenantRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.Tenant, error) { args := m.Called(ctx, ids) return args.Get(0).([]domain.Tenant), args.Error(1) } func (m *MockTenantRepository) AddDomain(ctx context.Context, tenantID string, domainName string) error { return m.Called(ctx, tenantID, domainName).Error(0) } // MockKetoService is a mock implementation of KetoService type MockKetoService struct { mock.Mock } func (m *MockKetoService) CheckPermission(ctx context.Context, subject, namespace, object, relation string) (bool, error) { args := m.Called(ctx, subject, namespace, object, relation) return args.Bool(0), args.Error(1) } func (m *MockKetoService) CreateRelation(ctx context.Context, namespace, object, relation, subject string) error { return m.Called(ctx, namespace, object, relation, subject).Error(0) } func (m *MockKetoService) DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error { return m.Called(ctx, namespace, object, relation, subject).Error(0) } func (m *MockKetoService) ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]RelationTuple, error) { args := m.Called(ctx, namespace, object, relation, subject) return args.Get(0).([]RelationTuple), args.Error(1) } func (m *MockKetoService) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) { args := m.Called(ctx, namespace, relation, subject) return args.Get(0).([]string), args.Error(1) } func TestTenantService_ListManageableTenants_Inheritance(t *testing.T) { mockRepo := new(MockTenantRepository) mockKeto := new(MockKetoService) svc := &tenantService{ repo: mockRepo, keto: mockKeto, } userID := "user-123" ctx := context.Background() // 1. Mock direct tenant management (admins relation) mockKeto.On("ListObjects", ctx, "Tenant", "admins", userID).Return([]string{"t-direct-1"}, nil) // 2. Mock group management (admins of a group) mockKeto.On("ListObjects", ctx, "TenantGroup", "admins", userID).Return([]string{"g-1"}, nil) // 3. Mock tenants belonging to group g-1 mockKeto.On("ListRelations", ctx, "Tenant", "", "parent_group", "TenantGroup:g-1").Return([]RelationTuple{ {Object: "t-inherited-1", Relation: "parent_group", SubjectID: "TenantGroup:g-1"}, {Object: "t-inherited-2", Relation: "parent_group", SubjectID: "TenantGroup:g-1"}, }, nil) // 4. Expect repository to fetch all unique IDs: t-direct-1, t-inherited-1, t-inherited-2 expectedIDs := []string{"t-direct-1", "t-inherited-1", "t-inherited-2"} mockRepo.On("FindByIDs", ctx, mock.MatchedBy(func(ids []string) bool { // Check if all expected IDs are present (order doesn't matter since we dedup via map) foundCount := 0 for _, eid := range expectedIDs { for _, id := range ids { if id == eid { foundCount++ break } } } return foundCount == len(expectedIDs) && len(ids) == len(expectedIDs) })).Return([]domain.Tenant{ {ID: "t-direct-1", Name: "Direct Tenant"}, {ID: "t-inherited-1", Name: "Inherited Tenant 1"}, {ID: "t-inherited-2", Name: "Inherited Tenant 2"}, }, nil) // Execute tenants, err := svc.ListManageableTenants(ctx, userID) // Verify assert.NoError(t, err) assert.Len(t, tenants, 3) mockKeto.AssertExpectations(t) mockRepo.AssertExpectations(t) }