diff --git a/backend/internal/handler/dev_handler.go b/backend/internal/handler/dev_handler.go index 8f3ed43d..f02d9f8a 100644 --- a/backend/internal/handler/dev_handler.go +++ b/backend/internal/handler/dev_handler.go @@ -268,14 +268,25 @@ func isDevConsoleViewerRole(role string) bool { } } +func setCurrentProfileContext(c *fiber.Ctx, profile *domain.UserProfileResponse) { + if profile == nil { + return + } + c.Locals("user_profile", profile) + if existingUserID, _ := c.Locals("user_id").(string); existingUserID == "" && profile.ID != "" { + c.Locals("user_id", profile.ID) + } +} + func (h *DevHandler) getCurrentProfile(c *fiber.Ctx) *domain.UserProfileResponse { if profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse); ok && profile != nil { + setCurrentProfileContext(c, profile) return profile } if h.Auth != nil { enriched, err := h.Auth.GetEnrichedProfile(c) if err == nil && enriched != nil { - c.Locals("user_profile", enriched) + setCurrentProfileContext(c, enriched) return enriched } } @@ -909,10 +920,11 @@ func (h *DevHandler) checkAppManagerPermission(c *fiber.Ctx) (bool, error) { if err == nil && enriched != nil { profile = enriched ok = true - c.Locals("user_profile", enriched) + setCurrentProfileContext(c, enriched) } } if ok && profile != nil { + setCurrentProfileContext(c, profile) role := normalizeUserRole(profile.Role) switch role { case domain.RoleSuperAdmin: @@ -3583,7 +3595,7 @@ func (h *DevHandler) RequestDeveloperAccess(c *fiber.Ctx) error { if h.Auth != nil { if enriched, err := h.Auth.GetEnrichedProfile(c); err == nil && enriched != nil { profile = enriched - c.Locals("user_profile", enriched) + setCurrentProfileContext(c, enriched) } } diff --git a/backend/internal/handler/dev_handler_test.go b/backend/internal/handler/dev_handler_test.go index d9b3875c..9f9cc291 100644 --- a/backend/internal/handler/dev_handler_test.go +++ b/backend/internal/handler/dev_handler_test.go @@ -129,6 +129,18 @@ type devMockKetoOutboxRepository struct { mock.Mock } +type devMockAuthProvider struct { + mock.Mock +} + +func (m *devMockAuthProvider) GetEnrichedProfile(c *fiber.Ctx) (*domain.UserProfileResponse, error) { + args := m.Called(c) + if profile, ok := args.Get(0).(*domain.UserProfileResponse); ok { + return profile, args.Error(1) + } + return nil, args.Error(1) +} + func (m *devMockKetoOutboxRepository) Create(ctx context.Context, entry *domain.KetoOutbox) error { return m.Called(ctx, entry).Error(0) } @@ -208,6 +220,66 @@ func devTestJWKSFirstKeyString(t *testing.T, jwks map[string]any, field string) // --- Tests --- +func TestGetCurrentProfile_SetsAuditUserContext(t *testing.T) { + mockAuth := new(devMockAuthProvider) + handler := &DevHandler{Auth: mockAuth} + app := fiber.New() + + mockAuth.On("GetEnrichedProfile", mock.Anything).Return(&domain.UserProfileResponse{ + ID: "0a5b7284-e88a-4fdf-b56f-98d0435b24f5", + Role: domain.RoleUser, + }, nil) + + app.Get("/test", func(c *fiber.Ctx) error { + profile := handler.getCurrentProfile(c) + return c.JSON(fiber.Map{ + "profile_id": profile.ID, + "user_id": c.Locals("user_id"), + }) + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + resp, _ := app.Test(req) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var body map[string]string + assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) + assert.NoError(t, resp.Body.Close()) + assert.Equal(t, "0a5b7284-e88a-4fdf-b56f-98d0435b24f5", body["profile_id"]) + assert.Equal(t, "0a5b7284-e88a-4fdf-b56f-98d0435b24f5", body["user_id"]) +} + +func TestGetCurrentProfile_PreservesExistingAuditUserContext(t *testing.T) { + handler := &DevHandler{} + app := fiber.New() + + app.Use(func(c *fiber.Ctx) error { + c.Locals("user_profile", &domain.UserProfileResponse{ + ID: "profile-user", + Role: domain.RoleUser, + }) + c.Locals("user_id", "existing-user") + return c.Next() + }) + app.Get("/test", func(c *fiber.Ctx) error { + profile := handler.getCurrentProfile(c) + return c.JSON(fiber.Map{ + "profile_id": profile.ID, + "user_id": c.Locals("user_id"), + }) + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + resp, _ := app.Test(req) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var body map[string]string + assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) + assert.NoError(t, resp.Body.Close()) + assert.Equal(t, "profile-user", body["profile_id"]) + assert.Equal(t, "existing-user", body["user_id"]) +} + func TestListClients_Success(t *testing.T) { transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { if r.URL.Path == "/clients" { diff --git a/backend/internal/middleware/rbac.go b/backend/internal/middleware/rbac.go index 5a15e190..314c5859 100644 --- a/backend/internal/middleware/rbac.go +++ b/backend/internal/middleware/rbac.go @@ -20,6 +20,16 @@ type AuthProfileProvider interface { GetEnrichedProfile(c *fiber.Ctx) (*domain.UserProfileResponse, error) } +func setAuditUserContext(c *fiber.Ctx, profile *domain.UserProfileResponse) { + if profile == nil || profile.ID == "" { + return + } + if existingUserID, _ := c.Locals("user_id").(string); existingUserID != "" { + return + } + c.Locals("user_id", profile.ID) +} + // RequireKetoPermission enforces permissions using Ory Keto (ReBAC) func RequireKetoPermission(config RBACConfig, namespace, relation string) fiber.Handler { return func(c *fiber.Ctx) error { @@ -30,6 +40,7 @@ func RequireKetoPermission(config RBACConfig, namespace, relation string) fiber. // Store profile in locals for further use in handlers c.Locals("user_profile", profile) + setAuditUserContext(c, profile) role := domain.NormalizeRole(profile.Role) @@ -92,6 +103,7 @@ func RequireRole(config RBACConfig) fiber.Handler { // Store profile in locals for further use in handlers c.Locals("user_profile", profile) + setAuditUserContext(c, profile) userRole := domain.NormalizeRole(profile.Role) @@ -139,6 +151,7 @@ func RequireTenantMatch(config RBACConfig) fiber.Handler { // Store profile in locals for further use in handlers c.Locals("user_profile", profile) + setAuditUserContext(c, profile) userRole := domain.NormalizeRole(profile.Role) diff --git a/backend/internal/middleware/rbac_test.go b/backend/internal/middleware/rbac_test.go index 54bd4b9d..20425ea5 100644 --- a/backend/internal/middleware/rbac_test.go +++ b/backend/internal/middleware/rbac_test.go @@ -4,7 +4,9 @@ import ( "baron-sso-backend/internal/domain" "baron-sso-backend/internal/service" "context" + "encoding/json" "errors" + "net/http" "net/http/httptest" "testing" @@ -89,6 +91,68 @@ func TestRequireRole_Success(t *testing.T) { assert.Equal(t, 200, resp.StatusCode) } +func TestRequireRole_SetsUserIDForAuditContext(t *testing.T) { + app := fiber.New() + mockAuth := new(MockAuthProvider) + config := RBACConfig{ + AllowedRoles: []string{"admin"}, + AuthHandler: mockAuth, + } + + mockAuth.On("GetEnrichedProfile", mock.Anything).Return(&domain.UserProfileResponse{ + ID: "user1", + Role: "admin", + }, nil) + + app.Get("/test", RequireRole(config), func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{ + "user_id": c.Locals("user_id"), + }) + }) + + req := httptest.NewRequest("GET", "/test", nil) + resp, _ := app.Test(req) + + assert.Equal(t, 200, resp.StatusCode) + + var body map[string]string + assert.NoError(t, readJSON(resp, &body)) + assert.Equal(t, "user1", body["user_id"]) +} + +func TestRequireRole_PreservesExistingUserID(t *testing.T) { + app := fiber.New() + mockAuth := new(MockAuthProvider) + config := RBACConfig{ + AllowedRoles: []string{"admin"}, + AuthHandler: mockAuth, + } + + mockAuth.On("GetEnrichedProfile", mock.Anything).Return(&domain.UserProfileResponse{ + ID: "profile-user", + Role: "admin", + }, nil) + + app.Use(func(c *fiber.Ctx) error { + c.Locals("user_id", "existing-user") + return c.Next() + }) + app.Get("/test", RequireRole(config), func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{ + "user_id": c.Locals("user_id"), + }) + }) + + req := httptest.NewRequest("GET", "/test", nil) + resp, _ := app.Test(req) + + assert.Equal(t, 200, resp.StatusCode) + + var body map[string]string + assert.NoError(t, readJSON(resp, &body)) + assert.Equal(t, "existing-user", body["user_id"]) +} + func TestRequireRole_Forbidden(t *testing.T) { app := fiber.New() mockAuth := new(MockAuthProvider) @@ -199,3 +263,8 @@ func TestRequireRole_Unauthorized(t *testing.T) { assert.Equal(t, 401, resp.StatusCode) } + +func readJSON(resp *http.Response, target any) error { + defer resp.Body.Close() + return json.NewDecoder(resp.Body).Decode(target) +}