diff --git a/backend/internal/handler/dev_handler.go b/backend/internal/handler/dev_handler.go index 94baa6b1..c96e6fda 100644 --- a/backend/internal/handler/dev_handler.go +++ b/backend/internal/handler/dev_handler.go @@ -1085,6 +1085,15 @@ func (h *DevHandler) UpdateClient(c *fiber.Ctx) error { return errorJSON(c, fiber.StatusInternalServerError, err.Error()) } + if updatedClient.ClientSecret != "" { + if h.SecretRepo != nil { + _ = h.SecretRepo.Upsert(c.Context(), updatedClient.ClientID, updatedClient.ClientSecret) + } + if h.Redis != nil { + _ = h.Redis.Set("client_secret:"+updatedClient.ClientID, updatedClient.ClientSecret, 0) + } + } + summary := h.mapClientSummary(*updatedClient) return c.JSON(clientDetailResponse{ Client: summary, diff --git a/backend/internal/handlerregression/dev_handler_trusted_secret_test.go b/backend/internal/handlerregression/dev_handler_trusted_secret_test.go new file mode 100644 index 00000000..3cb683fe --- /dev/null +++ b/backend/internal/handlerregression/dev_handler_trusted_secret_test.go @@ -0,0 +1,221 @@ +package handlerregression + +import ( + "baron-sso-backend/internal/domain" + "baron-sso-backend/internal/handler" + "baron-sso-backend/internal/service" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gofiber/fiber/v2" +) + +type roundTripFunc func(req *http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +type mockSecretRepo struct { + secrets map[string]string +} + +func (m *mockSecretRepo) Upsert(ctx context.Context, clientID, secret string) error { + if m.secrets == nil { + m.secrets = make(map[string]string) + } + m.secrets[clientID] = secret + return nil +} + +func (m *mockSecretRepo) GetByID(ctx context.Context, clientID string) (string, error) { + return m.secrets[clientID], nil +} + +func (m *mockSecretRepo) Delete(ctx context.Context, clientID string) error { + delete(m.secrets, clientID) + return nil +} + +type mockRedisRepo struct { + data map[string]string +} + +func (m *mockRedisRepo) Set(key, value string, exp time.Duration) error { + if m.data == nil { + m.data = make(map[string]string) + } + m.data[key] = value + return nil +} + +func (m *mockRedisRepo) Get(key string) (string, error) { + v, ok := m.data[key] + if !ok { + return "", fmt.Errorf("not found") + } + return v, nil +} + +func (m *mockRedisRepo) Delete(key string) error { + delete(m.data, key) + return nil +} + +func (m *mockRedisRepo) StoreVerificationCode(p, c string) error { return nil } +func (m *mockRedisRepo) GetVerificationCode(p string) (string, error) { return "", nil } +func (m *mockRedisRepo) DeleteVerificationCode(p string) error { return nil } + +func httpJSONAny(r *http.Request, code int, payload any) *http.Response { + body, _ := json.Marshal(payload) + return &http.Response{ + StatusCode: code, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(bytes.NewReader(body)), + Request: r, + } +} + +func TestUpdateClient_TrustedRPSecretPersistsForLaterDetailFetch(t *testing.T) { + getCount := 0 + + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + if r.Method == http.MethodGet && r.URL.Path == "/clients/client-trusted" { + getCount++ + if getCount == 1 { + return httpJSONAny(r, http.StatusOK, map[string]any{ + "client_id": "client-trusted", + "client_name": "Trusted Before", + "redirect_uris": []string{"https://before.example.com/callback"}, + "grant_types": []string{"authorization_code", "refresh_token"}, + "response_types": []string{"code"}, + "scope": "openid profile", + "token_endpoint_auth_method": "none", + "metadata": map[string]any{ + "status": "active", + }, + }), nil + } + + return httpJSONAny(r, http.StatusOK, map[string]any{ + "client_id": "client-trusted", + "client_name": "Trusted After", + "redirect_uris": []string{"https://trusted.example.com/callback"}, + "grant_types": []string{"authorization_code", "refresh_token"}, + "response_types": []string{"code"}, + "scope": "openid profile", + "token_endpoint_auth_method": "private_key_jwt", + "jwks_uri": "https://trusted.example.com/jwks.json", + "metadata": map[string]any{ + "status": "active", + "headless_login_enabled": true, + "request_object_signing_alg": "RS256", + }, + }), nil + } + + if r.Method == http.MethodPut && r.URL.Path == "/clients/client-trusted" { + return httpJSONAny(r, http.StatusOK, map[string]any{ + "client_id": "client-trusted", + "client_name": "Trusted After", + "client_secret": "trusted-secret", + "redirect_uris": []string{"https://trusted.example.com/callback"}, + "grant_types": []string{"authorization_code", "refresh_token"}, + "response_types": []string{"code"}, + "scope": "openid profile", + "token_endpoint_auth_method": "private_key_jwt", + "jwks_uri": "https://trusted.example.com/jwks.json", + "metadata": map[string]any{ + "status": "active", + "headless_login_enabled": true, + "request_object_signing_alg": "RS256", + }, + }), nil + } + + return httpJSONAny(r, http.StatusNotFound, nil), nil + }) + + secretRepo := &mockSecretRepo{secrets: make(map[string]string)} + redisRepo := &mockRedisRepo{data: make(map[string]string)} + + h := &handler.DevHandler{ + Hydra: &service.HydraAdminService{ + AdminURL: "http://hydra.test", + PublicURL: "http://hydra.public", + HTTPClient: &http.Client{Transport: transport}, + }, + SecretRepo: secretRepo, + Redis: redisRepo, + } + + app := fiber.New() + app.Use(func(c *fiber.Ctx) error { + c.Locals("user_profile", &domain.UserProfileResponse{ID: "test-user", Role: domain.RoleSuperAdmin}) + return c.Next() + }) + app.Put("/api/v1/dev/clients/:id", h.UpdateClient) + app.Get("/api/v1/dev/clients/:id", h.GetClient) + + updateBody, _ := json.Marshal(map[string]any{ + "name": "Trusted After", + "redirectUris": []string{"https://trusted.example.com/callback"}, + "tokenEndpointAuthMethod": "private_key_jwt", + "jwksUri": "https://trusted.example.com/jwks.json", + "metadata": map[string]any{ + "headless_login_enabled": true, + "request_object_signing_alg": "RS256", + }, + }) + updateReq := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-trusted", bytes.NewReader(updateBody)) + updateReq.Header.Set("Content-Type", "application/json") + + updateResp, err := app.Test(updateReq, -1) + if err != nil { + t.Fatalf("update request failed: %v", err) + } + if updateResp.StatusCode != http.StatusOK { + t.Fatalf("expected update 200, got %d", updateResp.StatusCode) + } + + storedSecret, _ := secretRepo.GetByID(context.Background(), "client-trusted") + if storedSecret != "trusted-secret" { + t.Fatalf("expected postgres secret trusted-secret, got %q", storedSecret) + } + + redisSecret, err := redisRepo.Get("client_secret:client-trusted") + if err != nil { + t.Fatalf("expected redis secret, got error: %v", err) + } + if redisSecret != "trusted-secret" { + t.Fatalf("expected redis secret trusted-secret, got %q", redisSecret) + } + + getReq := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-trusted", nil) + getResp, err := app.Test(getReq, -1) + if err != nil { + t.Fatalf("get request failed: %v", err) + } + if getResp.StatusCode != http.StatusOK { + t.Fatalf("expected get 200, got %d", getResp.StatusCode) + } + + var payload struct { + Client struct { + ClientSecret string `json:"clientSecret"` + } `json:"client"` + } + if err := json.NewDecoder(getResp.Body).Decode(&payload); err != nil { + t.Fatalf("decode response: %v", err) + } + if payload.Client.ClientSecret != "trusted-secret" { + t.Fatalf("expected detail secret trusted-secret, got %q", payload.Client.ClientSecret) + } +}