forked from baron/baron-sso
527 lines
16 KiB
Go
527 lines
16 KiB
Go
package handler
|
|
|
|
import (
|
|
"baron-sso-backend/internal/domain"
|
|
"baron-sso-backend/internal/service"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// MockTenantService is a mock for service.TenantService
|
|
type MockTenantService struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockTenantService) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) {
|
|
args := m.Called(ctx, name, slug, tenantType, description, domains, parentID, creatorID)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*domain.Tenant), args.Error(1)
|
|
}
|
|
|
|
func (m *MockTenantService) GetTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
|
|
args := m.Called(ctx, domainName)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*domain.Tenant), args.Error(1)
|
|
}
|
|
|
|
func (m *MockTenantService) ApproveTenant(ctx context.Context, tenantID string) error {
|
|
args := m.Called(ctx, tenantID)
|
|
return args.Error(0)
|
|
}
|
|
|
|
func (m *MockTenantService) RequestRegistration(ctx context.Context, name, slug, description, domainName, adminEmail string) (*domain.Tenant, error) {
|
|
args := m.Called(ctx, name, slug, description, domainName, adminEmail)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*domain.Tenant), args.Error(1)
|
|
}
|
|
|
|
func (m *MockTenantService) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
|
|
args := m.Called(ctx, slug)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*domain.Tenant), args.Error(1)
|
|
}
|
|
|
|
func (m *MockTenantService) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) {
|
|
args := m.Called(ctx, id)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*domain.Tenant), args.Error(1)
|
|
}
|
|
|
|
func (m *MockTenantService) ListTenants(ctx context.Context, limit, offset int, parentID string) ([]domain.Tenant, int64, error) {
|
|
args := m.Called(ctx, limit, offset, parentID)
|
|
return args.Get(0).([]domain.Tenant), args.Get(1).(int64), args.Error(2)
|
|
}
|
|
|
|
func (m *MockTenantService) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
|
|
args := m.Called(ctx, userID)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).([]domain.Tenant), args.Error(1)
|
|
}
|
|
|
|
func (m *MockTenantService) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) {
|
|
args := m.Called(ctx, domainName)
|
|
return args.Bool(0), args.Error(1)
|
|
}
|
|
|
|
func (m *MockTenantService) SetKetoService(keto service.KetoService) {
|
|
m.Called(keto)
|
|
}
|
|
|
|
func (m *MockTenantService) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
|
|
args := m.Called(ctx, domainName)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*domain.Tenant), args.Error(1)
|
|
}
|
|
|
|
type MockUserRepoForHandler struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockUserRepoForHandler) Create(ctx context.Context, user *domain.User) error { return nil }
|
|
func (m *MockUserRepoForHandler) Update(ctx context.Context, user *domain.User) error { return nil }
|
|
func (m *MockUserRepoForHandler) Delete(ctx context.Context, id string) error { return nil }
|
|
func (m *MockUserRepoForHandler) FindByEmail(ctx context.Context, email string) (*domain.User, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *MockUserRepoForHandler) FindByID(ctx context.Context, id string) (*domain.User, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *MockUserRepoForHandler) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *MockUserRepoForHandler) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *MockUserRepoForHandler) List(ctx context.Context, offset, limit int, search string, companyCode string) ([]domain.User, int64, error) {
|
|
for _, call := range m.ExpectedCalls {
|
|
if call.Method == "List" {
|
|
args := m.Called(ctx, offset, limit, search, companyCode)
|
|
return args.Get(0).([]domain.User), args.Get(1).(int64), args.Error(2)
|
|
}
|
|
}
|
|
return nil, 0, nil
|
|
}
|
|
|
|
func (m *MockUserRepoForHandler) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
|
|
return 0, nil
|
|
}
|
|
|
|
func (m *MockUserRepoForHandler) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) {
|
|
args := m.Called(ctx, tenantIDs)
|
|
return args.Get(0).([]domain.User), args.Error(1)
|
|
}
|
|
|
|
func (m *MockUserRepoForHandler) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *MockUserRepoForHandler) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) {
|
|
args := m.Called(ctx, codes)
|
|
return args.Get(0).([]domain.User), args.Error(1)
|
|
}
|
|
|
|
func (m *MockUserRepoForHandler) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) {
|
|
args := m.Called(ctx, codes)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(map[string]int64), args.Error(1)
|
|
}
|
|
|
|
func (m *MockUserRepoForHandler) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *MockUserRepoForHandler) GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *MockUserRepoForHandler) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) {
|
|
return false, nil
|
|
}
|
|
|
|
func (m *MockUserRepoForHandler) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) {
|
|
return "", nil
|
|
}
|
|
|
|
func TestTenantHandler_CreateTenant(t *testing.T) {
|
|
app := fiber.New()
|
|
mockSvc := new(MockTenantService)
|
|
// CreateTenant checks h.DB != nil
|
|
h := &TenantHandler{Service: mockSvc, DB: &gorm.DB{}}
|
|
|
|
app.Post("/tenants", h.CreateTenant)
|
|
|
|
input := map[string]interface{}{
|
|
"name": "Test Tenant",
|
|
"slug": "test-tenant",
|
|
"domains": []string{"test.com"},
|
|
}
|
|
body, _ := json.Marshal(input)
|
|
|
|
mockSvc.On("RegisterTenant", mock.Anything, "Test Tenant", "test-tenant", domain.TenantTypeCompany, "", []string(nil), (*string)(nil), "").
|
|
Return(&domain.Tenant{ID: "t1", Name: "Test Tenant", Slug: "test-tenant"}, nil)
|
|
|
|
req := httptest.NewRequest("POST", "/tenants", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
resp, _ := app.Test(req)
|
|
|
|
assert.Equal(t, http.StatusCreated, resp.StatusCode)
|
|
var got map[string]interface{}
|
|
json.NewDecoder(resp.Body).Decode(&got)
|
|
assert.Equal(t, "t1", got["id"])
|
|
}
|
|
|
|
func TestTenantHandler_ListTenants(t *testing.T) {
|
|
app := fiber.New()
|
|
mockSvc := new(MockTenantService)
|
|
mockUserRepo := new(MockUserRepoForHandler)
|
|
|
|
h := &TenantHandler{
|
|
Service: mockSvc,
|
|
UserRepo: mockUserRepo,
|
|
}
|
|
|
|
app.Use(func(c *fiber.Ctx) error {
|
|
c.Locals("user_profile", &domain.UserProfileResponse{
|
|
Role: "super_admin",
|
|
})
|
|
return c.Next()
|
|
})
|
|
app.Get("/tenants", h.ListTenants)
|
|
tenants := []domain.Tenant{
|
|
{ID: "t1", Name: "Tenant A", Slug: "slug-a"},
|
|
{ID: "t2", Name: "Tenant B", Slug: "slug-b"},
|
|
}
|
|
|
|
// Mocking for the new allTenants check in ListTenants
|
|
mockSvc.On("ListTenants", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tenants, int64(2), nil).Maybe()
|
|
|
|
mockUserRepo.On("CountByCompanyCodes", mock.Anything, mock.Anything).
|
|
Return(map[string]int64{"slug-a": 5, "slug-b": 10}, nil).Maybe()
|
|
mockUserRepo.On("CountByTenantIDs", mock.Anything, mock.Anything).
|
|
Return(map[string]int64{}, nil).Maybe()
|
|
|
|
req := httptest.NewRequest("GET", "/tenants?limit=10&offset=0", nil)
|
|
resp, _ := app.Test(req)
|
|
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
var res tenantListResponse
|
|
json.NewDecoder(resp.Body).Decode(&res)
|
|
|
|
assert.Equal(t, int64(2), res.Total)
|
|
assert.Len(t, res.Items, 2)
|
|
|
|
// Check if counts are mapped correctly
|
|
for _, item := range res.Items {
|
|
if item.Slug == "slug-a" {
|
|
assert.Equal(t, int64(5), item.MemberCount)
|
|
} else if item.Slug == "slug-b" {
|
|
assert.Equal(t, int64(10), item.MemberCount)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestTenantHandler_ExportTenantsCSV(t *testing.T) {
|
|
app := fiber.New()
|
|
mockSvc := new(MockTenantService)
|
|
h := &TenantHandler{Service: mockSvc}
|
|
|
|
app.Get("/tenants/export", h.ExportTenantsCSV)
|
|
|
|
parentID := "parent-1"
|
|
tenants := []domain.Tenant{
|
|
{
|
|
ID: "t1",
|
|
Name: "Tenant A",
|
|
Type: domain.TenantTypeCompany,
|
|
ParentID: &parentID,
|
|
Slug: "tenant-a",
|
|
Description: "Primary tenant",
|
|
Domains: []domain.TenantDomain{
|
|
{Domain: "tenant-a.example.com"},
|
|
{Domain: "login.tenant-a.example.com"},
|
|
},
|
|
},
|
|
}
|
|
|
|
mockSvc.On("ListTenants", mock.Anything, 10000, 0, "").Return(tenants, int64(1), nil)
|
|
|
|
req := httptest.NewRequest("GET", "/tenants/export?includeIds=true", nil)
|
|
resp, _ := app.Test(req)
|
|
body, _ := io.ReadAll(resp.Body)
|
|
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
assert.Contains(t, resp.Header.Get("Content-Disposition"), "tenants.csv")
|
|
assert.Equal(t, "text/csv", strings.Split(resp.Header.Get("Content-Type"), ";")[0])
|
|
assert.Contains(t, string(body), "tenant_id,name,type,parent_tenant_id,parent_tenant_slug,slug,memo,email_domain")
|
|
assert.Contains(t, string(body), "t1,Tenant A,COMPANY,parent-1,,tenant-a,Primary tenant,tenant-a.example.com;login.tenant-a.example.com")
|
|
}
|
|
|
|
func TestTenantHandler_ExportTenantsCSV_OmitsIDsAndUsesParentSlug(t *testing.T) {
|
|
app := fiber.New()
|
|
mockSvc := new(MockTenantService)
|
|
h := &TenantHandler{Service: mockSvc}
|
|
|
|
app.Get("/tenants/export", h.ExportTenantsCSV)
|
|
|
|
parentID := "parent-1"
|
|
tenants := []domain.Tenant{
|
|
{
|
|
ID: parentID,
|
|
Name: "Parent Tenant",
|
|
Type: domain.TenantTypeCompanyGroup,
|
|
Slug: "parent-tenant",
|
|
},
|
|
{
|
|
ID: "child-1",
|
|
Name: "Child Tenant",
|
|
Type: domain.TenantTypeUserGroup,
|
|
ParentID: &parentID,
|
|
Slug: "child-tenant",
|
|
},
|
|
}
|
|
|
|
mockSvc.On("ListTenants", mock.Anything, 10000, 0, "").Return(tenants, int64(2), nil)
|
|
|
|
req := httptest.NewRequest("GET", "/tenants/export?includeIds=false", nil)
|
|
resp, _ := app.Test(req)
|
|
body, _ := io.ReadAll(resp.Body)
|
|
text := string(body)
|
|
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
assert.Contains(t, text, "name,type,parent_tenant_slug,slug,memo,email_domain")
|
|
assert.Contains(t, text, "Child Tenant,USER_GROUP,parent-tenant,child-tenant,,")
|
|
assert.NotContains(t, text, "tenant_id")
|
|
assert.NotContains(t, text, "parent_tenant_id")
|
|
assert.NotContains(t, text, "child-1")
|
|
mockSvc.AssertExpectations(t)
|
|
}
|
|
|
|
func TestTenantHandler_ImportTenantsCSVCreatesTenant(t *testing.T) {
|
|
app := fiber.New()
|
|
mockSvc := new(MockTenantService)
|
|
h := &TenantHandler{Service: mockSvc}
|
|
|
|
app.Post("/tenants/import", h.ImportTenantsCSV)
|
|
|
|
var body bytes.Buffer
|
|
writer := multipart.NewWriter(&body)
|
|
part, err := writer.CreateFormFile("file", "tenants.csv")
|
|
assert.NoError(t, err)
|
|
_, err = part.Write([]byte("tenant_id,name,type,parent_tenant_id,slug,memo,email_domain\n,Imported Tenant,COMPANY,parent-1,imported-tenant,Imported memo,imported.example.com;login.imported.example.com\n"))
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, writer.Close())
|
|
|
|
mockSvc.On("ListTenants", mock.Anything, 10000, 0, "").Return([]domain.Tenant{}, int64(0), nil).Once()
|
|
mockSvc.On(
|
|
"RegisterTenant",
|
|
mock.Anything,
|
|
"Imported Tenant",
|
|
"imported-tenant",
|
|
domain.TenantTypeCompany,
|
|
"Imported memo",
|
|
[]string{"imported.example.com", "login.imported.example.com"},
|
|
mock.MatchedBy(func(parentID *string) bool {
|
|
return parentID != nil && *parentID == "parent-1"
|
|
}),
|
|
"",
|
|
).Return(&domain.Tenant{ID: "imported-1", Name: "Imported Tenant", Slug: "imported-tenant"}, nil)
|
|
|
|
req := httptest.NewRequest("POST", "/tenants/import", &body)
|
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
|
resp, _ := app.Test(req)
|
|
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
var got map[string]interface{}
|
|
json.NewDecoder(resp.Body).Decode(&got)
|
|
assert.Equal(t, float64(1), got["created"])
|
|
assert.Equal(t, float64(0), got["updated"])
|
|
assert.Equal(t, float64(0), got["failed"])
|
|
mockSvc.AssertExpectations(t)
|
|
}
|
|
|
|
func TestTenantHandler_ImportTenantsCSVResolvesParentSlugToID(t *testing.T) {
|
|
app := fiber.New()
|
|
mockSvc := new(MockTenantService)
|
|
h := &TenantHandler{Service: mockSvc}
|
|
|
|
app.Post("/tenants/import", h.ImportTenantsCSV)
|
|
|
|
var body bytes.Buffer
|
|
writer := multipart.NewWriter(&body)
|
|
part, err := writer.CreateFormFile("file", "tenants.csv")
|
|
assert.NoError(t, err)
|
|
_, err = part.Write([]byte("name,type,parent_tenant_slug,slug,memo,email_domain\nParent Tenant,COMPANY,,parent-slug,,\nChild Tenant,ORGANIZATION,parent-slug,child-slug,,\n"))
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, writer.Close())
|
|
|
|
parentID := "parent-id"
|
|
mockSvc.On("ListTenants", mock.Anything, 10000, 0, "").Return([]domain.Tenant{}, int64(0), nil).Once()
|
|
mockSvc.On(
|
|
"RegisterTenant",
|
|
mock.Anything,
|
|
"Parent Tenant",
|
|
"parent-slug",
|
|
domain.TenantTypeCompany,
|
|
"",
|
|
[]string{},
|
|
(*string)(nil),
|
|
"",
|
|
).Return(&domain.Tenant{ID: parentID, Name: "Parent Tenant", Slug: "parent-slug"}, nil).Once()
|
|
mockSvc.On(
|
|
"RegisterTenant",
|
|
mock.Anything,
|
|
"Child Tenant",
|
|
"child-slug",
|
|
domain.TenantTypeOrganization,
|
|
"",
|
|
[]string{},
|
|
mock.MatchedBy(func(got *string) bool {
|
|
return got != nil && *got == parentID
|
|
}),
|
|
"",
|
|
).Return(&domain.Tenant{ID: "child-id", Name: "Child Tenant", Slug: "child-slug"}, nil).Once()
|
|
|
|
req := httptest.NewRequest("POST", "/tenants/import", &body)
|
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
|
resp, _ := app.Test(req)
|
|
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
var got map[string]interface{}
|
|
json.NewDecoder(resp.Body).Decode(&got)
|
|
assert.Equal(t, float64(2), got["created"])
|
|
assert.Equal(t, float64(0), got["failed"])
|
|
mockSvc.AssertExpectations(t)
|
|
}
|
|
|
|
func TestNormalizeTenantTypeAllowsOrganization(t *testing.T) {
|
|
assert.Equal(t, domain.TenantTypeOrganization, normalizeTenantType("organization"))
|
|
}
|
|
|
|
func TestTenantCSVAllowedDomainsRoundTrip(t *testing.T) {
|
|
records, err := parseTenantCSVRecords(strings.NewReader(
|
|
"name,type,parent_tenant_slug,slug,memo,email_domain\n" +
|
|
"Hanmac,COMPANY,,hanmac,,\"samaneng.com, hanmaceng.co.kr;login.hmac.kr\"\n",
|
|
))
|
|
|
|
assert.NoError(t, err)
|
|
assert.Len(t, records, 1)
|
|
assert.Equal(t, []string{"samaneng.com", "hanmaceng.co.kr", "login.hmac.kr"}, records[0].Domains)
|
|
}
|
|
|
|
func TestNormalizeTenantDomainInputsSplitsCommaAndWhitespace(t *testing.T) {
|
|
got := normalizeTenantDomainInputs([]string{
|
|
"samaneng.com, hanmaceng.co.kr",
|
|
" LOGIN.HMAC.KR\nportal.hmac.kr ",
|
|
"samaneng.com",
|
|
})
|
|
|
|
assert.Equal(t, []string{
|
|
"samaneng.com",
|
|
"hanmaceng.co.kr",
|
|
"login.hmac.kr",
|
|
"portal.hmac.kr",
|
|
}, got)
|
|
}
|
|
|
|
func TestNormalizeTenantConfigForcesIndexedForLoginIDFields(t *testing.T) {
|
|
config, err := normalizeTenantConfig(map[string]any{
|
|
"userSchema": []any{
|
|
map[string]any{
|
|
"key": "emp_no",
|
|
"label": "사번",
|
|
"type": "text",
|
|
"indexed": false,
|
|
"isLoginId": true,
|
|
"maxLength": 20,
|
|
},
|
|
},
|
|
})
|
|
|
|
assert.NoError(t, err)
|
|
fields, ok := config["userSchema"].([]any)
|
|
assert.True(t, ok)
|
|
assert.Len(t, fields, 1)
|
|
|
|
field, ok := fields[0].(map[string]any)
|
|
assert.True(t, ok)
|
|
assert.Equal(t, true, field["indexed"])
|
|
assert.Equal(t, true, field["isLoginId"])
|
|
assert.NotContains(t, field, "maxLength")
|
|
}
|
|
|
|
func TestNormalizeTenantConfigRejectsNonTextLoginIDFields(t *testing.T) {
|
|
_, err := normalizeTenantConfig(map[string]any{
|
|
"userSchema": []any{
|
|
map[string]any{
|
|
"key": "emp_no",
|
|
"type": "number",
|
|
"isLoginId": true,
|
|
},
|
|
},
|
|
})
|
|
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "login ID fields must be text")
|
|
}
|
|
|
|
func TestTenantHandler_ApproveTenant(t *testing.T) {
|
|
app := fiber.New()
|
|
mockSvc := new(MockTenantService)
|
|
h := &TenantHandler{Service: mockSvc}
|
|
|
|
app.Post("/tenants/:id/approve", h.ApproveTenant)
|
|
|
|
mockSvc.On("ApproveTenant", mock.Anything, "t1").Return(nil)
|
|
|
|
req := httptest.NewRequest("POST", "/tenants/t1/approve", nil)
|
|
resp, _ := app.Test(req)
|
|
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
}
|
|
|
|
func (m *MockTenantService) DeleteTenantsBulk(ctx context.Context, tenantIDs []string) error {
|
|
args := m.Called(ctx, tenantIDs)
|
|
return args.Error(0)
|
|
}
|
|
|
|
func (m *MockTenantService) ListJoinedTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
|
|
args := m.Called(ctx, userID)
|
|
if args.Get(0) != nil {
|
|
return args.Get(0).([]domain.Tenant), args.Error(1)
|
|
}
|
|
return nil, args.Error(1)
|
|
}
|