package service import ( "baron-sso-backend/internal/domain" "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) // --- Local Mocks to avoid collisions --- type MockTenantRepoForSvc struct { mock.Mock } func (m *MockTenantRepoForSvc) Create(ctx context.Context, tenant *domain.Tenant) error { return m.Called(ctx, tenant).Error(0) } func (m *MockTenantRepoForSvc) Update(ctx context.Context, tenant *domain.Tenant) error { return m.Called(ctx, tenant).Error(0) } func (m *MockTenantRepoForSvc) FindByID(ctx context.Context, id string) (*domain.Tenant, error) { args := m.Called(ctx, id) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*domain.Tenant), args.Error(1) } func (m *MockTenantRepoForSvc) FindBySlug(ctx context.Context, slug string) (*domain.Tenant, error) { args := m.Called(ctx, slug) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*domain.Tenant), args.Error(1) } func (m *MockTenantRepoForSvc) FindByName(ctx context.Context, name string) (*domain.Tenant, error) { return nil, nil } func (m *MockTenantRepoForSvc) FindByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) { args := m.Called(ctx, domainName) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*domain.Tenant), args.Error(1) } func (m *MockTenantRepoForSvc) FindByIDs(ctx context.Context, ids []string) ([]domain.Tenant, error) { return nil, nil } func (m *MockTenantRepoForSvc) AddDomain(ctx context.Context, tenantID string, domainName string, verified bool) error { return m.Called(ctx, tenantID, domainName, verified).Error(0) } type MockKetoSvcForTenant struct { mock.Mock } func (m *MockKetoSvcForTenant) CreateRelation(ctx context.Context, namespace, object, relation, subject string) error { return m.Called(ctx, namespace, object, relation, subject).Error(0) } func (m *MockKetoSvcForTenant) DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error { return m.Called(ctx, namespace, object, relation, subject).Error(0) } func (m *MockKetoSvcForTenant) ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]RelationTuple, error) { args := m.Called(ctx, namespace, object, relation, subject) return args.Get(0).([]RelationTuple), args.Error(1) } func (m *MockKetoSvcForTenant) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) { args := m.Called(ctx, namespace, relation, subject) return args.Get(0).([]string), args.Error(1) } func (m *MockKetoSvcForTenant) CheckPermission(ctx context.Context, namespace, object, relation, subject string) (bool, error) { args := m.Called(ctx, namespace, object, relation, subject) return args.Bool(0), args.Error(1) } type MockUserRepoForTenant struct { mock.Mock } func (m *MockUserRepoForTenant) Create(ctx context.Context, user *domain.User) error { return nil } func (m *MockUserRepoForTenant) Update(ctx context.Context, user *domain.User) error { return nil } func (m *MockUserRepoForTenant) FindByEmail(ctx context.Context, email string) (*domain.User, error) { args := m.Called(email) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*domain.User), args.Error(1) } func (m *MockUserRepoForTenant) FindByID(ctx context.Context, id string) (*domain.User, error) { return nil, nil } func (m *MockUserRepoForTenant) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) { return nil, nil } func (m *MockUserRepoForTenant) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) { return nil, nil } func (m *MockUserRepoForTenant) List(ctx context.Context, offset, limit int, search string) ([]domain.User, int64, error) { return nil, 0, nil } // --- Tests --- func TestTenantService_RegisterTenant_AutoVerify(t *testing.T) { mockRepo := new(MockTenantRepoForSvc) svc := NewTenantService(mockRepo, nil) ctx := context.Background() name := "New Tenant" slug := "new-tenant" domains := []string{"example.com"} // Use .Once() to ensure correct return values for sequential calls to FindBySlug mockRepo.On("FindBySlug", ctx, slug).Return(nil, nil).Once() mockRepo.On("Create", ctx, mock.Anything).Return(nil) mockRepo.On("AddDomain", ctx, mock.Anything, "example.com", true).Return(nil) mockRepo.On("FindBySlug", ctx, slug).Return(&domain.Tenant{ID: "t1", Slug: slug}, nil).Once() tenant, err := svc.RegisterTenant(ctx, name, slug, "", domains) assert.NoError(t, err) assert.NotNil(t, tenant) assert.Equal(t, "t1", tenant.ID) mockRepo.AssertExpectations(t) } func TestTenantService_RequestRegistration_NoVerify(t *testing.T) { mockRepo := new(MockTenantRepoForSvc) svc := NewTenantService(mockRepo, nil) ctx := context.Background() name := "Public Tenant" slug := "public-tenant" domainName := "public.com" adminEmail := "admin@public.com" mockRepo.On("Create", ctx, mock.MatchedBy(func(tenant *domain.Tenant) bool { return tenant.Status == domain.TenantStatusPending })).Return(nil) mockRepo.On("AddDomain", ctx, mock.Anything, domainName, false).Return(nil) tenant, err := svc.RequestRegistration(ctx, name, slug, "", domainName, adminEmail) assert.NoError(t, err) assert.NotNil(t, tenant) mockRepo.AssertExpectations(t) } func TestTenantService_ApproveTenant_SyncAdmin(t *testing.T) { mockRepo := new(MockTenantRepoForSvc) mockUserRepo := new(MockUserRepoForTenant) mockKeto := new(MockKetoSvcForTenant) svc := NewTenantService(mockRepo, mockUserRepo) svc.SetKetoService(mockKeto) ctx := context.Background() tenantID := "t1" adminEmail := "admin@tenant.com" userID := "user-uuid" tenant := &domain.Tenant{ ID: tenantID, Slug: "tenant-slug", Config: domain.JSONMap{"adminEmail": adminEmail}, } mockRepo.On("FindByID", ctx, tenantID).Return(tenant, nil) mockRepo.On("Update", ctx, mock.Anything).Return(nil) mockUserRepo.On("FindByEmail", adminEmail).Return(&domain.User{ID: userID, Email: adminEmail}, nil) mockKeto.On("CreateRelation", ctx, "Tenant", tenantID, "admin", "User:"+userID).Return(nil) err := svc.ApproveTenant(ctx, tenantID) assert.NoError(t, err) mockRepo.AssertExpectations(t) mockUserRepo.AssertExpectations(t) mockKeto.AssertExpectations(t) }