forked from baron/baron-sso
109 lines
3.5 KiB
Go
109 lines
3.5 KiB
Go
package service
|
|
|
|
import (
|
|
"baron-sso-backend/internal/domain"
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
)
|
|
|
|
// MockTenantRepository is a mock implementation of repository.TenantRepository
|
|
type MockTenantRepository struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockTenantRepository) Create(ctx context.Context, tenant *domain.Tenant) error {
|
|
return m.Called(ctx, tenant).Error(0)
|
|
}
|
|
|
|
func (m *MockTenantRepository) Update(ctx context.Context, tenant *domain.Tenant) error {
|
|
return m.Called(ctx, tenant).Error(0)
|
|
}
|
|
|
|
func (m *MockTenantRepository) FindByID(ctx context.Context, id string) (*domain.Tenant, error) {
|
|
args := m.Called(ctx, id)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*domain.Tenant), args.Error(1)
|
|
}
|
|
|
|
func (m *MockTenantRepository) FindBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
|
|
args := m.Called(ctx, slug)
|
|
return args.Get(0).(*domain.Tenant), args.Error(1)
|
|
}
|
|
|
|
func (m *MockTenantRepository) FindByName(ctx context.Context, name string) (*domain.Tenant, error) {
|
|
args := m.Called(ctx, name)
|
|
return args.Get(0).(*domain.Tenant), args.Error(1)
|
|
}
|
|
|
|
func (m *MockTenantRepository) FindByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
|
|
args := m.Called(ctx, domainName)
|
|
return args.Get(0).(*domain.Tenant), args.Error(1)
|
|
}
|
|
|
|
func (m *MockTenantRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.Tenant, error) {
|
|
args := m.Called(ctx, ids)
|
|
return args.Get(0).([]domain.Tenant), args.Error(1)
|
|
}
|
|
|
|
func (m *MockTenantRepository) AddDomain(ctx context.Context, tenantID string, domainName string) error {
|
|
return m.Called(ctx, tenantID, domainName).Error(0)
|
|
}
|
|
|
|
func TestTenantService_ListManageableTenants_Inheritance(t *testing.T) {
|
|
mockRepo := new(MockTenantRepository)
|
|
mockKeto := new(MockKetoService)
|
|
svc := &tenantService{
|
|
repo: mockRepo,
|
|
keto: mockKeto,
|
|
}
|
|
|
|
userID := "user-123"
|
|
ctx := context.Background()
|
|
|
|
// 1. Mock direct tenant management (admins relation)
|
|
mockKeto.On("ListObjects", ctx, "Tenant", "admins", userID).Return([]string{"t-direct-1"}, nil)
|
|
|
|
// 2. Mock group management (admins of a group)
|
|
mockKeto.On("ListObjects", ctx, "TenantGroup", "admins", userID).Return([]string{"g-1"}, nil)
|
|
|
|
// 3. Mock tenants belonging to group g-1
|
|
mockKeto.On("ListRelations", ctx, "Tenant", "", "parent_group", "TenantGroup:g-1").Return([]RelationTuple{
|
|
{Object: "t-inherited-1", Relation: "parent_group", SubjectID: "TenantGroup:g-1"},
|
|
{Object: "t-inherited-2", Relation: "parent_group", SubjectID: "TenantGroup:g-1"},
|
|
}, nil)
|
|
|
|
// 4. Expect repository to fetch all unique IDs: t-direct-1, t-inherited-1, t-inherited-2
|
|
expectedIDs := []string{"t-direct-1", "t-inherited-1", "t-inherited-2"}
|
|
mockRepo.On("FindByIDs", ctx, mock.MatchedBy(func(ids []string) bool {
|
|
// Check if all expected IDs are present (order doesn't matter since we dedup via map)
|
|
foundCount := 0
|
|
for _, eid := range expectedIDs {
|
|
for _, id := range ids {
|
|
if id == eid {
|
|
foundCount++
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return foundCount == len(expectedIDs) && len(ids) == len(expectedIDs)
|
|
})).Return([]domain.Tenant{
|
|
{ID: "t-direct-1", Name: "Direct Tenant"},
|
|
{ID: "t-inherited-1", Name: "Inherited Tenant 1"},
|
|
{ID: "t-inherited-2", Name: "Inherited Tenant 2"},
|
|
}, nil)
|
|
|
|
// Execute
|
|
tenants, err := svc.ListManageableTenants(ctx, userID)
|
|
|
|
// Verify
|
|
assert.NoError(t, err)
|
|
assert.Len(t, tenants, 3)
|
|
mockKeto.AssertExpectations(t)
|
|
mockRepo.AssertExpectations(t)
|
|
}
|