package main import ( "encoding/json" "errors" "net/http" "net/http/httptest" "testing" "github.com/gofiber/fiber/v2" ) func decodeJSONBody(t *testing.T, resp *http.Response) map[string]any { t.Helper() var body map[string]any if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { t.Fatalf("failed to decode response body: %v", err) } return body } func TestNewErrorHandler_ProductionMasksServerError(t *testing.T) { app := fiber.New(fiber.Config{ErrorHandler: newErrorHandler("production")}) app.Get("/boom", func(c *fiber.Ctx) error { return errors.New("database connection failed") }) req := httptest.NewRequest(http.MethodGet, "/boom", nil) resp, err := app.Test(req) if err != nil { t.Fatalf("request failed: %v", err) } if resp.StatusCode != http.StatusInternalServerError { t.Fatalf("unexpected status code: %d", resp.StatusCode) } body := decodeJSONBody(t, resp) if body["error"] != "Internal Server Error" { t.Fatalf("unexpected error message: %v", body["error"]) } if body["code"] != "internal_error" { t.Fatalf("unexpected error code: %v", body["code"]) } } func TestNewErrorHandler_ProductionPassesClientError(t *testing.T) { app := fiber.New(fiber.Config{ErrorHandler: newErrorHandler("production")}) app.Get("/bad", func(c *fiber.Ctx) error { return fiber.NewError(fiber.StatusBadRequest, "bad request payload") }) req := httptest.NewRequest(http.MethodGet, "/bad", nil) resp, err := app.Test(req) if err != nil { t.Fatalf("request failed: %v", err) } if resp.StatusCode != http.StatusBadRequest { t.Fatalf("unexpected status code: %d", resp.StatusCode) } body := decodeJSONBody(t, resp) if body["error"] != "bad request payload" { t.Fatalf("unexpected error message: %v", body["error"]) } if body["code"] != "bad_request" { t.Fatalf("unexpected error code: %v", body["code"]) } } func TestNewErrorHandler_DevelopmentReturnsOriginalServerError(t *testing.T) { app := fiber.New(fiber.Config{ErrorHandler: newErrorHandler("dev")}) app.Get("/boom", func(c *fiber.Ctx) error { return errors.New("database connection failed") }) req := httptest.NewRequest(http.MethodGet, "/boom", nil) resp, err := app.Test(req) if err != nil { t.Fatalf("request failed: %v", err) } if resp.StatusCode != http.StatusInternalServerError { t.Fatalf("unexpected status code: %d", resp.StatusCode) } body := decodeJSONBody(t, resp) if body["error"] != "database connection failed" { t.Fatalf("unexpected error message: %v", body["error"]) } if body["code"] != "internal_error" { t.Fatalf("unexpected error code: %v", body["code"]) } } func TestNewErrorHandler_MapsUnauthorizedCode(t *testing.T) { app := fiber.New(fiber.Config{ErrorHandler: newErrorHandler("production")}) app.Get("/unauthorized", func(c *fiber.Ctx) error { return fiber.NewError(fiber.StatusUnauthorized, "missing token") }) req := httptest.NewRequest(http.MethodGet, "/unauthorized", nil) resp, err := app.Test(req) if err != nil { t.Fatalf("request failed: %v", err) } if resp.StatusCode != http.StatusUnauthorized { t.Fatalf("unexpected status code: %d", resp.StatusCode) } body := decodeJSONBody(t, resp) if body["code"] != "invalid_session" { t.Fatalf("unexpected error code: %v", body["code"]) } } func TestShouldEnableDocs_DisabledOnlyInProduction(t *testing.T) { testCases := []struct { appEnv string want bool }{ {appEnv: "production", want: false}, {appEnv: "prod", want: false}, {appEnv: "stage", want: true}, {appEnv: "staging", want: true}, {appEnv: "dev", want: true}, {appEnv: "development", want: true}, } for _, tc := range testCases { got := shouldEnableDocs(tc.appEnv) if got != tc.want { t.Fatalf("appEnv=%s expected shouldEnableDocs=%v, got %v", tc.appEnv, tc.want, got) } } }