diff --git a/backend/internal/domain/models.go b/backend/internal/domain/models.go index ad86ca7a..248618ea 100644 --- a/backend/internal/domain/models.go +++ b/backend/internal/domain/models.go @@ -10,6 +10,7 @@ type AuditLog struct { EventID string `json:"event_id"` Timestamp time.Time `json:"timestamp"` UserID string `json:"user_id"` + TenantID string `json:"tenant_id,omitempty"` SessionID string `json:"session_id,omitempty"` EventType string `json:"event_type"` // e.g., "login_success", "login_failed", "otp_sent" Status string `json:"status"` // e.g., "success", "failure" @@ -23,7 +24,7 @@ type AuditLog struct { // AuditRepository defines interface for storing logs type AuditRepository interface { Create(log *AuditLog) error - FindPage(ctx context.Context, limit int, cursor *AuditCursor) ([]AuditLog, error) + FindPage(ctx context.Context, limit int, cursor *AuditCursor, tenantID string) ([]AuditLog, error) FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]AuditLog, error) CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error) CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error) diff --git a/backend/internal/handler/audit_handler.go b/backend/internal/handler/audit_handler.go index be968edd..e2cf8b35 100644 --- a/backend/internal/handler/audit_handler.go +++ b/backend/internal/handler/audit_handler.go @@ -58,6 +58,8 @@ func (h *AuditHandler) CreateLog(c *fiber.Ctx) error { func (h *AuditHandler) ListLogs(c *fiber.Ctx) error { limit := c.QueryInt("limit", 50) cursorRaw := c.Query("cursor") + requestedTenantID := c.Query("tenantId") + cursor, err := parseAuditCursor(cursorRaw) if err != nil { return errorJSON(c, fiber.StatusBadRequest, "Invalid cursor") @@ -67,7 +69,41 @@ func (h *AuditHandler) ListLogs(c *fiber.Ctx) error { return errorJSON(c, fiber.StatusServiceUnavailable, "Audit service unavailable") } - logs, err := h.repo.FindPage(c.Context(), limit+1, cursor) + // [New] Role-based Filtering + profile, _ := c.Locals("user_profile").(*domain.UserProfileResponse) + var filterTenantID string + + if profile != nil { + if profile.Role == domain.RoleSuperAdmin { + // Super Admin can see everything or filter by a specific tenant if requested + filterTenantID = requestedTenantID + } else if profile.Role == domain.RoleTenantAdmin { + // Tenant Admin can only see their own tenant logs (or manageable ones) + // For now, lock to their primary tenant or requested one IF it's in their manageable list + if profile.TenantID != nil { + filterTenantID = *profile.TenantID + } + + // If they requested a specific tenant, verify they can manage it + if requestedTenantID != "" && requestedTenantID != filterTenantID { + canManage := false + for _, t := range profile.ManageableTenants { + if t.ID == requestedTenantID { + canManage = true + break + } + } + if !canManage { + return errorJSON(c, fiber.StatusForbidden, "forbidden: cannot view logs for this tenant") + } + filterTenantID = requestedTenantID + } + } else { + return errorJSON(c, fiber.StatusForbidden, "forbidden") + } + } + + logs, err := h.repo.FindPage(c.Context(), limit+1, cursor, filterTenantID) if err != nil { return errorJSON(c, fiber.StatusInternalServerError, "Failed to retrieve audit logs") } diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 47354d0f..c24a9cf7 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -3065,7 +3065,7 @@ func (h *AuthHandler) GetAuthTimeline(c *fiber.Ctx) error { currentCursor := cursor const maxBatches = 10 for batch := 0; batch < maxBatches && len(authLogs) < fetchLimit; batch++ { - logs, err := h.AuditRepo.FindPage(c.Context(), fetchLimit, currentCursor) + logs, err := h.AuditRepo.FindPage(c.Context(), fetchLimit, currentCursor, "") if err != nil { return errorJSON(c, fiber.StatusInternalServerError, "Failed to retrieve audit logs") } diff --git a/backend/internal/handler/common_test.go b/backend/internal/handler/common_test.go index 656749d5..73cbdc73 100644 --- a/backend/internal/handler/common_test.go +++ b/backend/internal/handler/common_test.go @@ -84,7 +84,7 @@ func (m *mockAuditRepo) Create(log *domain.AuditLog) error { return nil } -func (m *mockAuditRepo) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor) ([]domain.AuditLog, error) { +func (m *mockAuditRepo) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor, tenantID string) ([]domain.AuditLog, error) { return m.logs, nil } diff --git a/backend/internal/handler/dev_handler.go b/backend/internal/handler/dev_handler.go index 91f80b10..bbd5f2b0 100644 --- a/backend/internal/handler/dev_handler.go +++ b/backend/internal/handler/dev_handler.go @@ -1253,7 +1253,7 @@ func (h *DevHandler) ListAuditLogs(c *fiber.Ctx) error { const maxScan = 3000 for len(collected) < limit+1 && scanned < maxScan { - page, findErr := h.AuditRepo.FindPage(c.Context(), pageSize, nextCursor) + page, findErr := h.AuditRepo.FindPage(c.Context(), pageSize, nextCursor, tenantFilter) if findErr != nil { return errorJSON(c, fiber.StatusInternalServerError, "Failed to retrieve audit logs") } diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index e230d101..98fac67f 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -70,10 +70,8 @@ type userListResponse struct { func (h *UserHandler) ListUsers(c *fiber.Ctx) error { // [New] Get requester profile from middleware var requesterRole string - var requesterCompany string if profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse); ok { requesterRole = profile.Role - requesterCompany = profile.CompanyCode } limit := c.QueryInt("limit", 50) @@ -88,6 +86,21 @@ func (h *UserHandler) ListUsers(c *fiber.Ctx) error { offset = 0 } + // [New] Manageable Tenants Map for efficient lookup + manageableSlugs := make(map[string]bool) + if requesterRole == domain.RoleTenantAdmin { + profile, _ := c.Locals("user_profile").(*domain.UserProfileResponse) + if profile != nil { + for _, t := range profile.ManageableTenants { + manageableSlugs[strings.ToLower(t.Slug)] = true + } + // Include primary tenant slug if not already there + if profile.CompanyCode != "" { + manageableSlugs[strings.ToLower(profile.CompanyCode)] = true + } + } + } + // 1. Try Kratos First identities, err := h.KratosAdmin.ListIdentities(c.Context()) if err == nil { @@ -97,11 +110,11 @@ func (h *UserHandler) ListUsers(c *fiber.Ctx) error { for _, identity := range identities { email := strings.ToLower(extractTraitString(identity.Traits, "email")) name := strings.ToLower(extractTraitString(identity.Traits, "name")) - compCode := extractTraitString(identity.Traits, "companyCode") + compCode := strings.ToLower(extractTraitString(identity.Traits, "companyCode")) // Tenant Admin filtering if requesterRole == domain.RoleTenantAdmin { - if requesterCompany == "" || !strings.EqualFold(compCode, requesterCompany) { + if !manageableSlugs[compCode] { continue } } diff --git a/backend/internal/repository/clickhouse_repo.go b/backend/internal/repository/clickhouse_repo.go index 532cfa55..c9b0d0e4 100644 --- a/backend/internal/repository/clickhouse_repo.go +++ b/backend/internal/repository/clickhouse_repo.go @@ -54,6 +54,7 @@ func NewClickHouseRepository(host string, port int, user, password, db string) ( event_id String, timestamp DateTime DEFAULT now(), user_id String, + tenant_id String, event_type String, status String, ip_address String, @@ -69,6 +70,7 @@ func NewClickHouseRepository(host string, port int, user, password, db string) ( alterQuery := ` ALTER TABLE audit_logs + ADD COLUMN IF NOT EXISTS tenant_id String, ADD COLUMN IF NOT EXISTS event_id String ` if err := conn.Exec(context.Background(), alterQuery); err != nil { @@ -87,13 +89,14 @@ func (r *ClickHouseRepository) Create(log *domain.AuditLog) error { } query := ` - INSERT INTO audit_logs (event_id, timestamp, user_id, event_type, status, ip_address, user_agent, device_id, details) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO audit_logs (event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` return r.conn.Exec(ctx, query, log.EventID, log.Timestamp, log.UserID, + log.TenantID, log.EventType, log.Status, log.IPAddress, @@ -103,18 +106,25 @@ func (r *ClickHouseRepository) Create(log *domain.AuditLog) error { ) } -func (r *ClickHouseRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor) ([]domain.AuditLog, error) { +func (r *ClickHouseRepository) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor, tenantID string) ([]domain.AuditLog, error) { if limit <= 0 { limit = 50 } query := ` - SELECT event_id, timestamp, user_id, event_type, status, ip_address, user_agent, device_id, details + SELECT event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details FROM audit_logs + WHERE 1=1 ` - args := make([]any, 0, 4) + args := make([]any, 0, 5) + + if tenantID != "" { + query += " AND tenant_id = ?" + args = append(args, tenantID) + } + if cursor != nil { query += ` - WHERE (timestamp < ?) OR (timestamp = ? AND event_id < ?) + AND ((timestamp < ?) OR (timestamp = ? AND event_id < ?)) ` args = append(args, cursor.Timestamp, cursor.Timestamp, cursor.EventID) } @@ -137,6 +147,7 @@ func (r *ClickHouseRepository) FindPage(ctx context.Context, limit int, cursor * &log.EventID, &log.Timestamp, &log.UserID, + &log.TenantID, &log.EventType, &log.Status, &log.IPAddress, @@ -156,7 +167,7 @@ func (r *ClickHouseRepository) FindByUserAndEvents(ctx context.Context, userID s limit = 100 } query := ` - SELECT event_id, timestamp, user_id, event_type, status, ip_address, user_agent, device_id, details + SELECT event_id, timestamp, user_id, tenant_id, event_type, status, ip_address, user_agent, device_id, details FROM audit_logs WHERE user_id = ? AND event_type IN (?) ORDER BY timestamp DESC @@ -175,6 +186,7 @@ func (r *ClickHouseRepository) FindByUserAndEvents(ctx context.Context, userID s &log.EventID, &log.Timestamp, &log.UserID, + &log.TenantID, &log.EventType, &log.Status, &log.IPAddress,