첫 커밋: 로컬 프로젝트 업로드
This commit is contained in:
123
baron-sso/backend/internal/middleware/tenant_middleware_test.go
Normal file
123
baron-sso/backend/internal/middleware/tenant_middleware_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type MockTenantServiceForMiddleware struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error) {
|
||||
args := m.Called(mock.Anything, emailDomain)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
|
||||
args := m.Called(mock.Anything, slug)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) ListTenants(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) ListJoinedTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) ApproveTenant(ctx context.Context, id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) SetKetoService(keto service.KetoService) {}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) DeleteTenantsBulk(ctx context.Context, ids []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestTenantContextMiddleware(t *testing.T) {
|
||||
os.Setenv("USERFRONT_URL", "https://sso.hmac.kr")
|
||||
defer os.Unsetenv("USERFRONT_URL")
|
||||
|
||||
mockSvc := new(MockTenantServiceForMiddleware)
|
||||
app := fiber.New()
|
||||
app.Use(TenantContextMiddleware(TenantContextConfig{TenantService: mockSvc}))
|
||||
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return c.JSON(fiber.Map{
|
||||
"tenant_id": c.Locals("tenant_id"),
|
||||
"tenant_slug": c.Locals("tenant_slug"),
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Base Domain - No Tenant Context", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://sso.hmac.kr/test", nil)
|
||||
req.Host = "sso.hmac.kr"
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
mockSvc.AssertNotCalled(t, "GetTenantByDomain", mock.Anything, "sso.hmac.kr")
|
||||
})
|
||||
|
||||
t.Run("Subdomain - Identify by Slug", func(t *testing.T) {
|
||||
mockSvc.On("GetTenantByDomain", mock.Anything, "tenant1.sso.hmac.kr").Return(nil, nil).Once()
|
||||
mockSvc.On("GetTenantBySlug", mock.Anything, "tenant1").Return(&domain.Tenant{ID: "t1", Slug: "tenant1"}, nil).Once()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Host = "tenant1.sso.hmac.kr"
|
||||
req.Header.Set("Host", "tenant1.sso.hmac.kr")
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Custom Domain - Identify by Domain", func(t *testing.T) {
|
||||
mockSvc.On("GetTenantByDomain", mock.Anything, "company.com").Return(&domain.Tenant{ID: "t2", Slug: "company"}, nil).Once()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Host = "company.com"
|
||||
req.Header.Set("Host", "company.com")
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user