forked from baron/baron-sso
267 lines
8.5 KiB
Go
267 lines
8.5 KiB
Go
package service
|
|
|
|
import (
|
|
"baron-sso-backend/internal/domain"
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type tenantServiceTenantRepoMock struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *tenantServiceTenantRepoMock) Create(ctx context.Context, tenant *domain.Tenant) error {
|
|
args := m.Called(ctx, tenant)
|
|
return args.Error(0)
|
|
}
|
|
|
|
func (m *tenantServiceTenantRepoMock) Update(ctx context.Context, tenant *domain.Tenant) error {
|
|
args := m.Called(ctx, tenant)
|
|
return args.Error(0)
|
|
}
|
|
|
|
func (m *tenantServiceTenantRepoMock) FindByID(ctx context.Context, id string) (*domain.Tenant, error) {
|
|
args := m.Called(ctx, id)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*domain.Tenant), args.Error(1)
|
|
}
|
|
|
|
func (m *tenantServiceTenantRepoMock) FindBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
|
|
args := m.Called(ctx, slug)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*domain.Tenant), args.Error(1)
|
|
}
|
|
|
|
func (m *tenantServiceTenantRepoMock) FindByName(ctx context.Context, name string) (*domain.Tenant, error) {
|
|
args := m.Called(ctx, name)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*domain.Tenant), args.Error(1)
|
|
}
|
|
|
|
func (m *tenantServiceTenantRepoMock) FindByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
|
|
args := m.Called(ctx, domainName)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*domain.Tenant), args.Error(1)
|
|
}
|
|
|
|
func (m *tenantServiceTenantRepoMock) FindByIDs(ctx context.Context, ids []string) ([]domain.Tenant, error) {
|
|
args := m.Called(ctx, ids)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).([]domain.Tenant), args.Error(1)
|
|
}
|
|
|
|
func (m *tenantServiceTenantRepoMock) AddDomain(ctx context.Context, tenantID string, domainName string, verified bool) error {
|
|
args := m.Called(ctx, tenantID, domainName, verified)
|
|
return args.Error(0)
|
|
}
|
|
|
|
type tenantServiceUserRepoMock struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *tenantServiceUserRepoMock) Create(ctx context.Context, user *domain.User) error {
|
|
args := m.Called(ctx, user)
|
|
return args.Error(0)
|
|
}
|
|
|
|
func (m *tenantServiceUserRepoMock) Update(ctx context.Context, user *domain.User) error {
|
|
args := m.Called(ctx, user)
|
|
return args.Error(0)
|
|
}
|
|
|
|
func (m *tenantServiceUserRepoMock) FindByEmail(ctx context.Context, email string) (*domain.User, error) {
|
|
args := m.Called(ctx, email)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*domain.User), args.Error(1)
|
|
}
|
|
|
|
func (m *tenantServiceUserRepoMock) FindByID(ctx context.Context, id string) (*domain.User, error) {
|
|
args := m.Called(ctx, id)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*domain.User), args.Error(1)
|
|
}
|
|
|
|
func (m *tenantServiceUserRepoMock) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) {
|
|
args := m.Called(ctx, ids)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).([]domain.User), args.Error(1)
|
|
}
|
|
|
|
func (m *tenantServiceUserRepoMock) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
|
|
args := m.Called(ctx, tenantID)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).([]domain.User), args.Error(1)
|
|
}
|
|
|
|
func (m *tenantServiceUserRepoMock) List(ctx context.Context, offset, limit int, search string) ([]domain.User, int64, error) {
|
|
args := m.Called(ctx, offset, limit, search)
|
|
if args.Get(0) == nil {
|
|
return nil, 0, args.Error(2)
|
|
}
|
|
return args.Get(0).([]domain.User), args.Get(1).(int64), args.Error(2)
|
|
}
|
|
|
|
func TestTenantService_RegisterTenant_AddsDomainsAsVerified(t *testing.T) {
|
|
repo := new(tenantServiceTenantRepoMock)
|
|
userRepo := new(tenantServiceUserRepoMock)
|
|
svc := NewTenantService(repo, userRepo)
|
|
|
|
repo.On("FindBySlug", mock.Anything, "tenant-a").Return(nil, gorm.ErrRecordNotFound).Once()
|
|
repo.On("Create", mock.Anything, mock.MatchedBy(func(tenant *domain.Tenant) bool {
|
|
return tenant.Name == "Tenant A" &&
|
|
tenant.Slug == "tenant-a" &&
|
|
tenant.Status == domain.TenantStatusActive
|
|
})).Run(func(args mock.Arguments) {
|
|
args.Get(1).(*domain.Tenant).ID = "tenant-1"
|
|
}).Return(nil).Once()
|
|
repo.On("AddDomain", mock.Anything, "tenant-1", "a.example.com", true).Return(nil).Once()
|
|
repo.On("AddDomain", mock.Anything, "tenant-1", "a.example.org", true).Return(nil).Once()
|
|
repo.On("FindBySlug", mock.Anything, "tenant-a").Return(&domain.Tenant{
|
|
ID: "tenant-1",
|
|
Name: "Tenant A",
|
|
Slug: "tenant-a",
|
|
Status: domain.TenantStatusActive,
|
|
}, nil).Once()
|
|
|
|
tenant, err := svc.RegisterTenant(context.Background(), "Tenant A", "tenant-a", "desc", []string{"a.example.com", "a.example.org"})
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, tenant)
|
|
assert.Equal(t, "tenant-1", tenant.ID)
|
|
|
|
repo.AssertExpectations(t)
|
|
}
|
|
|
|
func TestTenantService_RequestRegistration_AddsDomainAsUnverified(t *testing.T) {
|
|
repo := new(tenantServiceTenantRepoMock)
|
|
userRepo := new(tenantServiceUserRepoMock)
|
|
svc := NewTenantService(repo, userRepo)
|
|
|
|
repo.On("Create", mock.Anything, mock.MatchedBy(func(tenant *domain.Tenant) bool {
|
|
return tenant.Name == "Tenant B" &&
|
|
tenant.Slug == "tenant-b" &&
|
|
tenant.Status == domain.TenantStatusPending &&
|
|
tenant.Config["adminEmail"] == "admin@tenant-b.com"
|
|
})).Run(func(args mock.Arguments) {
|
|
args.Get(1).(*domain.Tenant).ID = "tenant-2"
|
|
}).Return(nil).Once()
|
|
repo.On("AddDomain", mock.Anything, "tenant-2", "tenant-b.com", false).Return(nil).Once()
|
|
|
|
tenant, err := svc.RequestRegistration(
|
|
context.Background(),
|
|
"Tenant B",
|
|
"tenant-b",
|
|
"desc",
|
|
"tenant-b.com",
|
|
"admin@tenant-b.com",
|
|
)
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, tenant)
|
|
assert.Equal(t, "tenant-2", tenant.ID)
|
|
assert.Equal(t, domain.TenantStatusPending, tenant.Status)
|
|
|
|
repo.AssertExpectations(t)
|
|
}
|
|
|
|
func TestTenantService_RequestRegistration_RejectsDomainMismatch(t *testing.T) {
|
|
repo := new(tenantServiceTenantRepoMock)
|
|
userRepo := new(tenantServiceUserRepoMock)
|
|
svc := NewTenantService(repo, userRepo)
|
|
|
|
tenant, err := svc.RequestRegistration(
|
|
context.Background(),
|
|
"Tenant B",
|
|
"tenant-b",
|
|
"desc",
|
|
"tenant-b.com",
|
|
"admin@other.com",
|
|
)
|
|
assert.Error(t, err)
|
|
assert.ErrorContains(t, err, "admin email domain must match the tenant domain")
|
|
assert.Nil(t, tenant)
|
|
|
|
repo.AssertNotCalled(t, "Create", mock.Anything, mock.Anything)
|
|
repo.AssertNotCalled(t, "AddDomain", mock.Anything, mock.Anything, mock.Anything, mock.Anything)
|
|
}
|
|
|
|
func TestTenantService_ApproveTenant_AssignsAdminRelationWhenUserExists(t *testing.T) {
|
|
repo := new(tenantServiceTenantRepoMock)
|
|
userRepo := new(tenantServiceUserRepoMock)
|
|
keto := new(MockKetoService)
|
|
svc := NewTenantService(repo, userRepo)
|
|
svc.SetKetoService(keto)
|
|
|
|
tenant := &domain.Tenant{
|
|
ID: "tenant-3",
|
|
Slug: "tenant-c",
|
|
Status: domain.TenantStatusPending,
|
|
Config: domain.JSONMap{"adminEmail": "admin@tenant-c.com"},
|
|
}
|
|
|
|
repo.On("FindByID", mock.Anything, "tenant-3").Return(tenant, nil).Once()
|
|
repo.On("Update", mock.Anything, mock.MatchedBy(func(updated *domain.Tenant) bool {
|
|
return updated.ID == "tenant-3" && updated.Status == domain.TenantStatusActive
|
|
})).Return(nil).Once()
|
|
userRepo.On("FindByEmail", mock.Anything, "admin@tenant-c.com").Return(&domain.User{
|
|
ID: "user-1",
|
|
Email: "admin@tenant-c.com",
|
|
}, nil).Once()
|
|
keto.On("CreateRelation", mock.Anything, "Tenant", "tenant-3", "admin", "User:user-1").Return(nil).Once()
|
|
|
|
err := svc.ApproveTenant(context.Background(), "tenant-3")
|
|
assert.NoError(t, err)
|
|
|
|
repo.AssertExpectations(t)
|
|
userRepo.AssertExpectations(t)
|
|
keto.AssertExpectations(t)
|
|
}
|
|
|
|
func TestTenantService_ApproveTenant_DoesNotAssignWhenUserMissing(t *testing.T) {
|
|
repo := new(tenantServiceTenantRepoMock)
|
|
userRepo := new(tenantServiceUserRepoMock)
|
|
keto := new(MockKetoService)
|
|
svc := NewTenantService(repo, userRepo)
|
|
svc.SetKetoService(keto)
|
|
|
|
tenant := &domain.Tenant{
|
|
ID: "tenant-4",
|
|
Slug: "tenant-d",
|
|
Status: domain.TenantStatusPending,
|
|
Config: domain.JSONMap{"adminEmail": "admin@tenant-d.com"},
|
|
}
|
|
|
|
repo.On("FindByID", mock.Anything, "tenant-4").Return(tenant, nil).Once()
|
|
repo.On("Update", mock.Anything, mock.MatchedBy(func(updated *domain.Tenant) bool {
|
|
return updated.ID == "tenant-4" && updated.Status == domain.TenantStatusActive
|
|
})).Return(nil).Once()
|
|
userRepo.On("FindByEmail", mock.Anything, "admin@tenant-d.com").Return(nil, gorm.ErrRecordNotFound).Once()
|
|
|
|
err := svc.ApproveTenant(context.Background(), "tenant-4")
|
|
assert.NoError(t, err)
|
|
|
|
keto.AssertNotCalled(t, "CreateRelation", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything)
|
|
repo.AssertExpectations(t)
|
|
userRepo.AssertExpectations(t)
|
|
}
|