diff --git a/Cargo.lock b/Cargo.lock index 61e5625..b338c23 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1191,7 +1191,7 @@ dependencies = [ [[package]] name = "kcserver" -version = "0.1.62" +version = "0.1.63" dependencies = [ "anyhow", "async-channel", diff --git a/crates/kcserver/Cargo.toml b/crates/kcserver/Cargo.toml index ec7e25f..3ee274a 100644 --- a/crates/kcserver/Cargo.toml +++ b/crates/kcserver/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "kcserver" -version = "0.1.62" +version = "0.1.63" rust-version.workspace = true edition.workspace = true license.workspace = true @@ -76,5 +76,6 @@ features = [ "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Threading", + "Win32_System_Diagnostics_ToolHelp", "Win32_UI_WindowsAndMessaging" ] diff --git a/crates/kcserver/src/lib.rs b/crates/kcserver/src/lib.rs index a676a77..3b660c0 100644 --- a/crates/kcserver/src/lib.rs +++ b/crates/kcserver/src/lib.rs @@ -18,6 +18,9 @@ pub mod jupyter_messages; pub mod kernel_connection; pub mod kernel_session; pub mod kernel_state; +#[cfg(target_os = "linux")] +pub mod proc_stat; +pub mod process_tree; pub mod registration_file; pub mod registration_socket; pub mod resource_monitor; diff --git a/crates/kcserver/src/main.rs b/crates/kcserver/src/main.rs index e02bac7..bd03e0b 100644 --- a/crates/kcserver/src/main.rs +++ b/crates/kcserver/src/main.rs @@ -26,6 +26,9 @@ mod jupyter_messages; mod kernel_connection; mod kernel_session; mod kernel_state; +#[cfg(target_os = "linux")] +mod proc_stat; +mod process_tree; mod registration_file; mod registration_socket; mod resource_monitor; @@ -378,7 +381,7 @@ async fn main() { | \ / | |/ |/ | / |/ \ / \_/ | |/ | \_/\_/|_/|__/|__/|_/\___/| |_/\__/ |_/|__/ A Jupyter Kernel supervisor. Version {}. - Copyright (c) 2025, Posit Software PBC. All rights reserved. + Copyright (c) 2026, Posit Software PBC. All rights reserved. "#, env!("CARGO_PKG_VERSION") ); diff --git a/crates/kcserver/src/proc_stat.rs b/crates/kcserver/src/proc_stat.rs new file mode 100644 index 0000000..2dacc69 --- /dev/null +++ b/crates/kcserver/src/proc_stat.rs @@ -0,0 +1,145 @@ +// +// proc_stat.rs +// +// Copyright (C) 2026 Posit Software, PBC. All rights reserved. +// Licensed under the Elastic License 2.0. See LICENSE.txt for license information. +// + +//! Linux-specific utilities for parsing /proc filesystem. +//! +//! This module provides shared utilities for parsing /proc/[pid]/stat files, +//! used by both the process tree enumeration and CPU tracking code. + +#![cfg(target_os = "linux")] + +use std::fs; + +/// Parsed fields from /proc/[pid]/stat +#[derive(Debug, Clone)] +pub struct ProcStat { + /// Parent process ID (field 4, index 1 after comm) + pub ppid: u32, + /// Process group ID (field 5, index 2 after comm) + pub pgid: u32, + /// User mode CPU time in jiffies (field 14, index 11 after comm) + pub utime: u64, + /// Kernel mode CPU time in jiffies (field 15, index 12 after comm) + pub stime: u64, +} + +impl ProcStat { + /// Total CPU time (utime + stime) + pub fn cpu_time(&self) -> u64 { + self.utime + self.stime + } +} + +/// Parse /proc/[pid]/stat to extract process information. +/// +/// The stat file format is: `pid (comm) state ppid pgrp session tty_nr tpgid flags +/// minflt cminflt majflt cmajflt utime stime cutime cstime ...` +/// +/// Note: `comm` can contain spaces and parentheses, so we find the last ')' to +/// reliably parse the remaining fields. +pub fn parse_proc_stat(pid: u32) -> Option { + let stat_path = format!("/proc/{}/stat", pid); + let stat_content = fs::read_to_string(stat_path).ok()?; + + // comm can contain spaces and parens, so find the last ')' + let last_paren = stat_content.rfind(')')?; + let fields_after_comm = stat_content.get(last_paren + 2..)?; // Skip ") " + let fields: Vec<&str> = fields_after_comm.split_whitespace().collect(); + + // fields[0] = state, fields[1] = ppid, fields[2] = pgrp, ... + // fields[11] = utime, fields[12] = stime + if fields.len() < 13 { + return None; + } + + Some(ProcStat { + ppid: fields[1].parse().ok()?, + pgid: fields[2].parse().ok()?, + utime: fields[11].parse().ok()?, + stime: fields[12].parse().ok()?, + }) +} + +/// Read total CPU time from /proc/stat (sum of all jiffies across all CPUs). +/// +/// The first line of /proc/stat is: +/// `cpu user nice system idle iowait irq softirq steal guest guest_nice` +/// +/// We sum all these values to get total CPU time. +pub fn read_total_cpu_time() -> u64 { + let Ok(content) = fs::read_to_string("/proc/stat") else { + return 0; + }; + + let Some(cpu_line) = content.lines().next() else { + return 0; + }; + + if !cpu_line.starts_with("cpu ") { + return 0; + } + + // Sum all the values (skip "cpu" label) + cpu_line + .split_whitespace() + .skip(1) + .filter_map(|s| s.parse::().ok()) + .sum() +} + +/// Count the number of CPUs by counting cpu[N] lines in /proc/stat. +pub fn count_cpus() -> usize { + let Ok(content) = fs::read_to_string("/proc/stat") else { + return 1; + }; + + // Count lines starting with "cpu" followed by a digit (cpu0, cpu1, etc.) + content + .lines() + .filter(|line| { + line.starts_with("cpu") + && line + .chars() + .nth(3) + .map(|c| c.is_ascii_digit()) + .unwrap_or(false) + }) + .count() + .max(1) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_read_total_cpu_time() { + // Should return a non-zero value on a real Linux system + let cpu_time = read_total_cpu_time(); + // Just verify it doesn't panic and returns something reasonable + assert!(cpu_time > 0 || cfg!(not(target_os = "linux"))); + } + + #[test] + fn test_count_cpus() { + let cpus = count_cpus(); + assert!(cpus >= 1); + } + + #[test] + fn test_parse_proc_stat_current_process() { + // Parse our own process's stat + let pid = std::process::id(); + if let Some(stat) = parse_proc_stat(pid) { + // Our parent PID should be non-zero + assert!(stat.ppid > 0); + // PGID should be set + assert!(stat.pgid > 0); + } + // It's OK if this fails on non-Linux or in restricted environments + } +} diff --git a/crates/kcserver/src/process_tree.rs b/crates/kcserver/src/process_tree.rs new file mode 100644 index 0000000..52d305a --- /dev/null +++ b/crates/kcserver/src/process_tree.rs @@ -0,0 +1,439 @@ +// +// process_tree.rs +// +// Copyright (C) 2026 Posit Software, PBC. All rights reserved. +// Licensed under the Elastic License 2.0. See LICENSE.txt for license information. +// + +//! OS-specific efficient process tree enumeration. +//! +//! This module provides efficient ways to enumerate child processes of a given PID +//! without scanning the entire process table on the system. + +use std::collections::HashSet; + +/// How often to refresh the cache (in ticks) +/// This is used on platforms that cache process trees (Linux, Windows) +/// to limit the frequency of full process table scans. +#[cfg(any(target_os = "linux", target_os = "windows"))] +const CACHE_REFRESH_INTERVAL: u32 = 5; + +/// Get all descendant PIDs of the given root PID. +/// +/// This function returns a set containing the root PID and all its descendants. +/// The implementation is OS-specific for efficiency: +/// - macOS: Uses `proc_listchildpids()` to directly query child processes +/// - Linux: Enumerates only processes with the same PGID as the root +/// - Windows: Uses cached process tree with periodic full scans +pub fn get_process_tree(root_pid: u32) -> HashSet { + #[cfg(target_os = "macos")] + { + macos::get_process_tree(root_pid) + } + + #[cfg(target_os = "linux")] + { + linux::get_process_tree(root_pid) + } + + #[cfg(target_os = "windows")] + { + windows::get_process_tree(root_pid) + } +} + +/// Notify the process tree cache that a tick has occurred. +/// This triggers periodic cache refresh on platforms that use caching. +#[allow(unused_variables)] +pub fn tick_process_cache(root_pid: u32) { + #[cfg(target_os = "linux")] + { + linux::tick_process_cache(root_pid); + } + + #[cfg(target_os = "windows")] + { + windows::tick_process_cache(root_pid); + } +} + +/// Clear the process tree cache for a given root PID. +/// Called when a kernel session is terminated. +#[allow(unused_variables)] +pub fn clear_process_cache(root_pid: u32) { + #[cfg(target_os = "linux")] + { + linux::clear_process_cache(root_pid); + } + + #[cfg(target_os = "windows")] + { + windows::clear_process_cache(root_pid); + } +} + +// ============================================================================= +// macOS implementation using proc_listchildpids() +// ============================================================================= + +#[cfg(target_os = "macos")] +#[allow(unsafe_code)] +mod macos { + use std::collections::HashSet; + + // FFI bindings for libproc + #[link(name = "proc", kind = "dylib")] + extern "C" { + fn proc_listchildpids( + ppid: libc::c_int, + buffer: *mut libc::c_int, + buffersize: libc::c_int, + ) -> libc::c_int; + } + + /// Get child PIDs of a process using proc_listchildpids() + fn get_child_pids(pid: u32) -> Vec { + // First call with null buffer to get the required buffer size in bytes + // SAFETY: proc_listchildpids is a well-documented macOS API that safely handles + // null buffer pointers by returning the required buffer size in bytes. + let bytes_needed = + unsafe { proc_listchildpids(pid as libc::c_int, std::ptr::null_mut(), 0) }; + + if bytes_needed <= 0 { + return Vec::new(); + } + + // Allocate buffer for PIDs (convert bytes to element count) + let count = bytes_needed as usize / size_of::(); + let mut buffer: Vec = vec![0; count]; + + // SAFETY: We've allocated a buffer of sufficient size (as returned by the first call). + // proc_listchildpids writes at most buffersize bytes to the buffer. + let result = + unsafe { proc_listchildpids(pid as libc::c_int, buffer.as_mut_ptr(), bytes_needed) }; + + if result <= 0 { + return Vec::new(); + } + + // Convert to u32 and filter out any zeros + let num_pids = result as usize / size_of::(); + buffer + .into_iter() + .take(num_pids) + .filter(|&pid| pid > 0) + .map(|pid| pid as u32) + .collect() + } + + pub fn get_process_tree(root_pid: u32) -> HashSet { + let mut visited = HashSet::new(); + let mut to_visit = vec![root_pid]; + + while let Some(pid) = to_visit.pop() { + if !visited.insert(pid) { + continue; + } + + // Get children of this process + let children = get_child_pids(pid); + for child in children { + if !visited.contains(&child) { + to_visit.push(child); + } + } + } + + visited + } +} + +// ============================================================================= +// Linux implementation using PGID filtering with caching +// ============================================================================= + +#[cfg(target_os = "linux")] +mod linux { + use std::collections::{HashMap, HashSet}; + use std::fs; + use std::sync::Mutex; + + use once_cell::sync::Lazy; + + use super::CACHE_REFRESH_INTERVAL; + use crate::proc_stat; + + /// Global cache shared across all kernels + struct GlobalCache { + /// Parent map from the last /proc scan (pid -> ppid) + parent_map: HashMap, + /// PGID map from the last /proc scan (pid -> pgid) + pgid_map: HashMap, + /// Per-kernel cached process trees + kernel_caches: HashMap>, + /// Global tick counter (all kernels share the same clock) + tick_count: u32, + } + + static GLOBAL_CACHE: Lazy> = Lazy::new(|| { + Mutex::new(GlobalCache { + parent_map: HashMap::new(), + pgid_map: HashMap::new(), + kernel_caches: HashMap::new(), + tick_count: 0, + }) + }); + + /// Scan /proc once and build parent_map and pgid_map for all processes + fn scan_proc() -> (HashMap, HashMap) { + let mut parent_map = HashMap::new(); + let mut pgid_map = HashMap::new(); + + let proc_dir = match fs::read_dir("/proc") { + Ok(dir) => dir, + Err(_) => return (parent_map, pgid_map), + }; + + for entry in proc_dir.flatten() { + let name = entry.file_name(); + let name_str = name.to_string_lossy(); + + // Only look at numeric directories (PIDs) + if let Ok(pid) = name_str.parse::() { + if let Some(stat) = proc_stat::parse_proc_stat(pid) { + parent_map.insert(pid, stat.ppid); + pgid_map.insert(pid, stat.pgid); + } + } + } + + (parent_map, pgid_map) + } + + /// Build a process tree for a root PID using the cached parent/pgid maps + fn build_tree_from_cache( + root_pid: u32, + parent_map: &HashMap, + pgid_map: &HashMap, + ) -> HashSet { + let mut tree = HashSet::new(); + tree.insert(root_pid); + + // Get the PGID of the root process + let root_pgid = match pgid_map.get(&root_pid) { + Some(&pgid) => pgid, + None => return tree, // Process doesn't exist + }; + + // Collect PIDs with the same PGID + let same_pgid_pids: Vec = pgid_map + .iter() + .filter(|(_, &pgid)| pgid == root_pgid) + .map(|(&pid, _)| pid) + .collect(); + + // Check if each process is a descendant of root_pid + for pid in same_pgid_pids { + if is_descendant_of(pid, root_pid, parent_map) { + tree.insert(pid); + } + } + + tree + } + + /// Check if `pid` is a descendant of `ancestor` using the parent map + fn is_descendant_of(pid: u32, ancestor: u32, parent_map: &HashMap) -> bool { + if pid == ancestor { + return true; + } + + let mut current = pid; + let mut visited = HashSet::new(); + + while let Some(&parent) = parent_map.get(¤t) { + if parent == ancestor { + return true; + } + if parent == 0 || parent == 1 || !visited.insert(current) { + // Reached init or a cycle, not a descendant + return false; + } + current = parent; + } + + false + } + + pub fn get_process_tree(root_pid: u32) -> HashSet { + let mut cache = GLOBAL_CACHE.lock().unwrap(); + + // Check if we have a cached result + if let Some(pids) = cache.kernel_caches.get(&root_pid) { + return pids.clone(); + } + + // No cache entry - need to build one + // If we have no /proc data yet, do an initial scan + if cache.parent_map.is_empty() { + let (parent_map, pgid_map) = scan_proc(); + cache.parent_map = parent_map; + cache.pgid_map = pgid_map; + } + + let pids = build_tree_from_cache(root_pid, &cache.parent_map, &cache.pgid_map); + cache.kernel_caches.insert(root_pid, pids.clone()); + pids + } + + pub fn tick_process_cache(_root_pid: u32) { + let mut cache = GLOBAL_CACHE.lock().unwrap(); + + cache.tick_count += 1; + + if cache.tick_count >= CACHE_REFRESH_INTERVAL { + // Do ONE /proc scan for all kernels + let (parent_map, pgid_map) = scan_proc(); + cache.parent_map = parent_map; + cache.pgid_map = pgid_map; + + // Rebuild all kernel caches + let root_pids: Vec = cache.kernel_caches.keys().cloned().collect(); + for root in root_pids { + let pids = build_tree_from_cache(root, &cache.parent_map, &cache.pgid_map); + cache.kernel_caches.insert(root, pids); + } + + cache.tick_count = 0; + } + } + + pub fn clear_process_cache(root_pid: u32) { + let mut cache = GLOBAL_CACHE.lock().unwrap(); + cache.kernel_caches.remove(&root_pid); + } +} + +// ============================================================================= +// Windows implementation with cached process tree +// ============================================================================= + +#[cfg(target_os = "windows")] +mod windows { + use std::collections::{HashMap, HashSet}; + use std::mem::size_of; + use std::sync::Mutex; + + use once_cell::sync::Lazy; + use windows::Win32::Foundation::CloseHandle; + + use super::CACHE_REFRESH_INTERVAL; + use windows::Win32::System::Diagnostics::ToolHelp::{ + CreateToolhelp32Snapshot, Process32First, Process32Next, PROCESSENTRY32, TH32CS_SNAPPROCESS, + }; + + /// Cache entry for a process tree + struct CacheEntry { + pids: HashSet, + tick_count: u32, + } + + /// Global cache for process trees, keyed by root PID + static PROCESS_CACHE: Lazy>> = + Lazy::new(|| Mutex::new(HashMap::new())); + + /// Build the complete process tree by scanning all processes + #[allow(unsafe_code)] + fn scan_process_tree(root_pid: u32) -> HashSet { + let mut tree = HashSet::new(); + tree.insert(root_pid); + + // Create a snapshot of all processes + // SAFETY: CreateToolhelp32Snapshot is a well-documented Windows API. + // TH32CS_SNAPPROCESS with 0 requests a snapshot of all processes. + let snapshot = unsafe { CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0) }; + let snapshot = match snapshot { + Ok(handle) => handle, + Err(_) => return tree, + }; + + // Build a parent-child map + let mut parent_map: HashMap = HashMap::new(); + + let mut entry = PROCESSENTRY32 { + dwSize: size_of::() as u32, + ..Default::default() + }; + + // SAFETY: We have a valid snapshot handle and properly initialized PROCESSENTRY32 + // with dwSize set. Process32First/Next read process info into the entry struct. + // CloseHandle releases the snapshot handle when done. + unsafe { + if Process32First(snapshot, &mut entry).is_ok() { + loop { + let pid = entry.th32ProcessID; + let ppid = entry.th32ParentProcessID; + parent_map.insert(pid, ppid); + + if Process32Next(snapshot, &mut entry).is_err() { + break; + } + } + } + let _ = CloseHandle(snapshot); + } + + // Find all descendants using BFS + let mut to_visit = vec![root_pid]; + while let Some(pid) = to_visit.pop() { + // Find all children of this PID + for (&child_pid, &parent_pid) in &parent_map { + if parent_pid == pid && !tree.contains(&child_pid) { + tree.insert(child_pid); + to_visit.push(child_pid); + } + } + } + + tree + } + + pub fn get_process_tree(root_pid: u32) -> HashSet { + let mut cache = PROCESS_CACHE.lock().unwrap(); + + if let Some(entry) = cache.get(&root_pid) { + // Return cached result + return entry.pids.clone(); + } + + // No cache entry, do a full scan + let pids = scan_process_tree(root_pid); + cache.insert( + root_pid, + CacheEntry { + pids: pids.clone(), + tick_count: 0, + }, + ); + pids + } + + pub fn tick_process_cache(root_pid: u32) { + let mut cache = PROCESS_CACHE.lock().unwrap(); + + if let Some(entry) = cache.get_mut(&root_pid) { + entry.tick_count += 1; + + if entry.tick_count >= CACHE_REFRESH_INTERVAL { + // Time to refresh the cache + entry.pids = scan_process_tree(root_pid); + entry.tick_count = 0; + } + } + } + + pub fn clear_process_cache(root_pid: u32) { + let mut cache = PROCESS_CACHE.lock().unwrap(); + cache.remove(&root_pid); + } +} diff --git a/crates/kcserver/src/resource_monitor.rs b/crates/kcserver/src/resource_monitor.rs index e5034bb..222334a 100644 --- a/crates/kcserver/src/resource_monitor.rs +++ b/crates/kcserver/src/resource_monitor.rs @@ -8,23 +8,111 @@ //! Global resource usage monitor for all kernel sessions. -use std::collections::HashSet; +#[cfg(target_os = "linux")] +use std::collections::HashMap; use std::sync::{Arc, RwLock}; use std::time::Duration; use kcshared::kernel_message::{KernelMessage, ResourceUpdate}; use kcshared::websocket_message::WebsocketMessage; -use sysinfo::{Pid, ProcessesToUpdate, System}; +use sysinfo::{Pid, ProcessRefreshKind, ProcessesToUpdate, System}; use tokio::sync::mpsc; use tokio::time::MissedTickBehavior; use crate::kernel_session::KernelSession; +use crate::process_tree; -/// Metrics collected for a process tree +/// CPU usage collected for a process tree (used on non-Linux platforms) +#[cfg(not(target_os = "linux"))] struct ProcessMetrics { cpu_percent: u64, - memory_bytes: u64, - thread_count: u64, +} + +/// Tracks CPU times for computing CPU usage percentage on Linux. +/// +/// sysinfo doesn't compute CPU usage when using ProcessesToUpdate::Some(), +/// so we track CPU times ourselves and compute the percentage manually. +#[cfg(target_os = "linux")] +struct CpuTracker { + /// Previous CPU times per process: pid -> (utime + stime) + prev_times: HashMap, + /// Previous total system CPU time (sum of all CPU jiffies) + prev_total_cpu: u64, +} + +#[cfg(target_os = "linux")] +impl CpuTracker { + fn new() -> Self { + Self { + prev_times: HashMap::new(), + prev_total_cpu: 0, + } + } + + /// Compute CPU usage percentage for a set of processes. + /// Returns the total CPU percentage across all PIDs in the set. + /// + /// # Arguments + /// * `pids` - The set of process IDs to compute CPU usage for + /// * `current_total_cpu` - The current total system CPU time (should be read once per monitoring tick) + fn compute_cpu_usage( + &mut self, + pids: &std::collections::HashSet, + current_total_cpu: u64, + ) -> f32 { + use crate::proc_stat; + + let total_cpu_delta = current_total_cpu.saturating_sub(self.prev_total_cpu); + + // If no time has passed (or first call), we can't compute usage + if total_cpu_delta == 0 || self.prev_total_cpu == 0 { + // Still update the tracking for next time + for &pid in pids { + if let Some(stat) = proc_stat::parse_proc_stat(pid) { + self.prev_times.insert(pid, stat.cpu_time()); + } + } + return 0.0; + } + + let mut total_process_cpu_delta: u64 = 0; + let mut new_times = HashMap::new(); + + for &pid in pids { + if let Some(stat) = proc_stat::parse_proc_stat(pid) { + let current_time = stat.cpu_time(); + new_times.insert(pid, current_time); + + if let Some(&prev_time) = self.prev_times.get(&pid) { + total_process_cpu_delta += current_time.saturating_sub(prev_time); + } + // If no previous time, this is a new process - contributes 0 to delta + } + } + + // Update prev_times with new values (don't replace entirely, as there may be + // entries for other sessions that we need to preserve) + for (pid, time) in new_times { + self.prev_times.insert(pid, time); + } + + // Compute percentage: (process_delta / total_delta) * 100 * num_cpus + // The result is scaled to 100% per CPU core (like sysinfo does) + let num_cpus = proc_stat::count_cpus() as f32; + (total_process_cpu_delta as f32 / total_cpu_delta as f32) * 100.0 * num_cpus + } + + /// Update the previous total CPU time after all sessions have been processed. + /// This should be called once per monitoring tick, after all calls to compute_cpu_usage. + fn update_prev_total_cpu(&mut self, current_total_cpu: u64) { + self.prev_total_cpu = current_total_cpu; + } + + /// Remove stale entries from prev_times that are no longer tracked. + /// Call this periodically with the set of all currently tracked PIDs across all sessions. + fn cleanup_stale_entries(&mut self, active_pids: &std::collections::HashSet) { + self.prev_times.retain(|pid, _| active_pids.contains(pid)); + } } /// Start the global resource monitor. @@ -59,6 +147,11 @@ pub fn start_global_resource_monitor( // Create a System instance and keep it alive for accurate CPU measurements let mut system = System::new(); + // On Linux, use our own CPU tracker since sysinfo doesn't compute CPU + // usage when using ProcessesToUpdate::Some() + #[cfg(target_os = "linux")] + let mut cpu_tracker = CpuTracker::new(); + // Track current interval let mut current_sample_interval_ms = sample_interval_ms; @@ -71,6 +164,8 @@ pub fn start_global_resource_monitor( let mut interval = tokio::time::interval(effective_interval); interval.set_missed_tick_behavior(MissedTickBehavior::Delay); + // Prime CPU usage statistics (sysinfo needs an initial refresh) + system.refresh_cpu_usage(); // Consume the first tick immediately interval.tick().await; @@ -82,15 +177,6 @@ pub fn start_global_resource_monitor( continue; } - // Refresh all process data once - system.refresh_processes(ProcessesToUpdate::All); - - // Get the current timestamp - let timestamp = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|d| d.as_millis() as u64) - .unwrap_or(0); - // Clone session data we need while holding the lock briefly // This avoids holding the std::sync::RwLock across await points let session_data: Vec<_> = { @@ -115,6 +201,36 @@ pub fn start_global_resource_monitor( }; // Lock is now released + // Check if any clients are connected before doing any work + let mut has_connected_clients = false; + for (_, state, _) in &session_data { + let state_guard = state.read().await; + if state_guard.connected { + has_connected_clients = true; + break; + } + } + + // Skip all work if no clients are connected + if !has_connected_clients { + continue; + } + + // Get the current timestamp + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_millis() as u64) + .unwrap_or(0); + + // On Linux, read the system CPU time ONCE for all sessions in this tick + // This prevents artificial spikes in later sessions due to near-zero time deltas + #[cfg(target_os = "linux")] + let current_total_cpu = crate::proc_stat::read_total_cpu_time(); + + // Track all PIDs across all sessions for cleanup (Linux only) + #[cfg(target_os = "linux")] + let mut all_tracked_pids = std::collections::HashSet::new(); + for (session_id, state, ws_json_tx) in session_data { // Read the kernel state (tokio::sync::RwLock) let state_guard = state.read().await; @@ -135,14 +251,70 @@ pub fn start_global_resource_monitor( // Release the state lock before collecting metrics drop(state_guard); + // Get the process tree using OS-specific efficient enumeration + let tree_pids = process_tree::get_process_tree(pid); + + // Track all PIDs for cleanup (Linux only) + #[cfg(target_os = "linux")] + all_tracked_pids.extend(&tree_pids); + + // Log trace info about the process tree, including a list of the + // PIDs being monitored + log::trace!( + "[session {}] Monitoring resource usage for process tree with root PID {}: {} processes; {:?}", + session_id, + pid, + tree_pids.len(), + tree_pids + ); + + // Refresh only the processes we need for memory info + let pids_to_refresh: Vec = + tree_pids.iter().map(|&p| Pid::from_u32(p)).collect(); + + // For non-Linux platforms, refresh CPU and memory + #[cfg(not(target_os = "linux"))] + let refresh_kind = ProcessRefreshKind::new() + .with_cpu() + .with_memory(); + + // We don't refresh CPU on Linux here, because there's a bug in the + // sysinfo crate that causes CPU usage to be reported as 0.0 + // when using ProcessesToUpdate::Some(). Instead, we compute CPU + // usage ourselves using /proc data. + #[cfg(target_os = "linux")] + let refresh_kind = ProcessRefreshKind::new() + .with_memory(); + + system.refresh_processes_specifics( + ProcessesToUpdate::Some(&pids_to_refresh), + refresh_kind, + ); + + // Update the process cache tick counter (Windows only) + process_tree::tick_process_cache(pid); + // Collect metrics for this kernel's process tree - let metrics = collect_tree_metrics(&system, pid); + // On Linux, compute CPU ourselves; on other platforms use sysinfo + #[cfg(target_os = "linux")] + let cpu_percent = { + let cpu = cpu_tracker.compute_cpu_usage(&tree_pids, current_total_cpu); + cpu.round() as u64 + }; + + #[cfg(not(target_os = "linux"))] + let cpu_percent = { + let metrics = collect_tree_metrics(&system, &tree_pids); + metrics.cpu_percent + }; + + let (memory_bytes, thread_count) = collect_memory_and_threads(&system, &tree_pids); // Create the resource update message let update = ResourceUpdate { - cpu_percent: metrics.cpu_percent, - memory_bytes: metrics.memory_bytes, - thread_count: metrics.thread_count, + cpu_percent, + memory_bytes, + thread_count, sampling_period_ms: current_sample_interval_ms, timestamp, }; @@ -152,9 +324,9 @@ pub fn start_global_resource_monitor( let mut state_guard = state.write().await; state_guard.resource_usage = Some(kallichore_api::models::ResourceUsage { - cpu_percent: metrics.cpu_percent as i64, - memory_bytes: metrics.memory_bytes as i64, - thread_count: metrics.thread_count as i64, + cpu_percent: cpu_percent as i64, + memory_bytes: memory_bytes as i64, + thread_count: thread_count as i64, sampling_period_ms: current_sample_interval_ms as i64, timestamp: timestamp as i64, }); @@ -171,6 +343,15 @@ pub fn start_global_resource_monitor( ); } } + + // On Linux, update the tracker's previous total CPU time after processing all sessions + // This ensures all sessions in this tick use the same time delta + #[cfg(target_os = "linux")] + { + cpu_tracker.update_prev_total_cpu(current_total_cpu); + // Clean up stale entries from dead processes to prevent memory leak + cpu_tracker.cleanup_stale_entries(&all_tracked_pids); + } } Some(new_interval_ms) = interval_update_rx.recv() => { log::info!( @@ -203,50 +384,32 @@ pub fn start_global_resource_monitor( }); } -/// Collect metrics for a process and all its descendants. +/// Collect memory and thread count for a set of processes. /// -/// This function walks the process tree starting from the given root PID, -/// summing CPU usage, memory, and thread counts for the entire tree. +/// This function sums memory and thread counts for all processes in the provided set of PIDs. +/// CPU usage is handled separately on Linux due to sysinfo limitations. /// /// # Arguments /// -/// * `system` - The sysinfo System instance (must have been refreshed) -/// * `root_pid` - The root process ID to start from +/// * `system` - The sysinfo System instance (must have been refreshed for the given PIDs) +/// * `pids` - Set of process IDs to collect metrics for /// /// # Returns /// -/// Aggregated metrics for the process tree -fn collect_tree_metrics(system: &System, root_pid: u32) -> ProcessMetrics { - let pid = Pid::from_u32(root_pid); - - let mut total_cpu = 0.0f32; +/// Tuple of (memory_bytes, thread_count) +fn collect_memory_and_threads( + system: &System, + pids: &std::collections::HashSet, +) -> (u64, u64) { let mut total_memory = 0u64; let mut total_threads = 0u64; - // Collect all PIDs in process tree using BFS - let mut pids_to_check = vec![pid]; - let mut visited = HashSet::new(); - - while let Some(check_pid) = pids_to_check.pop() { - if !visited.insert(check_pid) { - continue; // Already visited - } - - // Find children by checking parent_pid - for (child_pid, proc) in system.processes() { - if proc.parent() == Some(check_pid) && !visited.contains(child_pid) { - pids_to_check.push(*child_pid); - } - } - } - - // Sum metrics for all processes in tree (using cached data) - for pid in &visited { - if let Some(proc) = system.process(*pid) { - total_cpu += proc.cpu_usage(); + for &pid in pids { + let sysinfo_pid = Pid::from_u32(pid); + if let Some(proc) = system.process(sysinfo_pid) { total_memory += proc.memory(); // Thread count: use tasks() if available, otherwise assume 1 thread - #[cfg(any(target_os = "linux"))] + #[cfg(target_os = "linux")] { if let Some(tasks) = proc.tasks() { total_threads += tasks.len() as u64; @@ -254,7 +417,7 @@ fn collect_tree_metrics(system: &System, root_pid: u32) -> ProcessMetrics { total_threads += 1; } } - #[cfg(not(any(target_os = "linux")))] + #[cfg(not(target_os = "linux"))] { // On macOS and Windows, tasks() is not available // Assume 1 thread per process as a baseline @@ -263,9 +426,35 @@ fn collect_tree_metrics(system: &System, root_pid: u32) -> ProcessMetrics { } } + (total_memory, total_threads) +} + +/// Collect CPU metrics for a set of processes. +/// +/// This function sums CPU usage for all processes in the provided set of PIDs. +/// Memory and thread counts are collected separately by `collect_memory_and_threads`. +/// +/// # Arguments +/// +/// * `system` - The sysinfo System instance (must have been refreshed for the given PIDs) +/// * `pids` - Set of process IDs to collect metrics for +/// +/// # Returns +/// +/// Aggregated CPU metrics for the processes +#[cfg(not(target_os = "linux"))] +fn collect_tree_metrics(system: &System, pids: &std::collections::HashSet) -> ProcessMetrics { + let mut total_cpu = 0.0f32; + + // Sum CPU for all processes in tree (using cached data) + for &pid in pids { + let sysinfo_pid = Pid::from_u32(pid); + if let Some(proc) = system.process(sysinfo_pid) { + total_cpu += proc.cpu_usage(); + } + } + ProcessMetrics { cpu_percent: total_cpu.round() as u64, - memory_bytes: total_memory, - thread_count: total_threads, } } diff --git a/crates/kcserver/src/server.rs b/crates/kcserver/src/server.rs index 186262c..25ceae5 100644 --- a/crates/kcserver/src/server.rs +++ b/crates/kcserver/src/server.rs @@ -10,6 +10,7 @@ #![allow(unused_imports)] +use crate::process_tree; use crate::resource_monitor; use crate::websocket_service::ApiWebsocketExt; use anyhow::anyhow; @@ -1362,10 +1363,10 @@ where } }; - // Ensure the session is not running - let status = { + // Ensure the session is not running and get the process ID for cache cleanup + let (status, process_id) = { let state = kernel_session.state.read().await; - state.status.clone() + (state.status.clone(), state.process_id) }; if status != models::Status::Exited { let error = KSError::SessionRunning(session_id.clone()); @@ -1375,6 +1376,11 @@ where )); } + // Clear process tree cache for this session's kernel PID + if let Some(pid) = process_id { + process_tree::clear_process_cache(pid); + } + // Ensure we get a write lock on the kernel sessions for the duration of // this function let mut sessions = self.kernel_sessions.write().unwrap(); diff --git a/crates/kcserver/tests/resource_usage_test.rs b/crates/kcserver/tests/resource_usage_test.rs index 8c7c080..95e498a 100644 --- a/crates/kcserver/tests/resource_usage_test.rs +++ b/crates/kcserver/tests/resource_usage_test.rs @@ -321,3 +321,238 @@ async fn test_resource_usage_websocket_messages() { } } } + +/// Verify that multiple concurrent kernel sessions get accurate CPU usage measurements. +#[tokio::test] +async fn test_multi_session_cpu_tracking() { + let test_result = tokio::time::timeout(Duration::from_secs(45), async { + let python_cmd = if let Some(cmd) = get_python_executable().await { + cmd + } else { + println!("Skipping test: No Python executable found"); + return; + }; + + if !is_ipykernel_available().await { + println!("Skipping test: ipykernel not available for {}", python_cmd); + return; + } + + let server = TestServer::start().await; + let client = server.create_client().await; + + // Create two kernel sessions + let session_id_1 = format!("multi-session-1-{}", Uuid::new_v4()); + let session_id_2 = format!("multi-session-2-{}", Uuid::new_v4()); + + let new_session_1 = create_test_session(session_id_1.clone(), &python_cmd); + let new_session_2 = create_test_session(session_id_2.clone(), &python_cmd); + + println!("Creating two kernel sessions..."); + let _created_1 = create_session_with_client(&client, new_session_1).await; + let _created_2 = create_session_with_client(&client, new_session_2).await; + + // Start both sessions + println!("Starting session 1..."); + let start_response_1 = client + .start_session(session_id_1.clone()) + .await + .expect("Failed to start session 1"); + + match &start_response_1 { + kallichore_api::StartSessionResponse::Started(_) => { + println!("Session 1 started successfully"); + } + _ => { + println!("Session 1 failed to start, skipping test"); + return; + } + } + + println!("Starting session 2..."); + let start_response_2 = client + .start_session(session_id_2.clone()) + .await + .expect("Failed to start session 2"); + + match &start_response_2 { + kallichore_api::StartSessionResponse::Started(_) => { + println!("Session 2 started successfully"); + } + _ => { + println!("Session 2 failed to start, skipping test"); + return; + } + } + + // Wait for both kernels to fully start + tokio::time::sleep(Duration::from_millis(2000)).await; + + // Create WebSocket connections for both sessions to trigger resource monitoring + let ws_url_1 = format!( + "ws://localhost:{}/sessions/{}/channels", + server.port(), + session_id_1 + ); + let ws_url_2 = format!( + "ws://localhost:{}/sessions/{}/channels", + server.port(), + session_id_2 + ); + + let mut comm_1 = CommunicationChannel::create_websocket(&ws_url_1) + .await + .expect("Failed to create websocket for session 1"); + + let mut comm_2 = CommunicationChannel::create_websocket(&ws_url_2) + .await + .expect("Failed to create websocket for session 2"); + + // Wait for several sampling periods to collect resource usage data + println!("Waiting for resource usage sampling across both sessions..."); + tokio::time::sleep(Duration::from_millis(3500)).await; + + // Collect resource usage updates from both sessions + let mut session_1_updates = Vec::new(); + let mut session_2_updates = Vec::new(); + + println!("Collecting resource usage updates..."); + let collection_start = std::time::Instant::now(); + + while collection_start.elapsed() < Duration::from_secs(6) { + // Check session 1 + match tokio::time::timeout(Duration::from_millis(100), comm_1.receive_message()).await { + Ok(Ok(Some(message_text))) => { + if let Ok(ws_msg) = serde_json::from_str::(&message_text) { + if let WebsocketMessage::Kernel(KernelMessage::ResourceUsage(update)) = + ws_msg + { + println!( + "Session 1 update: cpu={}%, memory={}B", + update.cpu_percent, update.memory_bytes + ); + session_1_updates.push(update); + } + } + } + _ => {} + } + + // Check session 2 + match tokio::time::timeout(Duration::from_millis(100), comm_2.receive_message()).await { + Ok(Ok(Some(message_text))) => { + if let Ok(ws_msg) = serde_json::from_str::(&message_text) { + if let WebsocketMessage::Kernel(KernelMessage::ResourceUsage(update)) = + ws_msg + { + println!( + "Session 2 update: cpu={}%, memory={}B", + update.cpu_percent, update.memory_bytes + ); + session_2_updates.push(update); + } + } + } + _ => {} + } + + // Break if we have enough updates + if session_1_updates.len() >= 3 && session_2_updates.len() >= 3 { + break; + } + + tokio::time::sleep(Duration::from_millis(200)).await; + } + + println!( + "Collected {} updates from session 1, {} from session 2", + session_1_updates.len(), + session_2_updates.len() + ); + + // Verify we got updates from both sessions + assert!( + !session_1_updates.is_empty(), + "Should have received resource updates from session 1" + ); + assert!( + !session_2_updates.is_empty(), + "Should have received resource updates from session 2" + ); + + // Calculate average CPU percentages for both sessions + let avg_cpu_1: f64 = session_1_updates + .iter() + .map(|u| u.cpu_percent as f64) + .sum::() + / session_1_updates.len() as f64; + + let avg_cpu_2: f64 = session_2_updates + .iter() + .map(|u| u.cpu_percent as f64) + .sum::() + / session_2_updates.len() as f64; + + println!( + "Average CPU usage: Session 1 = {:.1}%, Session 2 = {:.1}%", + avg_cpu_1, avg_cpu_2 + ); + + // Key regression test: verify that session 2 doesn't have artificially inflated CPU + // Both sessions are idle Python kernels, so they should have similar (low) CPU usage + // The bug would cause session 2 to show spikes of hundreds of percent + // + // We allow a reasonable range: idle kernels might use 0-20% CPU occasionally, + // but sustained values over 100% indicate the bug has returned + let max_cpu_2 = session_2_updates + .iter() + .map(|u| u.cpu_percent) + .max() + .unwrap_or(0); + + assert!( + max_cpu_2 < 150, + "Session 2 showed abnormally high CPU usage ({}%), \ + suggesting the multi-session CPU tracking bug has returned. \ + Expected idle kernel to use <150% CPU.", + max_cpu_2 + ); + + // Additional sanity check: both sessions should have comparable average CPU usage + // (within a factor of 10, since they're both idle) + let ratio = if avg_cpu_1 > 1.0 && avg_cpu_2 > 1.0 { + (avg_cpu_1 / avg_cpu_2).max(avg_cpu_2 / avg_cpu_1) + } else { + 1.0 + }; + + assert!( + ratio < 10.0, + "Session CPU usage ratio too large ({:.1}:1), suggesting measurement issue. \ + Session 1 avg: {:.1}%, Session 2 avg: {:.1}%", + ratio, + avg_cpu_1, + avg_cpu_2 + ); + + println!("Multi-session CPU tracking test passed!"); + println!( + "Both sessions showed reasonable CPU usage without artificial spikes." + ); + + // Clean up + let _ = comm_1.close().await; + let _ = comm_2.close().await; + drop(server); + }) + .await; + + match test_result { + Ok(_) => { + println!("Multi-session CPU tracking test completed successfully"); + } + Err(_) => { + panic!("Multi-session CPU tracking test timed out after 45 seconds"); + } + } +}