1
0
forked from baron/baron-sso

feat: enforce tenant isolation for audit logs and enhance user list filtering for multi-tenant admins

This commit is contained in:
2026-03-04 14:12:39 +09:00
parent 9da97554ce
commit 39b41a4c42
7 changed files with 78 additions and 16 deletions

View File

@@ -10,6 +10,7 @@ type AuditLog struct {
EventID string `json:"event_id"` EventID string `json:"event_id"`
Timestamp time.Time `json:"timestamp"` Timestamp time.Time `json:"timestamp"`
UserID string `json:"user_id"` UserID string `json:"user_id"`
TenantID string `json:"tenant_id,omitempty"`
SessionID string `json:"session_id,omitempty"` SessionID string `json:"session_id,omitempty"`
EventType string `json:"event_type"` // e.g., "login_success", "login_failed", "otp_sent" EventType string `json:"event_type"` // e.g., "login_success", "login_failed", "otp_sent"
Status string `json:"status"` // e.g., "success", "failure" Status string `json:"status"` // e.g., "success", "failure"
@@ -23,7 +24,7 @@ type AuditLog struct {
// AuditRepository defines interface for storing logs // AuditRepository defines interface for storing logs
type AuditRepository interface { type AuditRepository interface {
Create(log *AuditLog) error Create(log *AuditLog) error
FindPage(ctx context.Context, limit int, cursor *AuditCursor) ([]AuditLog, error) FindPage(ctx context.Context, limit int, cursor *AuditCursor, tenantID string) ([]AuditLog, error)
FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]AuditLog, error) FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]AuditLog, error)
CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error) CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error)
CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error) CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error)

View File

@@ -58,6 +58,8 @@ func (h *AuditHandler) CreateLog(c *fiber.Ctx) error {
func (h *AuditHandler) ListLogs(c *fiber.Ctx) error { func (h *AuditHandler) ListLogs(c *fiber.Ctx) error {
limit := c.QueryInt("limit", 50) limit := c.QueryInt("limit", 50)
cursorRaw := c.Query("cursor") cursorRaw := c.Query("cursor")
requestedTenantID := c.Query("tenantId")
cursor, err := parseAuditCursor(cursorRaw) cursor, err := parseAuditCursor(cursorRaw)
if err != nil { if err != nil {
return errorJSON(c, fiber.StatusBadRequest, "Invalid cursor") return errorJSON(c, fiber.StatusBadRequest, "Invalid cursor")
@@ -67,7 +69,41 @@ func (h *AuditHandler) ListLogs(c *fiber.Ctx) error {
return errorJSON(c, fiber.StatusServiceUnavailable, "Audit service unavailable") return errorJSON(c, fiber.StatusServiceUnavailable, "Audit service unavailable")
} }
logs, err := h.repo.FindPage(c.Context(), limit+1, cursor) // [New] Role-based Filtering
profile, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
var filterTenantID string
if profile != nil {
if profile.Role == domain.RoleSuperAdmin {
// Super Admin can see everything or filter by a specific tenant if requested
filterTenantID = requestedTenantID
} else if profile.Role == domain.RoleTenantAdmin {
// Tenant Admin can only see their own tenant logs (or manageable ones)
// For now, lock to their primary tenant or requested one IF it's in their manageable list
if profile.TenantID != nil {
filterTenantID = *profile.TenantID
}
// If they requested a specific tenant, verify they can manage it
if requestedTenantID != "" && requestedTenantID != filterTenantID {
canManage := false
for _, t := range profile.ManageableTenants {
if t.ID == requestedTenantID {
canManage = true
break
}
}
if !canManage {
return errorJSON(c, fiber.StatusForbidden, "forbidden: cannot view logs for this tenant")
}
filterTenantID = requestedTenantID
}
} else {
return errorJSON(c, fiber.StatusForbidden, "forbidden")
}
}
logs, err := h.repo.FindPage(c.Context(), limit+1, cursor, filterTenantID)
if err != nil { if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, "Failed to retrieve audit logs") return errorJSON(c, fiber.StatusInternalServerError, "Failed to retrieve audit logs")
} }

View File

@@ -3065,7 +3065,7 @@ func (h *AuthHandler) GetAuthTimeline(c *fiber.Ctx) error {
currentCursor := cursor currentCursor := cursor
const maxBatches = 10 const maxBatches = 10
for batch := 0; batch < maxBatches && len(authLogs) < fetchLimit; batch++ { for batch := 0; batch < maxBatches && len(authLogs) < fetchLimit; batch++ {
logs, err := h.AuditRepo.FindPage(c.Context(), fetchLimit, currentCursor) logs, err := h.AuditRepo.FindPage(c.Context(), fetchLimit, currentCursor, "")
if err != nil { if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, "Failed to retrieve audit logs") return errorJSON(c, fiber.StatusInternalServerError, "Failed to retrieve audit logs")
} }

View File

@@ -84,7 +84,7 @@ func (m *mockAuditRepo) Create(log *domain.AuditLog) error {
return nil return nil
} }
func (m *mockAuditRepo) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor) ([]domain.AuditLog, error) { func (m *mockAuditRepo) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor, tenantID string) ([]domain.AuditLog, error) {
return m.logs, nil return m.logs, nil
} }

View File

@@ -1253,7 +1253,7 @@ func (h *DevHandler) ListAuditLogs(c *fiber.Ctx) error {
const maxScan = 3000 const maxScan = 3000
for len(collected) < limit+1 && scanned < maxScan { for len(collected) < limit+1 && scanned < maxScan {
page, findErr := h.AuditRepo.FindPage(c.Context(), pageSize, nextCursor) page, findErr := h.AuditRepo.FindPage(c.Context(), pageSize, nextCursor, tenantFilter)
if findErr != nil { if findErr != nil {
return errorJSON(c, fiber.StatusInternalServerError, "Failed to retrieve audit logs") return errorJSON(c, fiber.StatusInternalServerError, "Failed to retrieve audit logs")
} }

View File

@@ -70,10 +70,8 @@ type userListResponse struct {
func (h *UserHandler) ListUsers(c *fiber.Ctx) error { func (h *UserHandler) ListUsers(c *fiber.Ctx) error {
// [New] Get requester profile from middleware // [New] Get requester profile from middleware
var requesterRole string var requesterRole string
var requesterCompany string
if profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse); ok { if profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse); ok {
requesterRole = profile.Role requesterRole = profile.Role
requesterCompany = profile.CompanyCode
} }
limit := c.QueryInt("limit", 50) limit := c.QueryInt("limit", 50)
@@ -88,6 +86,21 @@ func (h *UserHandler) ListUsers(c *fiber.Ctx) error {
offset = 0 offset = 0
} }
// [New] Manageable Tenants Map for efficient lookup
manageableSlugs := make(map[string]bool)
if requesterRole == domain.RoleTenantAdmin {
profile, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
if profile != nil {
for _, t := range profile.ManageableTenants {
manageableSlugs[strings.ToLower(t.Slug)] = true
}
// Include primary tenant slug if not already there
if profile.CompanyCode != "" {
manageableSlugs[strings.ToLower(profile.CompanyCode)] = true
}
}
}
// 1. Try Kratos First // 1. Try Kratos First
identities, err := h.KratosAdmin.ListIdentities(c.Context()) identities, err := h.KratosAdmin.ListIdentities(c.Context())
if err == nil { if err == nil {
@@ -97,11 +110,11 @@ func (h *UserHandler) ListUsers(c *fiber.Ctx) error {
for _, identity := range identities { for _, identity := range identities {
email := strings.ToLower(extractTraitString(identity.Traits, "email")) email := strings.ToLower(extractTraitString(identity.Traits, "email"))
name := strings.ToLower(extractTraitString(identity.Traits, "name")) name := strings.ToLower(extractTraitString(identity.Traits, "name"))
compCode := extractTraitString(identity.Traits, "companyCode") compCode := strings.ToLower(extractTraitString(identity.Traits, "companyCode"))
// Tenant Admin filtering // Tenant Admin filtering
if requesterRole == domain.RoleTenantAdmin { if requesterRole == domain.RoleTenantAdmin {
if requesterCompany == "" || !strings.EqualFold(compCode, requesterCompany) { if !manageableSlugs[compCode] {
continue continue
} }
} }

View File

@@ -54,6 +54,7 @@ func NewClickHouseRepository(host string, port int, user, password, db string) (
event_id String, event_id String,
timestamp DateTime DEFAULT now(), timestamp DateTime DEFAULT now(),
user_id String, user_id String,
tenant_id String,
event_type String, event_type String,
status String, status String,
ip_address String, ip_address String,
@@ -69,6 +70,7 @@ func NewClickHouseRepository(host string, port int, user, password, db string) (
alterQuery := ` alterQuery := `
ALTER TABLE audit_logs ALTER TABLE audit_logs
ADD COLUMN IF NOT EXISTS tenant_id String,
ADD COLUMN IF NOT EXISTS event_id String ADD COLUMN IF NOT EXISTS event_id String
` `
if err := conn.Exec(context.Background(), alterQuery); err != nil { if err := conn.Exec(context.Background(), alterQuery); err != nil {
@@ -87,13 +89,14 @@ func (r *ClickHouseRepository) Create(log *domain.AuditLog) error {
} }
query := ` query := `
INSERT INTO audit_logs (event_id, timestamp, user_id, event_type, status, ip_address, user_agent, device_id, details) INSERT INTO audit_logs (event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
return r.conn.Exec(ctx, query, return r.conn.Exec(ctx, query,
log.EventID, log.EventID,
log.Timestamp, log.Timestamp,
log.UserID, log.UserID,
log.TenantID,
log.EventType, log.EventType,
log.Status, log.Status,
log.IPAddress, log.IPAddress,
@@ -103,18 +106,25 @@ func (r *ClickHouseRepository) Create(log *domain.AuditLog) error {
) )
} }
func (r *ClickHouseRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor) ([]domain.AuditLog, error) { func (r *ClickHouseRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor, tenantID string) ([]domain.AuditLog, error) {
if limit <= 0 { if limit <= 0 {
limit = 50 limit = 50
} }
query := ` query := `
SELECT event_id, timestamp, user_id, event_type, status, ip_address, user_agent, device_id, details SELECT event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details
FROM audit_logs FROM audit_logs
WHERE 1=1
` `
args := make([]any, 0, 4) args := make([]any, 0, 5)
if tenantID != "" {
query += " AND tenant_id = ?"
args = append(args, tenantID)
}
if cursor != nil { if cursor != nil {
query += ` query += `
WHERE (timestamp < ?) OR (timestamp = ? AND event_id < ?) AND ((timestamp < ?) OR (timestamp = ? AND event_id < ?))
` `
args = append(args, cursor.Timestamp, cursor.Timestamp, cursor.EventID) args = append(args, cursor.Timestamp, cursor.Timestamp, cursor.EventID)
} }
@@ -137,6 +147,7 @@ func (r *ClickHouseRepository) FindPage(ctx context.Context, limit int, cursor *
&log.EventID, &log.EventID,
&log.Timestamp, &log.Timestamp,
&log.UserID, &log.UserID,
&log.TenantID,
&log.EventType, &log.EventType,
&log.Status, &log.Status,
&log.IPAddress, &log.IPAddress,
@@ -156,7 +167,7 @@ func (r *ClickHouseRepository) FindByUserAndEvents(ctx context.Context, userID s
limit = 100 limit = 100
} }
query := ` query := `
SELECT event_id, timestamp, user_id, event_type, status, ip_address, user_agent, device_id, details SELECT event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details
FROM audit_logs FROM audit_logs
WHERE user_id = ? AND event_type IN (?) WHERE user_id = ? AND event_type IN (?)
ORDER BY timestamp DESC ORDER BY timestamp DESC
@@ -175,6 +186,7 @@ func (r *ClickHouseRepository) FindByUserAndEvents(ctx context.Context, userID s
&log.EventID, &log.EventID,
&log.Timestamp, &log.Timestamp,
&log.UserID, &log.UserID,
&log.TenantID,
&log.EventType, &log.EventType,
&log.Status, &log.Status,
&log.IPAddress, &log.IPAddress,