forked from baron/baron-sso
feat: enforce tenant isolation for audit logs and enhance user list filtering for multi-tenant admins
This commit is contained in:
@@ -10,6 +10,7 @@ type AuditLog struct {
|
|||||||
EventID string `json:"event_id"`
|
EventID string `json:"event_id"`
|
||||||
Timestamp time.Time `json:"timestamp"`
|
Timestamp time.Time `json:"timestamp"`
|
||||||
UserID string `json:"user_id"`
|
UserID string `json:"user_id"`
|
||||||
|
TenantID string `json:"tenant_id,omitempty"`
|
||||||
SessionID string `json:"session_id,omitempty"`
|
SessionID string `json:"session_id,omitempty"`
|
||||||
EventType string `json:"event_type"` // e.g., "login_success", "login_failed", "otp_sent"
|
EventType string `json:"event_type"` // e.g., "login_success", "login_failed", "otp_sent"
|
||||||
Status string `json:"status"` // e.g., "success", "failure"
|
Status string `json:"status"` // e.g., "success", "failure"
|
||||||
@@ -23,7 +24,7 @@ type AuditLog struct {
|
|||||||
// AuditRepository defines interface for storing logs
|
// AuditRepository defines interface for storing logs
|
||||||
type AuditRepository interface {
|
type AuditRepository interface {
|
||||||
Create(log *AuditLog) error
|
Create(log *AuditLog) error
|
||||||
FindPage(ctx context.Context, limit int, cursor *AuditCursor) ([]AuditLog, error)
|
FindPage(ctx context.Context, limit int, cursor *AuditCursor, tenantID string) ([]AuditLog, error)
|
||||||
FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]AuditLog, error)
|
FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]AuditLog, error)
|
||||||
CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error)
|
CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error)
|
||||||
CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error)
|
CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error)
|
||||||
|
|||||||
@@ -58,6 +58,8 @@ func (h *AuditHandler) CreateLog(c *fiber.Ctx) error {
|
|||||||
func (h *AuditHandler) ListLogs(c *fiber.Ctx) error {
|
func (h *AuditHandler) ListLogs(c *fiber.Ctx) error {
|
||||||
limit := c.QueryInt("limit", 50)
|
limit := c.QueryInt("limit", 50)
|
||||||
cursorRaw := c.Query("cursor")
|
cursorRaw := c.Query("cursor")
|
||||||
|
requestedTenantID := c.Query("tenantId")
|
||||||
|
|
||||||
cursor, err := parseAuditCursor(cursorRaw)
|
cursor, err := parseAuditCursor(cursorRaw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorJSON(c, fiber.StatusBadRequest, "Invalid cursor")
|
return errorJSON(c, fiber.StatusBadRequest, "Invalid cursor")
|
||||||
@@ -67,7 +69,41 @@ func (h *AuditHandler) ListLogs(c *fiber.Ctx) error {
|
|||||||
return errorJSON(c, fiber.StatusServiceUnavailable, "Audit service unavailable")
|
return errorJSON(c, fiber.StatusServiceUnavailable, "Audit service unavailable")
|
||||||
}
|
}
|
||||||
|
|
||||||
logs, err := h.repo.FindPage(c.Context(), limit+1, cursor)
|
// [New] Role-based Filtering
|
||||||
|
profile, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
|
||||||
|
var filterTenantID string
|
||||||
|
|
||||||
|
if profile != nil {
|
||||||
|
if profile.Role == domain.RoleSuperAdmin {
|
||||||
|
// Super Admin can see everything or filter by a specific tenant if requested
|
||||||
|
filterTenantID = requestedTenantID
|
||||||
|
} else if profile.Role == domain.RoleTenantAdmin {
|
||||||
|
// Tenant Admin can only see their own tenant logs (or manageable ones)
|
||||||
|
// For now, lock to their primary tenant or requested one IF it's in their manageable list
|
||||||
|
if profile.TenantID != nil {
|
||||||
|
filterTenantID = *profile.TenantID
|
||||||
|
}
|
||||||
|
|
||||||
|
// If they requested a specific tenant, verify they can manage it
|
||||||
|
if requestedTenantID != "" && requestedTenantID != filterTenantID {
|
||||||
|
canManage := false
|
||||||
|
for _, t := range profile.ManageableTenants {
|
||||||
|
if t.ID == requestedTenantID {
|
||||||
|
canManage = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !canManage {
|
||||||
|
return errorJSON(c, fiber.StatusForbidden, "forbidden: cannot view logs for this tenant")
|
||||||
|
}
|
||||||
|
filterTenantID = requestedTenantID
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return errorJSON(c, fiber.StatusForbidden, "forbidden")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logs, err := h.repo.FindPage(c.Context(), limit+1, cursor, filterTenantID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorJSON(c, fiber.StatusInternalServerError, "Failed to retrieve audit logs")
|
return errorJSON(c, fiber.StatusInternalServerError, "Failed to retrieve audit logs")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3065,7 +3065,7 @@ func (h *AuthHandler) GetAuthTimeline(c *fiber.Ctx) error {
|
|||||||
currentCursor := cursor
|
currentCursor := cursor
|
||||||
const maxBatches = 10
|
const maxBatches = 10
|
||||||
for batch := 0; batch < maxBatches && len(authLogs) < fetchLimit; batch++ {
|
for batch := 0; batch < maxBatches && len(authLogs) < fetchLimit; batch++ {
|
||||||
logs, err := h.AuditRepo.FindPage(c.Context(), fetchLimit, currentCursor)
|
logs, err := h.AuditRepo.FindPage(c.Context(), fetchLimit, currentCursor, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorJSON(c, fiber.StatusInternalServerError, "Failed to retrieve audit logs")
|
return errorJSON(c, fiber.StatusInternalServerError, "Failed to retrieve audit logs")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ func (m *mockAuditRepo) Create(log *domain.AuditLog) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuditRepo) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor) ([]domain.AuditLog, error) {
|
func (m *mockAuditRepo) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor, tenantID string) ([]domain.AuditLog, error) {
|
||||||
return m.logs, nil
|
return m.logs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1253,7 +1253,7 @@ func (h *DevHandler) ListAuditLogs(c *fiber.Ctx) error {
|
|||||||
const maxScan = 3000
|
const maxScan = 3000
|
||||||
|
|
||||||
for len(collected) < limit+1 && scanned < maxScan {
|
for len(collected) < limit+1 && scanned < maxScan {
|
||||||
page, findErr := h.AuditRepo.FindPage(c.Context(), pageSize, nextCursor)
|
page, findErr := h.AuditRepo.FindPage(c.Context(), pageSize, nextCursor, tenantFilter)
|
||||||
if findErr != nil {
|
if findErr != nil {
|
||||||
return errorJSON(c, fiber.StatusInternalServerError, "Failed to retrieve audit logs")
|
return errorJSON(c, fiber.StatusInternalServerError, "Failed to retrieve audit logs")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -70,10 +70,8 @@ type userListResponse struct {
|
|||||||
func (h *UserHandler) ListUsers(c *fiber.Ctx) error {
|
func (h *UserHandler) ListUsers(c *fiber.Ctx) error {
|
||||||
// [New] Get requester profile from middleware
|
// [New] Get requester profile from middleware
|
||||||
var requesterRole string
|
var requesterRole string
|
||||||
var requesterCompany string
|
|
||||||
if profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse); ok {
|
if profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse); ok {
|
||||||
requesterRole = profile.Role
|
requesterRole = profile.Role
|
||||||
requesterCompany = profile.CompanyCode
|
|
||||||
}
|
}
|
||||||
|
|
||||||
limit := c.QueryInt("limit", 50)
|
limit := c.QueryInt("limit", 50)
|
||||||
@@ -88,6 +86,21 @@ func (h *UserHandler) ListUsers(c *fiber.Ctx) error {
|
|||||||
offset = 0
|
offset = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// [New] Manageable Tenants Map for efficient lookup
|
||||||
|
manageableSlugs := make(map[string]bool)
|
||||||
|
if requesterRole == domain.RoleTenantAdmin {
|
||||||
|
profile, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
|
||||||
|
if profile != nil {
|
||||||
|
for _, t := range profile.ManageableTenants {
|
||||||
|
manageableSlugs[strings.ToLower(t.Slug)] = true
|
||||||
|
}
|
||||||
|
// Include primary tenant slug if not already there
|
||||||
|
if profile.CompanyCode != "" {
|
||||||
|
manageableSlugs[strings.ToLower(profile.CompanyCode)] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 1. Try Kratos First
|
// 1. Try Kratos First
|
||||||
identities, err := h.KratosAdmin.ListIdentities(c.Context())
|
identities, err := h.KratosAdmin.ListIdentities(c.Context())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -97,11 +110,11 @@ func (h *UserHandler) ListUsers(c *fiber.Ctx) error {
|
|||||||
for _, identity := range identities {
|
for _, identity := range identities {
|
||||||
email := strings.ToLower(extractTraitString(identity.Traits, "email"))
|
email := strings.ToLower(extractTraitString(identity.Traits, "email"))
|
||||||
name := strings.ToLower(extractTraitString(identity.Traits, "name"))
|
name := strings.ToLower(extractTraitString(identity.Traits, "name"))
|
||||||
compCode := extractTraitString(identity.Traits, "companyCode")
|
compCode := strings.ToLower(extractTraitString(identity.Traits, "companyCode"))
|
||||||
|
|
||||||
// Tenant Admin filtering
|
// Tenant Admin filtering
|
||||||
if requesterRole == domain.RoleTenantAdmin {
|
if requesterRole == domain.RoleTenantAdmin {
|
||||||
if requesterCompany == "" || !strings.EqualFold(compCode, requesterCompany) {
|
if !manageableSlugs[compCode] {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ func NewClickHouseRepository(host string, port int, user, password, db string) (
|
|||||||
event_id String,
|
event_id String,
|
||||||
timestamp DateTime DEFAULT now(),
|
timestamp DateTime DEFAULT now(),
|
||||||
user_id String,
|
user_id String,
|
||||||
|
tenant_id String,
|
||||||
event_type String,
|
event_type String,
|
||||||
status String,
|
status String,
|
||||||
ip_address String,
|
ip_address String,
|
||||||
@@ -69,6 +70,7 @@ func NewClickHouseRepository(host string, port int, user, password, db string) (
|
|||||||
|
|
||||||
alterQuery := `
|
alterQuery := `
|
||||||
ALTER TABLE audit_logs
|
ALTER TABLE audit_logs
|
||||||
|
ADD COLUMN IF NOT EXISTS tenant_id String,
|
||||||
ADD COLUMN IF NOT EXISTS event_id String
|
ADD COLUMN IF NOT EXISTS event_id String
|
||||||
`
|
`
|
||||||
if err := conn.Exec(context.Background(), alterQuery); err != nil {
|
if err := conn.Exec(context.Background(), alterQuery); err != nil {
|
||||||
@@ -87,13 +89,14 @@ func (r *ClickHouseRepository) Create(log *domain.AuditLog) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO audit_logs (event_id, timestamp, user_id, event_type, status, ip_address, user_agent, device_id, details)
|
INSERT INTO audit_logs (event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
`
|
`
|
||||||
return r.conn.Exec(ctx, query,
|
return r.conn.Exec(ctx, query,
|
||||||
log.EventID,
|
log.EventID,
|
||||||
log.Timestamp,
|
log.Timestamp,
|
||||||
log.UserID,
|
log.UserID,
|
||||||
|
log.TenantID,
|
||||||
log.EventType,
|
log.EventType,
|
||||||
log.Status,
|
log.Status,
|
||||||
log.IPAddress,
|
log.IPAddress,
|
||||||
@@ -103,18 +106,25 @@ func (r *ClickHouseRepository) Create(log *domain.AuditLog) error {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ClickHouseRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor) ([]domain.AuditLog, error) {
|
func (r *ClickHouseRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor, tenantID string) ([]domain.AuditLog, error) {
|
||||||
if limit <= 0 {
|
if limit <= 0 {
|
||||||
limit = 50
|
limit = 50
|
||||||
}
|
}
|
||||||
query := `
|
query := `
|
||||||
SELECT event_id, timestamp, user_id, event_type, status, ip_address, user_agent, device_id, details
|
SELECT event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details
|
||||||
FROM audit_logs
|
FROM audit_logs
|
||||||
|
WHERE 1=1
|
||||||
`
|
`
|
||||||
args := make([]any, 0, 4)
|
args := make([]any, 0, 5)
|
||||||
|
|
||||||
|
if tenantID != "" {
|
||||||
|
query += " AND tenant_id = ?"
|
||||||
|
args = append(args, tenantID)
|
||||||
|
}
|
||||||
|
|
||||||
if cursor != nil {
|
if cursor != nil {
|
||||||
query += `
|
query += `
|
||||||
WHERE (timestamp < ?) OR (timestamp = ? AND event_id < ?)
|
AND ((timestamp < ?) OR (timestamp = ? AND event_id < ?))
|
||||||
`
|
`
|
||||||
args = append(args, cursor.Timestamp, cursor.Timestamp, cursor.EventID)
|
args = append(args, cursor.Timestamp, cursor.Timestamp, cursor.EventID)
|
||||||
}
|
}
|
||||||
@@ -137,6 +147,7 @@ func (r *ClickHouseRepository) FindPage(ctx context.Context, limit int, cursor *
|
|||||||
&log.EventID,
|
&log.EventID,
|
||||||
&log.Timestamp,
|
&log.Timestamp,
|
||||||
&log.UserID,
|
&log.UserID,
|
||||||
|
&log.TenantID,
|
||||||
&log.EventType,
|
&log.EventType,
|
||||||
&log.Status,
|
&log.Status,
|
||||||
&log.IPAddress,
|
&log.IPAddress,
|
||||||
@@ -156,7 +167,7 @@ func (r *ClickHouseRepository) FindByUserAndEvents(ctx context.Context, userID s
|
|||||||
limit = 100
|
limit = 100
|
||||||
}
|
}
|
||||||
query := `
|
query := `
|
||||||
SELECT event_id, timestamp, user_id, event_type, status, ip_address, user_agent, device_id, details
|
SELECT event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details
|
||||||
FROM audit_logs
|
FROM audit_logs
|
||||||
WHERE user_id = ? AND event_type IN (?)
|
WHERE user_id = ? AND event_type IN (?)
|
||||||
ORDER BY timestamp DESC
|
ORDER BY timestamp DESC
|
||||||
@@ -175,6 +186,7 @@ func (r *ClickHouseRepository) FindByUserAndEvents(ctx context.Context, userID s
|
|||||||
&log.EventID,
|
&log.EventID,
|
||||||
&log.Timestamp,
|
&log.Timestamp,
|
||||||
&log.UserID,
|
&log.UserID,
|
||||||
|
&log.TenantID,
|
||||||
&log.EventType,
|
&log.EventType,
|
||||||
&log.Status,
|
&log.Status,
|
||||||
&log.IPAddress,
|
&log.IPAddress,
|
||||||
|
|||||||
Reference in New Issue
Block a user