1
0
forked from baron/baron-sso
Files
baron-sso/backend/internal/handler/tenant_handler_test.go
chan a1d508ed69 test(backend): fix tenant_handler_test by adjusting mock call for new ListTenants logic
Adjusts parameter matching on mockSvc.ListTenants to use mock.Anything and mockUserRepo methods, ensuring the test safely covers the newly added allTenants internal sub-query flow without fragile strict args.
2026-04-13 13:23:09 +09:00

263 lines
8.2 KiB
Go

package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"gorm.io/gorm"
)
// MockTenantService is a mock for service.TenantService
type MockTenantService struct {
mock.Mock
}
func (m *MockTenantService) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) {
args := m.Called(ctx, name, slug, tenantType, description, domains, parentID, creatorID)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *MockTenantService) GetTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
args := m.Called(ctx, domainName)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *MockTenantService) ApproveTenant(ctx context.Context, tenantID string) error {
args := m.Called(ctx, tenantID)
return args.Error(0)
}
func (m *MockTenantService) RequestRegistration(ctx context.Context, name, slug, description, domainName, adminEmail string) (*domain.Tenant, error) {
args := m.Called(ctx, name, slug, description, domainName, adminEmail)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *MockTenantService) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
args := m.Called(ctx, slug)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *MockTenantService) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) {
args := m.Called(ctx, id)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *MockTenantService) ListTenants(ctx context.Context, limit, offset int, parentID string) ([]domain.Tenant, int64, error) {
args := m.Called(ctx, limit, offset, parentID)
return args.Get(0).([]domain.Tenant), args.Get(1).(int64), args.Error(2)
}
func (m *MockTenantService) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
args := m.Called(ctx, userID)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]domain.Tenant), args.Error(1)
}
func (m *MockTenantService) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) {
args := m.Called(ctx, domainName)
return args.Bool(0), args.Error(1)
}
func (m *MockTenantService) SetKetoService(keto service.KetoService) {
m.Called(keto)
}
func (m *MockTenantService) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
args := m.Called(ctx, domainName)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
type MockUserRepoForHandler struct {
mock.Mock
}
func (m *MockUserRepoForHandler) Create(ctx context.Context, user *domain.User) error { return nil }
func (m *MockUserRepoForHandler) Update(ctx context.Context, user *domain.User) error { return nil }
func (m *MockUserRepoForHandler) Delete(ctx context.Context, id string) error { return nil }
func (m *MockUserRepoForHandler) FindByEmail(ctx context.Context, email string) (*domain.User, error) {
return nil, nil
}
func (m *MockUserRepoForHandler) FindByID(ctx context.Context, id string) (*domain.User, error) {
return nil, nil
}
func (m *MockUserRepoForHandler) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) {
return nil, nil
}
func (m *MockUserRepoForHandler) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
return nil, nil
}
func (m *MockUserRepoForHandler) List(ctx context.Context, offset, limit int, search string, companyCode string) ([]domain.User, int64, error) {
return nil, 0, nil
}
func (m *MockUserRepoForHandler) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
return 0, nil
}
func (m *MockUserRepoForHandler) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
return nil, nil
}
func (m *MockUserRepoForHandler) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) {
args := m.Called(ctx, codes)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(map[string]int64), args.Error(1)
}
func (m *MockUserRepoForHandler) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error {
return nil
}
func (m *MockUserRepoForHandler) GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) {
return nil, nil
}
func (m *MockUserRepoForHandler) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) {
return false, nil
}
func (m *MockUserRepoForHandler) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) {
return "", nil
}
func TestTenantHandler_CreateTenant(t *testing.T) {
app := fiber.New()
mockSvc := new(MockTenantService)
// CreateTenant checks h.DB != nil
h := &TenantHandler{Service: mockSvc, DB: &gorm.DB{}}
app.Post("/tenants", h.CreateTenant)
input := map[string]interface{}{
"name": "Test Tenant",
"slug": "test-tenant",
"domains": []string{"test.com"},
}
body, _ := json.Marshal(input)
mockSvc.On("RegisterTenant", mock.Anything, "Test Tenant", "test-tenant", domain.TenantTypeCompany, "", []string{"test.com"}, (*string)(nil), "").
Return(&domain.Tenant{ID: "t1", Name: "Test Tenant", Slug: "test-tenant"}, nil)
req := httptest.NewRequest("POST", "/tenants", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusCreated, resp.StatusCode)
var got map[string]interface{}
json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "t1", got["id"])
}
func TestTenantHandler_ListTenants(t *testing.T) {
app := fiber.New()
mockSvc := new(MockTenantService)
mockUserRepo := new(MockUserRepoForHandler)
h := &TenantHandler{
Service: mockSvc,
UserRepo: mockUserRepo,
}
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
Role: "super_admin",
})
return c.Next()
})
app.Get("/tenants", h.ListTenants)
tenants := []domain.Tenant{
{ID: "t1", Name: "Tenant A", Slug: "slug-a"},
{ID: "t2", Name: "Tenant B", Slug: "slug-b"},
}
// Mocking for the new allTenants check in ListTenants
mockSvc.On("ListTenants", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tenants, int64(2), nil).Maybe()
mockUserRepo.On("CountByCompanyCodes", mock.Anything, mock.Anything).
Return(map[string]int64{"slug-a": 5, "slug-b": 10}, nil).Maybe()
mockUserRepo.On("CountByTenantIDs", mock.Anything, mock.Anything).
Return(map[string]int64{}, nil).Maybe()
req := httptest.NewRequest("GET", "/tenants?limit=10&offset=0", nil)
resp, _ := app.Test(req)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var res tenantListResponse
json.NewDecoder(resp.Body).Decode(&res)
assert.Equal(t, int64(2), res.Total)
assert.Len(t, res.Items, 2)
// Check if counts are mapped correctly
for _, item := range res.Items {
if item.Slug == "slug-a" {
assert.Equal(t, int64(5), item.MemberCount)
} else if item.Slug == "slug-b" {
assert.Equal(t, int64(10), item.MemberCount)
}
}
}
func TestTenantHandler_ApproveTenant(t *testing.T) {
app := fiber.New()
mockSvc := new(MockTenantService)
h := &TenantHandler{Service: mockSvc}
app.Post("/tenants/:id/approve", h.ApproveTenant)
mockSvc.On("ApproveTenant", mock.Anything, "t1").Return(nil)
req := httptest.NewRequest("POST", "/tenants/t1/approve", nil)
resp, _ := app.Test(req)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
func (m *MockTenantService) DeleteTenantsBulk(ctx context.Context, tenantIDs []string) error {
args := m.Called(ctx, tenantIDs)
return args.Error(0)
}
func (m *MockTenantService) ListJoinedTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
args := m.Called(ctx, userID)
if args.Get(0) != nil {
return args.Get(0).([]domain.Tenant), args.Error(1)
}
return nil, args.Error(1)
}