1
0
forked from baron/baron-sso

orgfront 버그 픽스

This commit is contained in:
2026-06-10 09:36:57 +09:00
parent 28478309fa
commit c880b3c333
33 changed files with 853 additions and 130 deletions

View File

@@ -39,7 +39,7 @@ type DevHandler struct {
KetoOutbox repository.KetoOutboxRepository
RPSvc service.RelyingPartyService
TenantSvc service.TenantService
DeveloperSvc *service.DeveloperService
DeveloperSvc developerRequestService
RPUserMetadataRepo repository.RPUserMetadataRepository
RPUsageQueries domain.RPUsageQueryRepository
Auth interface {
@@ -47,6 +47,16 @@ type DevHandler struct {
}
}
type developerRequestService interface {
RequestAccess(ctx context.Context, req domain.DeveloperRequest) error
GetRequestStatus(ctx context.Context, userID, tenantID string) (*domain.DeveloperRequest, error)
GetRequestByID(ctx context.Context, id uint) (*domain.DeveloperRequest, error)
ListRequests(ctx context.Context, userID, status string) ([]domain.DeveloperRequest, error)
ApproveRequest(ctx context.Context, id uint, adminNotes string) error
RejectRequest(ctx context.Context, id uint, adminNotes string) error
CancelApprovedRequest(ctx context.Context, id uint, adminNotes string) error
}
func NewDevHandler(
redis domain.RedisRepository,
secretRepo domain.ClientSecretRepository,
@@ -426,7 +436,28 @@ func (h *DevHandler) canManageTenantClientsByPermit(c *fiber.Ctx, profile *domai
return false
}
allowed, err := h.checkProfileKetoPermission(c, profile, "Tenant", tenantID, "grant_dev_permissions")
return err == nil && allowed
if err == nil && allowed {
return true
}
return h.hasApprovedDeveloperRequest(c, profile, tenantID)
}
func (h *DevHandler) hasApprovedDeveloperRequest(c *fiber.Ctx, profile *domain.UserProfileResponse, tenantID string) bool {
if h.DeveloperSvc == nil || profile == nil {
return false
}
userID := strings.TrimSpace(profile.ID)
tenantID = strings.TrimSpace(tenantID)
if userID == "" || tenantID == "" {
return false
}
status, err := h.DeveloperSvc.GetRequestStatus(c.Context(), userID, tenantID)
if err != nil || status == nil {
return false
}
return status.Status == domain.DeveloperRequestStatusApproved &&
strings.TrimSpace(status.UserID) == userID &&
strings.TrimSpace(status.TenantID) == tenantID
}
func (h *DevHandler) canOperateClientByPermit(c *fiber.Ctx, profile *domain.UserProfileResponse, summary clientSummary, relation string) bool {

View File

@@ -62,6 +62,54 @@ func (m *devMockKetoService) ListObjects(ctx context.Context, ns, rel, sub strin
return args.Get(0).([]string), args.Error(1)
}
type devMockDeveloperService struct {
mock.Mock
}
func (m *devMockDeveloperService) RequestAccess(ctx context.Context, req domain.DeveloperRequest) error {
args := m.Called(ctx, req)
return args.Error(0)
}
func (m *devMockDeveloperService) GetRequestStatus(ctx context.Context, userID, tenantID string) (*domain.DeveloperRequest, error) {
args := m.Called(ctx, userID, tenantID)
if req, ok := args.Get(0).(*domain.DeveloperRequest); ok {
return req, args.Error(1)
}
return nil, args.Error(1)
}
func (m *devMockDeveloperService) GetRequestByID(ctx context.Context, id uint) (*domain.DeveloperRequest, error) {
args := m.Called(ctx, id)
if req, ok := args.Get(0).(*domain.DeveloperRequest); ok {
return req, args.Error(1)
}
return nil, args.Error(1)
}
func (m *devMockDeveloperService) ListRequests(ctx context.Context, userID, status string) ([]domain.DeveloperRequest, error) {
args := m.Called(ctx, userID, status)
if requests, ok := args.Get(0).([]domain.DeveloperRequest); ok {
return requests, args.Error(1)
}
return nil, args.Error(1)
}
func (m *devMockDeveloperService) ApproveRequest(ctx context.Context, id uint, adminNotes string) error {
args := m.Called(ctx, id, adminNotes)
return args.Error(0)
}
func (m *devMockDeveloperService) RejectRequest(ctx context.Context, id uint, adminNotes string) error {
args := m.Called(ctx, id, adminNotes)
return args.Error(0)
}
func (m *devMockDeveloperService) CancelApprovedRequest(ctx context.Context, id uint, adminNotes string) error {
args := m.Called(ctx, id, adminNotes)
return args.Error(0)
}
type devMockRedisRepo struct {
data map[string]string
}
@@ -1521,6 +1569,66 @@ func TestCreateClient_ApprovedDeveloperCanCreatePrivateClient(t *testing.T) {
mockKeto.AssertExpectations(t)
}
func TestCreateClient_ApprovedDeveloperRequestAllowsCreateWhenTenantGrantNotVisible(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.Method == http.MethodPost && r.URL.Path == "/clients" {
var body map[string]any
_ = json.NewDecoder(r.Body).Decode(&body)
body["client_secret"] = "generated-secret"
return httpJSONAny(r, http.StatusCreated, body), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
mockKeto := new(devMockKetoService)
mockKeto.On("CheckPermission", mock.Anything, "User:user-1", "Tenant", "tenant-a", "grant_dev_permissions").Return(false, nil).Maybe()
mockKeto.On("CheckPermission", mock.Anything, "User:user-1", "System", "global", "manage_all").Return(false, nil).Maybe()
developerSvc := new(devMockDeveloperService)
developerSvc.On("GetRequestStatus", mock.Anything, "user-1", "tenant-a").Return(&domain.DeveloperRequest{
UserID: "user-1",
TenantID: "tenant-a",
Status: domain.DeveloperRequestStatusApproved,
}, nil).Maybe()
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
},
SecretRepo: &mockSecretRepo{secrets: make(map[string]string)},
Redis: &devMockRedisRepo{data: make(map[string]string)},
Keto: mockKeto,
DeveloperSvc: developerSvc,
}
app := fiber.New()
tenantID := "tenant-a"
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
ID: "user-1",
Role: domain.RoleUser,
TenantID: &tenantID,
})
return c.Next()
})
app.Post("/api/v1/dev/clients", h.CreateClient)
body, _ := json.Marshal(map[string]any{
"id": "client-1",
"name": "App One",
"type": "private",
"redirectUris": []string{"http://localhost/cb"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusCreated, resp.StatusCode)
mockKeto.AssertExpectations(t)
developerSvc.AssertExpectations(t)
}
func TestGrantCreatorAdminRelation_FallsBackToOutboxOnImmediateFailure(t *testing.T) {
mockKeto := new(devMockKetoService)
mockKeto.On("CheckPermission", mock.Anything, mock.Anything, "System", "global", "manage_all").Return(false, nil).Maybe()

View File

@@ -14,6 +14,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"maps"
"os"
"reflect"
@@ -22,6 +23,7 @@ import (
"strings"
"time"
"github.com/go-redis/redis/v8"
"github.com/gofiber/fiber/v2"
"gorm.io/gorm"
)
@@ -315,26 +317,14 @@ func (h *TenantHandler) ListTenants(c *fiber.Ctx) error {
}
}
findRoot := func(id string) string {
curr := id
for {
p, exists := parentMap[curr]
if !exists || p == "" {
break
}
curr = p
}
return curr
}
roots := make(map[string]bool)
for _, id := range baseTenantIDs {
roots[findRoot(id)] = true
roots[findTenantRootID(parentMap, id)] = true
}
// Filter tenants that belong to the same tree family
for _, t := range allTenants {
if roots[findRoot(t.ID)] {
if roots[findTenantRootID(parentMap, t.ID)] {
tenants = append(tenants, t)
}
}
@@ -2774,6 +2764,14 @@ func (h *TenantHandler) GetOrgChartSnapshot(c *fiber.Ctx) error {
cacheMode := strings.ToLower(strings.TrimSpace(c.Query("cache")))
cacheKey := orgChartSnapshotCacheKey(profile, c.Get("X-Tenant-ID"))
ttl := orgChartSnapshotCacheTTL()
role, userID, profileTenantID := orgChartProfileLogValues(profile)
slog.Info("orgchart snapshot request started",
"user_id", userID,
"role", role,
"profile_tenant_id", profileTenantID,
"tenant_header", c.Get("X-Tenant-ID"),
"cache_mode", cacheMode,
)
if cacheMode == "redis" && h.OrgChartCache != nil {
if raw, err := h.OrgChartCache.Get(cacheKey); err == nil && strings.TrimSpace(raw) != "" {
@@ -2785,13 +2783,43 @@ func (h *TenantHandler) GetOrgChartSnapshot(c *fiber.Ctx) error {
TTLSeconds: int(ttl.Seconds()),
}
c.Set("X-Orgfront-Cache", "HIT")
slog.Info("orgchart snapshot cache hit",
"user_id", userID,
"role", role,
"profile_tenant_id", profileTenantID,
"tenant_header", c.Get("X-Tenant-ID"),
"tenant_count", len(cached.Tenants),
"user_count", len(cached.Users),
)
return c.JSON(cached)
}
slog.Warn("orgchart snapshot cache payload ignored",
"user_id", userID,
"role", role,
"profile_tenant_id", profileTenantID,
"tenant_header", c.Get("X-Tenant-ID"),
"error", err,
)
} else if err != nil && err != redis.Nil {
slog.Warn("orgchart snapshot cache read failed",
"user_id", userID,
"role", role,
"profile_tenant_id", profileTenantID,
"tenant_header", c.Get("X-Tenant-ID"),
"error", err,
)
}
}
snapshot, err := h.buildOrgChartSnapshot(c.Context(), profile)
if err != nil {
slog.Error("orgchart snapshot build failed",
"user_id", userID,
"role", role,
"profile_tenant_id", profileTenantID,
"tenant_header", c.Get("X-Tenant-ID"),
"error", err,
)
return errorJSON(c, fiber.StatusServiceUnavailable, err.Error())
}
snapshot.Cache = orgChartSnapshotCacheInfo{
@@ -2802,13 +2830,31 @@ func (h *TenantHandler) GetOrgChartSnapshot(c *fiber.Ctx) error {
if cacheMode == "redis" && h.OrgChartCache != nil {
if raw, err := json.Marshal(snapshot); err == nil {
_ = h.OrgChartCache.Set(cacheKey, string(raw), ttl)
if err := h.OrgChartCache.Set(cacheKey, string(raw), ttl); err != nil {
slog.Warn("orgchart snapshot cache write failed",
"user_id", userID,
"role", role,
"profile_tenant_id", profileTenantID,
"tenant_header", c.Get("X-Tenant-ID"),
"error", err,
)
}
}
c.Set("X-Orgfront-Cache", "MISS")
} else {
c.Set("X-Orgfront-Cache", "BYPASS")
}
slog.Info("orgchart snapshot request completed",
"user_id", userID,
"role", role,
"profile_tenant_id", profileTenantID,
"tenant_header", c.Get("X-Tenant-ID"),
"cache_mode", cacheMode,
"cache_result", c.GetRespHeader("X-Orgfront-Cache"),
"tenant_count", len(snapshot.Tenants),
"user_count", len(snapshot.Users),
)
return c.JSON(snapshot)
}
@@ -2880,27 +2926,16 @@ func (h *TenantHandler) listOrgChartTenantsForProfile(ctx context.Context, profi
parentMap[tenant.ID] = *tenant.ParentID
}
}
findRoot := func(id string) string {
curr := id
for {
parentID, exists := parentMap[curr]
if !exists || parentID == "" {
return curr
}
curr = parentID
}
}
roots := make(map[string]bool)
for _, id := range baseTenantIDs {
if strings.TrimSpace(id) != "" {
roots[findRoot(id)] = true
roots[findTenantRootID(parentMap, id)] = true
}
}
tenants := make([]domain.Tenant, 0, len(allTenants))
for _, tenant := range allTenants {
if roots[findRoot(tenant.ID)] {
if roots[findTenantRootID(parentMap, tenant.ID)] {
tenants = append(tenants, tenant)
}
}
@@ -2980,6 +3015,36 @@ func orgChartSnapshotCacheKey(profile *domain.UserProfileResponse, tenantHeader
return fmt.Sprintf("orgchart:snapshot:v1:%s:%s:%s", role, userID, tenantID)
}
func orgChartProfileLogValues(profile *domain.UserProfileResponse) (string, string, string) {
if profile == nil {
return "anonymous", "anonymous", ""
}
tenantID := ""
if profile.TenantID != nil {
tenantID = strings.TrimSpace(*profile.TenantID)
}
return domain.NormalizeRole(profile.Role), strings.TrimSpace(profile.ID), tenantID
}
func findTenantRootID(parentMap map[string]string, tenantID string) string {
curr := strings.TrimSpace(tenantID)
if curr == "" {
return ""
}
visited := map[string]struct{}{}
for {
parentID := strings.TrimSpace(parentMap[curr])
if parentID == "" || parentID == curr {
return curr
}
if _, exists := visited[parentID]; exists {
return parentID
}
visited[curr] = struct{}{}
curr = parentID
}
}
func orgChartSnapshotCacheTTL() time.Duration {
const defaultTTL = 5 * time.Minute
raw := strings.TrimSpace(os.Getenv("ORGFRONT_ORGCHART_CACHE_TTL_SECONDS"))
@@ -2996,16 +3061,26 @@ func orgChartSnapshotCacheTTL() time.Duration {
func (h *TenantHandler) GetPublicOrgChart(c *fiber.Ctx) error {
token := c.Query("token")
if token == "" {
slog.Warn("public orgchart rejected missing token")
return errorJSON(c, fiber.StatusUnauthorized, "share token is required")
}
link, err := h.SharedLink.ValidateToken(c.Context(), token)
if err != nil {
slog.Warn("public orgchart token validation failed",
"token_length", len(token),
"error", err,
)
return errorJSON(c, fiber.StatusUnauthorized, err.Error())
}
allTenants, _, err := h.Service.ListTenants(c.Context(), 10000, 0, "", "")
if err != nil {
slog.Error("public orgchart tenant list failed",
"link_id", link.ID,
"tenant_id", link.TenantID,
"error", err,
)
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
@@ -3016,24 +3091,12 @@ func (h *TenantHandler) GetPublicOrgChart(c *fiber.Ctx) error {
}
}
findRoot := func(id string) string {
curr := id
for {
p, exists := parentMap[curr]
if !exists || p == "" {
break
}
curr = p
}
return curr
}
sharedRootID := findRoot(link.TenantID)
sharedRootID := findTenantRootID(parentMap, link.TenantID)
var filteredTenants []domain.Tenant
var tenantIDs []string
for _, t := range allTenants {
if findRoot(t.ID) == sharedRootID {
if findTenantRootID(parentMap, t.ID) == sharedRootID {
filteredTenants = append(filteredTenants, t)
}
}
@@ -3076,6 +3139,13 @@ func (h *TenantHandler) GetPublicOrgChart(c *fiber.Ctx) error {
tenantSummaries = append(tenantSummaries, mapTenantSummary(t))
}
slog.Info("public orgchart request completed",
"link_id", link.ID,
"tenant_id", link.TenantID,
"shared_root_id", sharedRootID,
"tenant_count", len(tenantSummaries),
"user_count", len(publicUsers),
)
return c.JSON(fiber.Map{
"tenants": tenantSummaries,
"users": publicUsers,

View File

@@ -405,6 +405,68 @@ func TestTenantHandler_GetOrgChartSnapshotCachesMissResult(t *testing.T) {
mockUsers.AssertExpectations(t)
}
func TestTenantHandler_GetOrgChartSnapshotHandlesSelfParentHanmacFamily(t *testing.T) {
app := fiber.New()
mockSvc := new(MockTenantService)
mockProjection := new(MockUserProjectionRepoForHandler)
mockUsers := new(MockUserRepoForHandler)
now := time.Date(2026, 6, 10, 0, 0, 0, 0, time.UTC)
parent := func(id string) *string { return &id }
familyID := "hanmac-family-id"
samanID := "saman-id"
teamID := "saman-platform-id"
tenants := []domain.Tenant{
{ID: familyID, Type: domain.TenantTypeCompanyGroup, ParentID: parent(familyID), Name: "한맥가족", Slug: "hanmac-family", Status: domain.TenantStatusActive, CreatedAt: now, UpdatedAt: now},
{ID: samanID, Type: domain.TenantTypeCompany, ParentID: parent(familyID), Name: "삼안", Slug: "saman", Status: domain.TenantStatusActive, CreatedAt: now, UpdatedAt: now},
{ID: teamID, Type: domain.TenantTypeUserGroup, ParentID: parent(samanID), Name: "플랫폼팀", Slug: "saman-platform", Status: domain.TenantStatusActive, CreatedAt: now, UpdatedAt: now},
}
users := []domain.User{
{ID: "user-1", Email: "user@samaneng.com", Name: "Saman User", Role: domain.RoleUser, Status: domain.UserStatusActive, TenantID: &samanID, Tenant: &tenants[1], CreatedAt: now, UpdatedAt: now},
}
h := &TenantHandler{Service: mockSvc, UserRepo: mockUsers, UserProjectionRepo: mockProjection}
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
ID: "user-1",
Role: domain.RoleUser,
TenantID: &samanID,
JoinedTenants: []domain.Tenant{
tenants[1],
},
})
return c.Next()
})
app.Get("/admin/orgchart/snapshot", h.GetOrgChartSnapshot)
mockSvc.On("ListTenants", mock.Anything, 10000, 0, "", "").Return(tenants, int64(len(tenants)), nil).Once()
mockProjection.On("IsReady", mock.Anything).Return(true, nil).Once()
mockProjection.On("CountTenantMembers", mock.Anything, mock.MatchedBy(func(got []domain.Tenant) bool {
return tenantSlugsMatch(got, "hanmac-family", "saman", "saman-platform")
})).Return(map[string]int64{familyID: 0, samanID: 1, teamID: 0}, nil).Once()
mockProjection.On("CountTenantMembersRecursive", mock.Anything, mock.MatchedBy(func(got []domain.Tenant) bool {
return tenantSlugsMatch(got, "hanmac-family", "saman", "saman-platform")
})).Return(map[string]int64{familyID: 1, samanID: 1, teamID: 0}, nil).Once()
mockUsers.On("List", mock.Anything, 0, 10000, "", []string{familyID, samanID, teamID}, "").Return(users, int64(1), "", nil).Once()
mockSvc.On("ListJoinedTenants", mock.Anything, "user-1").Return([]domain.Tenant{tenants[1], tenants[2]}, nil).Once()
req := httptest.NewRequest(http.MethodGet, "/admin/orgchart/snapshot", nil)
resp, err := app.Test(req, 1000)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
var body struct {
Tenants []tenantSummary `json:"tenants"`
Users []userSummary `json:"users"`
}
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Len(t, body.Tenants, 3)
require.True(t, tenantSummarySlugsMatch(body.Tenants, "hanmac-family", "saman", "saman-platform"))
require.Len(t, body.Users, 1)
mockSvc.AssertExpectations(t)
mockProjection.AssertExpectations(t)
mockUsers.AssertExpectations(t)
}
func TestTenantHandler_ListTenantsReturnsTotalMemberCountForDescendants(t *testing.T) {
app := fiber.New()
mockSvc := new(MockTenantService)
@@ -740,6 +802,25 @@ func tenantSlugsMatch(got []domain.Tenant, want ...string) bool {
return true
}
func tenantSummarySlugsMatch(got []tenantSummary, want ...string) bool {
if len(got) != len(want) {
return false
}
counts := make(map[string]int, len(want))
for _, slug := range want {
counts[slug]++
}
for _, tenant := range got {
counts[tenant.Slug]--
}
for _, count := range counts {
if count != 0 {
return false
}
}
return true
}
func TestTenantHandler_GetOrgContextJSONDefaultsToHanmacFamilyForApiKey(t *testing.T) {
app := fiber.New()
mockSvc := new(MockTenantService)