package handler import ( "baron-sso-backend/internal/domain" "baron-sso-backend/internal/service" "bytes" "context" "encoding/json" "io" "mime/multipart" "net/http" "net/http/httptest" "strings" "testing" "github.com/gofiber/fiber/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "gorm.io/gorm" ) // MockTenantService is a mock for service.TenantService type MockTenantService struct { mock.Mock } func (m *MockTenantService) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) { args := m.Called(ctx, name, slug, tenantType, description, domains, parentID, creatorID) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*domain.Tenant), args.Error(1) } func (m *MockTenantService) GetTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) { args := m.Called(ctx, domainName) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*domain.Tenant), args.Error(1) } func (m *MockTenantService) ApproveTenant(ctx context.Context, tenantID string) error { args := m.Called(ctx, tenantID) return args.Error(0) } func (m *MockTenantService) RequestRegistration(ctx context.Context, name, slug, description, domainName, adminEmail string) (*domain.Tenant, error) { args := m.Called(ctx, name, slug, description, domainName, adminEmail) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*domain.Tenant), args.Error(1) } func (m *MockTenantService) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) { args := m.Called(ctx, slug) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*domain.Tenant), args.Error(1) } func (m *MockTenantService) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) { args := m.Called(ctx, id) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*domain.Tenant), args.Error(1) } func (m *MockTenantService) ListTenants(ctx context.Context, limit, offset int, parentID string) ([]domain.Tenant, int64, error) { args := m.Called(ctx, limit, offset, parentID) return args.Get(0).([]domain.Tenant), args.Get(1).(int64), args.Error(2) } func (m *MockTenantService) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) { args := m.Called(ctx, userID) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).([]domain.Tenant), args.Error(1) } func (m *MockTenantService) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) { args := m.Called(ctx, domainName) return args.Bool(0), args.Error(1) } func (m *MockTenantService) SetKetoService(keto service.KetoService) { m.Called(keto) } func (m *MockTenantService) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) { args := m.Called(ctx, domainName) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*domain.Tenant), args.Error(1) } type MockUserRepoForHandler struct { mock.Mock } func (m *MockUserRepoForHandler) Create(ctx context.Context, user *domain.User) error { return nil } func (m *MockUserRepoForHandler) Update(ctx context.Context, user *domain.User) error { return nil } func (m *MockUserRepoForHandler) Delete(ctx context.Context, id string) error { return nil } func (m *MockUserRepoForHandler) FindByEmail(ctx context.Context, email string) (*domain.User, error) { return nil, nil } func (m *MockUserRepoForHandler) FindByID(ctx context.Context, id string) (*domain.User, error) { return nil, nil } func (m *MockUserRepoForHandler) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) { return nil, nil } func (m *MockUserRepoForHandler) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) { return nil, nil } func (m *MockUserRepoForHandler) List(ctx context.Context, offset, limit int, search string, companyCode string) ([]domain.User, int64, error) { for _, call := range m.ExpectedCalls { if call.Method == "List" { args := m.Called(ctx, offset, limit, search, companyCode) return args.Get(0).([]domain.User), args.Get(1).(int64), args.Error(2) } } return nil, 0, nil } func (m *MockUserRepoForHandler) CountByTenant(ctx context.Context, tenantID string) (int64, error) { return 0, nil } func (m *MockUserRepoForHandler) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) { args := m.Called(ctx, tenantIDs) return args.Get(0).([]domain.User), args.Error(1) } func (m *MockUserRepoForHandler) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) { return nil, nil } func (m *MockUserRepoForHandler) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) { args := m.Called(ctx, codes) return args.Get(0).([]domain.User), args.Error(1) } func (m *MockUserRepoForHandler) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) { args := m.Called(ctx, codes) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(map[string]int64), args.Error(1) } func (m *MockUserRepoForHandler) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error { return nil } func (m *MockUserRepoForHandler) GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) { return nil, nil } func (m *MockUserRepoForHandler) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) { return false, nil } func (m *MockUserRepoForHandler) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) { return "", nil } func TestTenantHandler_CreateTenant(t *testing.T) { app := fiber.New() mockSvc := new(MockTenantService) // CreateTenant checks h.DB != nil h := &TenantHandler{Service: mockSvc, DB: &gorm.DB{}} app.Post("/tenants", h.CreateTenant) input := map[string]interface{}{ "name": "Test Tenant", "slug": "test-tenant", "domains": []string{"test.com"}, } body, _ := json.Marshal(input) mockSvc.On("RegisterTenant", mock.Anything, "Test Tenant", "test-tenant", domain.TenantTypeCompany, "", []string(nil), (*string)(nil), ""). Return(&domain.Tenant{ID: "t1", Name: "Test Tenant", Slug: "test-tenant"}, nil) req := httptest.NewRequest("POST", "/tenants", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") resp, _ := app.Test(req) assert.Equal(t, http.StatusCreated, resp.StatusCode) var got map[string]interface{} json.NewDecoder(resp.Body).Decode(&got) assert.Equal(t, "t1", got["id"]) } func TestTenantHandler_ListTenants(t *testing.T) { app := fiber.New() mockSvc := new(MockTenantService) mockUserRepo := new(MockUserRepoForHandler) h := &TenantHandler{ Service: mockSvc, UserRepo: mockUserRepo, } app.Use(func(c *fiber.Ctx) error { c.Locals("user_profile", &domain.UserProfileResponse{ Role: "super_admin", }) return c.Next() }) app.Get("/tenants", h.ListTenants) tenants := []domain.Tenant{ {ID: "t1", Name: "Tenant A", Slug: "slug-a"}, {ID: "t2", Name: "Tenant B", Slug: "slug-b"}, } // Mocking for the new allTenants check in ListTenants mockSvc.On("ListTenants", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tenants, int64(2), nil).Maybe() mockUserRepo.On("CountByCompanyCodes", mock.Anything, mock.Anything). Return(map[string]int64{"slug-a": 5, "slug-b": 10}, nil).Maybe() mockUserRepo.On("CountByTenantIDs", mock.Anything, mock.Anything). Return(map[string]int64{}, nil).Maybe() req := httptest.NewRequest("GET", "/tenants?limit=10&offset=0", nil) resp, _ := app.Test(req) assert.Equal(t, http.StatusOK, resp.StatusCode) var res tenantListResponse json.NewDecoder(resp.Body).Decode(&res) assert.Equal(t, int64(2), res.Total) assert.Len(t, res.Items, 2) // Check if counts are mapped correctly for _, item := range res.Items { if item.Slug == "slug-a" { assert.Equal(t, int64(5), item.MemberCount) } else if item.Slug == "slug-b" { assert.Equal(t, int64(10), item.MemberCount) } } } func TestTenantHandler_ExportTenantsCSV(t *testing.T) { app := fiber.New() mockSvc := new(MockTenantService) h := &TenantHandler{Service: mockSvc} app.Get("/tenants/export", h.ExportTenantsCSV) parentID := "parent-1" tenants := []domain.Tenant{ { ID: "t1", Name: "Tenant A", Type: domain.TenantTypeCompany, ParentID: &parentID, Slug: "tenant-a", Description: "Primary tenant", Domains: []domain.TenantDomain{ {Domain: "tenant-a.example.com"}, {Domain: "login.tenant-a.example.com"}, }, }, } mockSvc.On("ListTenants", mock.Anything, 10000, 0, "").Return(tenants, int64(1), nil) req := httptest.NewRequest("GET", "/tenants/export?includeIds=true", nil) resp, _ := app.Test(req) body, _ := io.ReadAll(resp.Body) assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Contains(t, resp.Header.Get("Content-Disposition"), "tenants.csv") assert.Equal(t, "text/csv", strings.Split(resp.Header.Get("Content-Type"), ";")[0]) assert.Contains(t, string(body), "tenant_id,name,type,parent_tenant_id,parent_tenant_slug,slug,memo,email_domain") assert.Contains(t, string(body), "t1,Tenant A,COMPANY,parent-1,,tenant-a,Primary tenant,tenant-a.example.com;login.tenant-a.example.com") } func TestTenantHandler_ExportTenantsCSV_OmitsIDsAndUsesParentSlug(t *testing.T) { app := fiber.New() mockSvc := new(MockTenantService) h := &TenantHandler{Service: mockSvc} app.Get("/tenants/export", h.ExportTenantsCSV) parentID := "parent-1" tenants := []domain.Tenant{ { ID: parentID, Name: "Parent Tenant", Type: domain.TenantTypeCompanyGroup, Slug: "parent-tenant", }, { ID: "child-1", Name: "Child Tenant", Type: domain.TenantTypeUserGroup, ParentID: &parentID, Slug: "child-tenant", }, } mockSvc.On("ListTenants", mock.Anything, 10000, 0, "").Return(tenants, int64(2), nil) req := httptest.NewRequest("GET", "/tenants/export?includeIds=false", nil) resp, _ := app.Test(req) body, _ := io.ReadAll(resp.Body) text := string(body) assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Contains(t, text, "name,type,parent_tenant_slug,slug,memo,email_domain") assert.Contains(t, text, "Child Tenant,USER_GROUP,parent-tenant,child-tenant,,") assert.NotContains(t, text, "tenant_id") assert.NotContains(t, text, "parent_tenant_id") assert.NotContains(t, text, "child-1") mockSvc.AssertExpectations(t) } func TestTenantHandler_ImportTenantsCSVCreatesTenant(t *testing.T) { app := fiber.New() mockSvc := new(MockTenantService) h := &TenantHandler{Service: mockSvc} app.Post("/tenants/import", h.ImportTenantsCSV) var body bytes.Buffer writer := multipart.NewWriter(&body) part, err := writer.CreateFormFile("file", "tenants.csv") assert.NoError(t, err) _, err = part.Write([]byte("tenant_id,name,type,parent_tenant_id,slug,memo,email_domain\n,Imported Tenant,COMPANY,parent-1,imported-tenant,Imported memo,imported.example.com;login.imported.example.com\n")) assert.NoError(t, err) assert.NoError(t, writer.Close()) mockSvc.On("ListTenants", mock.Anything, 10000, 0, "").Return([]domain.Tenant{}, int64(0), nil).Once() mockSvc.On( "RegisterTenant", mock.Anything, "Imported Tenant", "imported-tenant", domain.TenantTypeCompany, "Imported memo", []string{"imported.example.com", "login.imported.example.com"}, mock.MatchedBy(func(parentID *string) bool { return parentID != nil && *parentID == "parent-1" }), "", ).Return(&domain.Tenant{ID: "imported-1", Name: "Imported Tenant", Slug: "imported-tenant"}, nil) req := httptest.NewRequest("POST", "/tenants/import", &body) req.Header.Set("Content-Type", writer.FormDataContentType()) resp, _ := app.Test(req) assert.Equal(t, http.StatusOK, resp.StatusCode) var got map[string]interface{} json.NewDecoder(resp.Body).Decode(&got) assert.Equal(t, float64(1), got["created"]) assert.Equal(t, float64(0), got["updated"]) assert.Equal(t, float64(0), got["failed"]) mockSvc.AssertExpectations(t) } func TestTenantHandler_ImportTenantsCSVResolvesParentSlugToID(t *testing.T) { app := fiber.New() mockSvc := new(MockTenantService) h := &TenantHandler{Service: mockSvc} app.Post("/tenants/import", h.ImportTenantsCSV) var body bytes.Buffer writer := multipart.NewWriter(&body) part, err := writer.CreateFormFile("file", "tenants.csv") assert.NoError(t, err) _, err = part.Write([]byte("name,type,parent_tenant_slug,slug,memo,email_domain\nParent Tenant,COMPANY,,parent-slug,,\nChild Tenant,ORGANIZATION,parent-slug,child-slug,,\n")) assert.NoError(t, err) assert.NoError(t, writer.Close()) parentID := "parent-id" mockSvc.On("ListTenants", mock.Anything, 10000, 0, "").Return([]domain.Tenant{}, int64(0), nil).Once() mockSvc.On( "RegisterTenant", mock.Anything, "Parent Tenant", "parent-slug", domain.TenantTypeCompany, "", []string{}, (*string)(nil), "", ).Return(&domain.Tenant{ID: parentID, Name: "Parent Tenant", Slug: "parent-slug"}, nil).Once() mockSvc.On( "RegisterTenant", mock.Anything, "Child Tenant", "child-slug", domain.TenantTypeOrganization, "", []string{}, mock.MatchedBy(func(got *string) bool { return got != nil && *got == parentID }), "", ).Return(&domain.Tenant{ID: "child-id", Name: "Child Tenant", Slug: "child-slug"}, nil).Once() req := httptest.NewRequest("POST", "/tenants/import", &body) req.Header.Set("Content-Type", writer.FormDataContentType()) resp, _ := app.Test(req) assert.Equal(t, http.StatusOK, resp.StatusCode) var got map[string]interface{} json.NewDecoder(resp.Body).Decode(&got) assert.Equal(t, float64(2), got["created"]) assert.Equal(t, float64(0), got["failed"]) mockSvc.AssertExpectations(t) } func TestNormalizeTenantTypeAllowsOrganization(t *testing.T) { assert.Equal(t, domain.TenantTypeOrganization, normalizeTenantType("organization")) } func TestTenantCSVAllowedDomainsRoundTrip(t *testing.T) { records, err := parseTenantCSVRecords(strings.NewReader( "name,type,parent_tenant_slug,slug,memo,email_domain\n" + "Hanmac,COMPANY,,hanmac,,\"samaneng.com, hanmaceng.co.kr;login.hmac.kr\"\n", )) assert.NoError(t, err) assert.Len(t, records, 1) assert.Equal(t, []string{"samaneng.com", "hanmaceng.co.kr", "login.hmac.kr"}, records[0].Domains) } func TestNormalizeTenantDomainInputsSplitsCommaAndWhitespace(t *testing.T) { got := normalizeTenantDomainInputs([]string{ "samaneng.com, hanmaceng.co.kr", " LOGIN.HMAC.KR\nportal.hmac.kr ", "samaneng.com", }) assert.Equal(t, []string{ "samaneng.com", "hanmaceng.co.kr", "login.hmac.kr", "portal.hmac.kr", }, got) } func TestNormalizeTenantConfigForcesIndexedForLoginIDFields(t *testing.T) { config, err := normalizeTenantConfig(map[string]any{ "userSchema": []any{ map[string]any{ "key": "emp_no", "label": "사번", "type": "text", "indexed": false, "isLoginId": true, "maxLength": 20, }, }, }) assert.NoError(t, err) fields, ok := config["userSchema"].([]any) assert.True(t, ok) assert.Len(t, fields, 1) field, ok := fields[0].(map[string]any) assert.True(t, ok) assert.Equal(t, true, field["indexed"]) assert.Equal(t, true, field["isLoginId"]) assert.NotContains(t, field, "maxLength") } func TestNormalizeTenantConfigRejectsNonTextLoginIDFields(t *testing.T) { _, err := normalizeTenantConfig(map[string]any{ "userSchema": []any{ map[string]any{ "key": "emp_no", "type": "number", "isLoginId": true, }, }, }) assert.Error(t, err) assert.Contains(t, err.Error(), "login ID fields must be text") } func TestTenantHandler_ApproveTenant(t *testing.T) { app := fiber.New() mockSvc := new(MockTenantService) h := &TenantHandler{Service: mockSvc} app.Post("/tenants/:id/approve", h.ApproveTenant) mockSvc.On("ApproveTenant", mock.Anything, "t1").Return(nil) req := httptest.NewRequest("POST", "/tenants/t1/approve", nil) resp, _ := app.Test(req) assert.Equal(t, http.StatusOK, resp.StatusCode) } func (m *MockTenantService) DeleteTenantsBulk(ctx context.Context, tenantIDs []string) error { args := m.Called(ctx, tenantIDs) return args.Error(0) } func (m *MockTenantService) ListJoinedTenants(ctx context.Context, userID string) ([]domain.Tenant, error) { args := m.Called(ctx, userID) if args.Get(0) != nil { return args.Get(0).([]domain.Tenant), args.Error(1) } return nil, args.Error(1) }