package middleware import ( "baron-sso-backend/internal/domain" "baron-sso-backend/internal/service" "context" "net/http/httptest" "os" "testing" "github.com/gofiber/fiber/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) type MockTenantServiceForMiddleware struct { mock.Mock } func (m *MockTenantServiceForMiddleware) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) { return nil, nil } func (m *MockTenantServiceForMiddleware) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) { return nil, nil } func (m *MockTenantServiceForMiddleware) GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error) { args := m.Called(mock.Anything, emailDomain) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*domain.Tenant), args.Error(1) } func (m *MockTenantServiceForMiddleware) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) { args := m.Called(mock.Anything, slug) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*domain.Tenant), args.Error(1) } func (m *MockTenantServiceForMiddleware) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) { return nil, nil } func (m *MockTenantServiceForMiddleware) ListTenants(ctx context.Context, limit, offset int, parentID string) ([]domain.Tenant, int64, error) { return nil, 0, nil } func (m *MockTenantServiceForMiddleware) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) { return nil, nil } func (m *MockTenantServiceForMiddleware) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) { return false, nil } func (m *MockTenantServiceForMiddleware) ApproveTenant(ctx context.Context, id string) error { return nil } func (m *MockTenantServiceForMiddleware) SetKetoService(keto service.KetoService) {} func TestTenantContextMiddleware(t *testing.T) { os.Setenv("USERFRONT_URL", "https://sso.hmac.kr") defer os.Unsetenv("USERFRONT_URL") mockSvc := new(MockTenantServiceForMiddleware) app := fiber.New() app.Use(TenantContextMiddleware(TenantContextConfig{TenantService: mockSvc})) app.Get("/test", func(c *fiber.Ctx) error { return c.JSON(fiber.Map{ "tenant_id": c.Locals("tenant_id"), "tenant_slug": c.Locals("tenant_slug"), }) }) t.Run("Base Domain - No Tenant Context", func(t *testing.T) { req := httptest.NewRequest("GET", "http://sso.hmac.kr/test", nil) req.Host = "sso.hmac.kr" resp, _ := app.Test(req) assert.Equal(t, 200, resp.StatusCode) mockSvc.AssertNotCalled(t, "GetTenantByDomain", mock.Anything, "sso.hmac.kr") }) t.Run("Subdomain - Identify by Slug", func(t *testing.T) { mockSvc.On("GetTenantByDomain", mock.Anything, "tenant1.sso.hmac.kr").Return(nil, nil).Once() mockSvc.On("GetTenantBySlug", mock.Anything, "tenant1").Return(&domain.Tenant{ID: "t1", Slug: "tenant1"}, nil).Once() req := httptest.NewRequest("GET", "/test", nil) req.Host = "tenant1.sso.hmac.kr" req.Header.Set("Host", "tenant1.sso.hmac.kr") resp, _ := app.Test(req) assert.Equal(t, 200, resp.StatusCode) mockSvc.AssertExpectations(t) }) t.Run("Custom Domain - Identify by Domain", func(t *testing.T) { mockSvc.On("GetTenantByDomain", mock.Anything, "company.com").Return(&domain.Tenant{ID: "t2", Slug: "company"}, nil).Once() req := httptest.NewRequest("GET", "/test", nil) req.Host = "company.com" req.Header.Set("Host", "company.com") resp, _ := app.Test(req) assert.Equal(t, 200, resp.StatusCode) mockSvc.AssertExpectations(t) }) } func (m *MockTenantServiceForMiddleware) DeleteTenantsBulk(ctx context.Context, tenantIDs []string) error { return nil } func (m *MockTenantServiceForMiddleware) ListJoinedTenants(ctx context.Context, userID string) ([]domain.Tenant, error) { return nil, nil } func (m *MockTenantServiceForMiddleware) ProvisionTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error) { return nil, nil }