diff --git a/Cargo.lock b/Cargo.lock index 178ccf34..df6bee50 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5496,8 +5496,10 @@ name = "pluto-cli" version = "1.7.1" dependencies = [ "backon", + "bytes", "chrono", "clap", + "futures", "hex", "humantime", "k256", @@ -5515,11 +5517,13 @@ dependencies = [ "pluto-relay-server", "pluto-ssz", "pluto-tracing", + "quick-xml", "rand 0.8.6", "reqwest 0.13.2", "serde", "serde_json", "serde_with", + "sysinfo", "tempfile", "test-case", "thiserror 2.0.18", diff --git a/Cargo.toml b/Cargo.toml index 41d56f61..f70094df 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ async-trait = "0.1.89" alloy = { version = "1.3", features = ["essentials"] } built = { version = "0.8.0", features = ["git2", "chrono", "cargo-lock"] } blst = "0.3" +bytes = "1" anyhow = "1" axum = "0.8.6" cancellation = "0.1.0" @@ -97,6 +98,8 @@ tree_hash_derive = "0.12" tar = "0.4" flate2 = "1.1" wiremock = "0.6" +sysinfo = "0.33" +quick-xml = { version = "0.39", features = ["serialize"] } # Crates in the workspace pluto-app = { path = "crates/app" } diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index 06a81d60..d7bc4422 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -38,9 +38,13 @@ serde_json.workspace = true serde_with = { workspace = true, features = ["base64"] } rand.workspace = true tempfile.workspace = true +bytes.workspace = true reqwest.workspace = true url.workspace = true chrono.workspace = true +sysinfo.workspace = true +quick-xml.workspace = true +futures.workspace = true [dev-dependencies] tempfile.workspace = true diff --git a/crates/cli/src/commands/test/infra.rs b/crates/cli/src/commands/test/infra.rs index 87a66f8d..242a1e27 100644 --- a/crates/cli/src/commands/test/infra.rs +++ b/crates/cli/src/commands/test/infra.rs @@ -1,9 +1,628 @@ //! Infrastructure and hardware tests. -use super::{TestConfigArgs, helpers::TestCategoryResult}; -use crate::error::Result; +use std::{ + io::Write, + path::{Path, PathBuf}, + time::{Duration, Instant}, +}; + use clap::Args; -use std::io::Write; +use serde::Deserialize; +use tokio_util::sync::CancellationToken; + +use super::{ + AllCategoriesResult, TestCaseName, TestCategory, TestCategoryResult, TestConfigArgs, + TestResult, TestVerdict, calculate_score, evaluate_rtt, filter_tests, + must_output_to_file_on_quiet, publish_result_to_obol_api, sort_tests, write_result_to_file, + write_result_to_writer, +}; +use crate::{ + duration::Duration as CliDuration, + error::{CliError, Result}, +}; + +const FIO_NOT_FOUND: &str = "fio command not found, install fio from https://fio.readthedocs.io/en/latest/fio_doc.html#binary-packages or using the package manager of your choice (apt, yum, brew, etc.)"; + +const DISK_OPS_NUM_OF_JOBS: u32 = 8; +const DISK_OPS_MBS_TOTAL: u32 = 4096; +const DISK_WRITE_SPEED_MBS_AVG: f64 = 1000.0; +const DISK_WRITE_SPEED_MBS_POOR: f64 = 500.0; +const DISK_WRITE_IOPS_AVG: f64 = 2000.0; +const DISK_WRITE_IOPS_POOR: f64 = 1000.0; +const DISK_READ_SPEED_MBS_AVG: f64 = 1000.0; +const DISK_READ_SPEED_MBS_POOR: f64 = 500.0; +const DISK_READ_IOPS_AVG: f64 = 2000.0; +const DISK_READ_IOPS_POOR: f64 = 1000.0; +const AVAILABLE_MEMORY_MBS_AVG: u64 = 4000; +const AVAILABLE_MEMORY_MBS_POOR: u64 = 2000; +const TOTAL_MEMORY_MBS_AVG: u64 = 8000; +const TOTAL_MEMORY_MBS_POOR: u64 = 4000; +const INTERNET_LATENCY_AVG: Duration = Duration::from_millis(20); +const INTERNET_LATENCY_POOR: Duration = Duration::from_millis(50); +const INTERNET_DOWNLOAD_SPEED_MBPS_AVG: f64 = 50.0; +const INTERNET_DOWNLOAD_SPEED_MBPS_POOR: f64 = 15.0; +const INTERNET_UPLOAD_SPEED_MBPS_AVG: f64 = 50.0; +const INTERNET_UPLOAD_SPEED_MBPS_POOR: f64 = 15.0; + +#[derive(Deserialize)] +struct FioResult { + jobs: Vec, +} + +#[derive(Deserialize)] +struct FioResultJob { + read: FioResultSingle, + write: FioResultSingle, +} + +#[derive(Deserialize)] +struct FioResultSingle { + iops: f64, + bw: f64, +} + +#[allow(async_fn_in_trait)] +trait DiskTestTool { + async fn check_availability(&self) -> Result<()>; + async fn write_speed(&self, path: &Path, block_size_kb: i32) -> Result; + async fn write_iops(&self, path: &Path, block_size_kb: i32) -> Result; + async fn read_speed(&self, path: &Path, block_size_kb: i32) -> Result; + async fn read_iops(&self, path: &Path, block_size_kb: i32) -> Result; +} + +struct FioTestTool; + +impl DiskTestTool for FioTestTool { + async fn check_availability(&self) -> Result<()> { + let result = tokio::process::Command::new("fio") + .arg("--version") + .output() + .await; + match result { + Ok(o) if o.status.success() => Ok(()), + _ => Err(CliError::Other(FIO_NOT_FOUND.to_string())), + } + } + + async fn write_speed(&self, path: &Path, block_size_kb: i32) -> Result { + let out = fio_command(path, block_size_kb, "write").await?; + let res: FioResult = serde_json::from_slice(&out) + .map_err(|e| CliError::Other(format!("unmarshal fio result: {e}")))?; + let job = res + .jobs + .into_iter() + .next() + .ok_or_else(|| CliError::Other("fio returned no jobs".to_string()))?; + Ok(job.write.bw / 1024.0) + } + + async fn write_iops(&self, path: &Path, block_size_kb: i32) -> Result { + let out = fio_command(path, block_size_kb, "write").await?; + let res: FioResult = serde_json::from_slice(&out) + .map_err(|e| CliError::Other(format!("unmarshal fio result: {e}")))?; + let job = res + .jobs + .into_iter() + .next() + .ok_or_else(|| CliError::Other("fio returned no jobs".to_string()))?; + Ok(job.write.iops) + } + + async fn read_speed(&self, path: &Path, block_size_kb: i32) -> Result { + let out = fio_command(path, block_size_kb, "read").await?; + let res: FioResult = serde_json::from_slice(&out) + .map_err(|e| CliError::Other(format!("unmarshal fio result: {e}")))?; + let job = res + .jobs + .into_iter() + .next() + .ok_or_else(|| CliError::Other("fio returned no jobs".to_string()))?; + Ok(job.read.bw / 1024.0) + } + + async fn read_iops(&self, path: &Path, block_size_kb: i32) -> Result { + let out = fio_command(path, block_size_kb, "read").await?; + let res: FioResult = serde_json::from_slice(&out) + .map_err(|e| CliError::Other(format!("unmarshal fio result: {e}")))?; + let job = res + .jobs + .into_iter() + .next() + .ok_or_else(|| CliError::Other("fio returned no jobs".to_string()))?; + Ok(job.read.iops) + } +} + +fn can_write_to_dir(dir: &Path) -> bool { + let test_file = dir.join(".perm_test_tmp"); + match std::fs::File::create(&test_file) { + Ok(_) => { + let _ = std::fs::remove_file(&test_file); + true + } + Err(_) => false, + } +} + +async fn fio_command(path: &Path, block_size_kb: i32, operation: &str) -> Result> { + let tmp = tempfile::Builder::new() + .prefix("fiotest") + .tempfile_in(path) + .map_err(|e| CliError::Other(format!("create fio temp file: {e}")))?; + let filename_str = tmp + .path() + .to_str() + .ok_or_else(|| CliError::Other("fio temp file path is not valid UTF-8".to_string()))?; + let size_per_job = DISK_OPS_MBS_TOTAL / DISK_OPS_NUM_OF_JOBS; + + let output = tokio::process::Command::new("fio") + .arg("--name=fioTest") + .arg(format!("--filename={filename_str}")) + .arg(format!("--size={size_per_job}Mb")) + .arg(format!("--blocksize={block_size_kb}k")) + .arg(format!("--numjobs={DISK_OPS_NUM_OF_JOBS}")) + .arg(format!("--rw={operation}")) + .arg("--direct=1") + .arg("--runtime=60s") + .arg("--group_reporting") + .arg("--output-format=json") + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .kill_on_drop(true) + .spawn() + .map_err(|e| CliError::Other(format!("exec fio command: {e}")))? + .wait_with_output() + .await + .map_err(|e| CliError::Other(format!("exec fio command: {e}")))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(CliError::Other(format!("exec fio command: {stderr}"))); + } + + Ok(output.stdout) +} + +async fn disk_write_speed_test( + args: &TestInfraArgs, + disk_dir: &Path, + tool: &impl DiskTestTool, +) -> TestResult { + let mut result = TestResult::new("DiskWriteSpeed"); + + tracing::info!( + test_file_size_mb = DISK_OPS_MBS_TOTAL, + jobs = DISK_OPS_NUM_OF_JOBS, + test_file_path = %disk_dir.display(), + "Testing disk write speed..." + ); + + if let Err(e) = tool.check_availability().await { + return result.fail(e); + } + + match tool.write_speed(disk_dir, args.disk_io_block_size_kb).await { + Err(e) => result.fail(e), + Ok(speed) => { + result.verdict = if speed < DISK_WRITE_SPEED_MBS_POOR { + TestVerdict::Poor + } else if speed < DISK_WRITE_SPEED_MBS_AVG { + TestVerdict::Avg + } else { + TestVerdict::Good + }; + result.measurement = format!("{speed:.2}MB/s"); + result + } + } +} + +async fn disk_write_iops_test( + args: &TestInfraArgs, + disk_dir: &Path, + tool: &impl DiskTestTool, +) -> TestResult { + let mut result = TestResult::new("DiskWriteIOPS"); + + tracing::info!( + test_file_size_mb = DISK_OPS_MBS_TOTAL, + jobs = DISK_OPS_NUM_OF_JOBS, + test_file_path = %disk_dir.display(), + "Testing disk write IOPS..." + ); + + if let Err(e) = tool.check_availability().await { + return result.fail(e); + } + + match tool.write_iops(disk_dir, args.disk_io_block_size_kb).await { + Err(e) => result.fail(e), + Ok(iops) => { + result.verdict = if iops < DISK_WRITE_IOPS_POOR { + TestVerdict::Poor + } else if iops < DISK_WRITE_IOPS_AVG { + TestVerdict::Avg + } else { + TestVerdict::Good + }; + result.measurement = format!("{iops:.0}"); + result + } + } +} + +async fn disk_read_speed_test( + args: &TestInfraArgs, + disk_dir: &Path, + tool: &impl DiskTestTool, +) -> TestResult { + let mut result = TestResult::new("DiskReadSpeed"); + + tracing::info!( + test_file_size_mb = DISK_OPS_MBS_TOTAL, + jobs = DISK_OPS_NUM_OF_JOBS, + test_file_path = %disk_dir.display(), + "Testing disk read speed..." + ); + + if let Err(e) = tool.check_availability().await { + return result.fail(e); + } + + match tool.read_speed(disk_dir, args.disk_io_block_size_kb).await { + Err(e) => result.fail(e), + Ok(speed) => { + result.verdict = if speed < DISK_READ_SPEED_MBS_POOR { + TestVerdict::Poor + } else if speed < DISK_READ_SPEED_MBS_AVG { + TestVerdict::Avg + } else { + TestVerdict::Good + }; + result.measurement = format!("{speed:.2}MB/s"); + result + } + } +} + +/// Go bug parity: the original Go implementation (testinfra.go:377) calls +/// ReadSpeed instead of ReadIOPS for this test, then compares the bandwidth +/// result against IOPS thresholds. Fixed here to call read_iops() correctly; +/// the Go behaviour was clearly unintentional. +async fn disk_read_iops_test( + args: &TestInfraArgs, + disk_dir: &Path, + tool: &impl DiskTestTool, +) -> TestResult { + let mut result = TestResult::new("DiskReadIOPS"); + + tracing::info!( + test_file_size_mb = DISK_OPS_MBS_TOTAL, + jobs = DISK_OPS_NUM_OF_JOBS, + test_file_path = %disk_dir.display(), + "Testing disk read IOPS..." + ); + + if let Err(e) = tool.check_availability().await { + return result.fail(e); + } + + match tool.read_iops(disk_dir, args.disk_io_block_size_kb).await { + Err(e) => result.fail(e), + Ok(iops) => { + result.verdict = if iops < DISK_READ_IOPS_POOR { + TestVerdict::Poor + } else if iops < DISK_READ_IOPS_AVG { + TestVerdict::Avg + } else { + TestVerdict::Good + }; + result.measurement = format!("{iops:.0}"); + result + } + } +} + +fn apply_memory_result(result: &mut TestResult, mb: u64, poor: u64, avg: u64) { + result.verdict = if mb < poor { + TestVerdict::Poor + } else if mb < avg { + TestVerdict::Avg + } else { + TestVerdict::Good + }; + result.measurement = format!("{mb}MB"); +} + +async fn available_memory_test() -> TestResult { + let mut result = TestResult::new("AvailableMemory"); + let sys = sysinfo::System::new_with_specifics( + sysinfo::RefreshKind::nothing() + .with_memory(sysinfo::MemoryRefreshKind::nothing().with_ram()), + ); + let mb = sys.available_memory() / 1024 / 1024; + apply_memory_result( + &mut result, + mb, + AVAILABLE_MEMORY_MBS_POOR, + AVAILABLE_MEMORY_MBS_AVG, + ); + result +} + +async fn total_memory_test() -> TestResult { + let mut result = TestResult::new("TotalMemory"); + let sys = sysinfo::System::new_with_specifics( + sysinfo::RefreshKind::nothing() + .with_memory(sysinfo::MemoryRefreshKind::nothing().with_ram()), + ); + let mb = sys.total_memory() / 1024 / 1024; + apply_memory_result(&mut result, mb, TOTAL_MEMORY_MBS_POOR, TOTAL_MEMORY_MBS_AVG); + result +} + +async fn internet_latency_test(args: &TestInfraArgs, client: &reqwest::Client) -> TestResult { + let result = TestResult::new("InternetLatency"); + + let mut server = match super::speedtest::fetch_best_server( + &args.internet_test_servers_only, + &args.internet_test_servers_exclude, + client, + ) + .await + { + Err(e) => return result.fail(e), + Ok(s) => s, + }; + + tracing::info!( + server_name = %server.name, + server_country = %server.country, + server_distance_km = server.distance, + server_id = %server.id, + "Testing internet latency..." + ); + + if let Err(e) = server.ping_test(client).await { + return result.fail(e); + } + + evaluate_rtt( + server.latency, + result, + INTERNET_LATENCY_AVG, + INTERNET_LATENCY_POOR, + ) +} + +fn apply_speed_result(result: &mut TestResult, speed: f64, poor: f64, avg: f64) { + result.verdict = if speed < poor { + TestVerdict::Poor + } else if speed < avg { + TestVerdict::Avg + } else { + TestVerdict::Good + }; + result.measurement = format!("{speed:.2}Mb/s"); +} + +async fn internet_download_speed_test( + args: &TestInfraArgs, + client: &reqwest::Client, +) -> TestResult { + let mut result = TestResult::new("InternetDownloadSpeed"); + + let mut server = match super::speedtest::fetch_best_server( + &args.internet_test_servers_only, + &args.internet_test_servers_exclude, + client, + ) + .await + { + Err(e) => return result.fail(e), + Ok(s) => s, + }; + + tracing::info!( + server_name = %server.name, + server_country = %server.country, + server_distance_km = server.distance, + server_id = %server.id, + "Testing internet download speed..." + ); + + if let Err(e) = server.download_test(client).await { + return result.fail(e); + } + + let speed = server.dl_speed_mbps; + apply_speed_result( + &mut result, + speed, + INTERNET_DOWNLOAD_SPEED_MBPS_POOR, + INTERNET_DOWNLOAD_SPEED_MBPS_AVG, + ); + result +} + +async fn internet_upload_speed_test(args: &TestInfraArgs, client: &reqwest::Client) -> TestResult { + let mut result = TestResult::new("InternetUploadSpeed"); + + let mut server = match super::speedtest::fetch_best_server( + &args.internet_test_servers_only, + &args.internet_test_servers_exclude, + client, + ) + .await + { + Err(e) => return result.fail(e), + Ok(s) => s, + }; + + tracing::info!( + server_name = %server.name, + server_country = %server.country, + server_distance_km = server.distance, + server_id = %server.id, + "Testing internet upload speed..." + ); + + if let Err(e) = server.upload_test(client).await { + return result.fail(e); + } + + let speed = server.ul_speed_mbps; + apply_speed_result( + &mut result, + speed, + INTERNET_UPLOAD_SPEED_MBPS_POOR, + INTERNET_UPLOAD_SPEED_MBPS_AVG, + ); + result +} + +/// Returns the ordered list of supported infra test case names. +pub(crate) fn supported_infra_test_cases() -> Vec { + vec![ + TestCaseName::new("DiskWriteSpeed", 1), + TestCaseName::new("DiskWriteIOPS", 2), + TestCaseName::new("DiskReadSpeed", 3), + TestCaseName::new("DiskReadIOPS", 4), + TestCaseName::new("AvailableMemory", 5), + TestCaseName::new("TotalMemory", 6), + TestCaseName::new("InternetLatency", 7), + TestCaseName::new("InternetDownloadSpeed", 8), + TestCaseName::new("InternetUploadSpeed", 9), + ] +} + +async fn run_single_test( + name: &str, + args: &TestInfraArgs, + disk_dir: &Path, + tool: &impl DiskTestTool, + client: &reqwest::Client, +) -> TestResult { + match name { + "DiskWriteSpeed" => disk_write_speed_test(args, disk_dir, tool).await, + "DiskWriteIOPS" => disk_write_iops_test(args, disk_dir, tool).await, + "DiskReadSpeed" => disk_read_speed_test(args, disk_dir, tool).await, + "DiskReadIOPS" => disk_read_iops_test(args, disk_dir, tool).await, + "AvailableMemory" => available_memory_test().await, + "TotalMemory" => total_memory_test().await, + "InternetLatency" => internet_latency_test(args, client).await, + "InternetDownloadSpeed" => internet_download_speed_test(args, client).await, + "InternetUploadSpeed" => internet_upload_speed_test(args, client).await, + _ => TestResult::new(name).fail(CliError::Other(format!("unknown test: {name}"))), + } +} + +async fn run_tests_with_timeout( + args: &TestInfraArgs, + tests: &[TestCaseName], + disk_dir: &Path, + client: &reqwest::Client, + ct: CancellationToken, +) -> Vec { + let tool = FioTestTool; + let mut results = Vec::new(); + let start = Instant::now(); + + for test_case in tests { + let remaining = args.test_config.timeout.saturating_sub(start.elapsed()); + tokio::select! { + result = run_single_test(test_case.name, args, disk_dir, &tool, client) => { + results.push(result); + } + () = tokio::time::sleep(remaining) => { + results.push(TestResult::new(test_case.name).fail(CliError::TimeoutInterrupted)); + break; + } + () = ct.cancelled() => { + results.push(TestResult::new(test_case.name).fail(CliError::TimeoutInterrupted)); + break; + } + } + } + + results +} + +/// Runs the infrastructure tests. +pub async fn run( + args: TestInfraArgs, + writer: &mut dyn Write, + ct: CancellationToken, +) -> Result { + pluto_tracing::init( + &pluto_tracing::TracingConfig::builder() + .with_default_console() + .build(), + ) + .expect("Failed to initialize tracing"); + + must_output_to_file_on_quiet(args.test_config.quiet, &args.test_config.output_json)?; + + tracing::info!("Starting hardware performance and network connectivity test"); + + let disk_dir = match &args.disk_io_test_file_dir { + Some(dir) => PathBuf::from(dir), + None => std::env::var("HOME") + .or_else(|_| std::env::var("USERPROFILE")) + .map(PathBuf::from) + .map_err(|_| CliError::Other("get user home directory".to_string()))?, + }; + + if !can_write_to_dir(&disk_dir) { + return Err(CliError::Other(format!( + "no write permissions to disk IO test file directory: {}", + disk_dir.display() + ))); + } + + let client = super::speedtest::build_client()?; + + let all_cases = supported_infra_test_cases(); + let mut queued = filter_tests(&all_cases, args.test_config.test_cases.as_deref()); + if queued.is_empty() { + return Err(CliError::TestCaseNotSupported); + } + sort_tests(&mut queued); + + let start = Instant::now(); + let test_results = run_tests_with_timeout(&args, &queued, &disk_dir, &client, ct).await; + let elapsed = start.elapsed(); + + let score = calculate_score(&test_results); + + let mut res = TestCategoryResult::new(TestCategory::Infra); + res.targets.insert("local".to_string(), test_results); + res.execution_time = Some(CliDuration::new(elapsed)); + res.score = Some(score); + + if !args.test_config.quiet { + write_result_to_writer(&res, writer)?; + } + + if !args.test_config.output_json.is_empty() { + write_result_to_file(&res, args.test_config.output_json.as_ref()).await?; + } + + if args.test_config.publish { + let all = AllCategoriesResult { + infra: Some(res.clone()), + ..Default::default() + }; + publish_result_to_obol_api( + all, + &args.test_config.publish_addr, + &args.test_config.publish_private_key_file, + ) + .await?; + } + + Ok(res) +} /// Arguments for the infra test command. #[derive(Args, Clone, Debug)] @@ -42,18 +661,3 @@ pub struct TestInfraArgs { )] pub internet_test_servers_exclude: Vec, } - -/// Runs the infrastructure tests. -pub async fn run(_args: TestInfraArgs, _writer: &mut dyn Write) -> Result { - // TODO: Implement infra tests - // - DiskWriteSpeed - // - DiskWriteIOPS - // - DiskReadSpeed - // - DiskReadIOPS - // - AvailableMemory - // - TotalMemory - // - InternetLatency - // - InternetDownloadSpeed - // - InternetUploadSpeed - unimplemented!("infra test not yet implemented") -} diff --git a/crates/cli/src/commands/test/mod.rs b/crates/cli/src/commands/test/mod.rs index 84e7dd1a..2fbf3dee 100644 --- a/crates/cli/src/commands/test/mod.rs +++ b/crates/cli/src/commands/test/mod.rs @@ -14,6 +14,7 @@ pub mod helpers; pub mod infra; pub mod mev; pub mod peers; +pub(super) mod speedtest; pub mod validator; pub(crate) use helpers::*; @@ -89,10 +90,10 @@ fn list_test_cases(category: TestCategory) -> Vec { // supported_self_test_cases() vec![] } - TestCategory::Infra => { - // TODO: Extract from infra::supported_infra_test_cases() - vec![] - } + TestCategory::Infra => infra::supported_infra_test_cases() + .into_iter() + .map(|tc| tc.name.to_string()) + .collect(), TestCategory::All => { // TODO: Combine all test cases from all categories vec![] diff --git a/crates/cli/src/commands/test/speedtest.rs b/crates/cli/src/commands/test/speedtest.rs new file mode 100644 index 00000000..6f135f4d --- /dev/null +++ b/crates/cli/src/commands/test/speedtest.rs @@ -0,0 +1,335 @@ +//! Ookla Speedtest.net client for latency, download, and upload measurements. + +use std::time::{Duration, Instant}; + +use serde::Deserialize; + +use crate::error::{CliError, Result}; + +const SPEEDTEST_SERVERS_URL: &str = + "https://www.speedtest.net/api/js/servers?engine=js&https_functional=true&limit=10"; +const SPEEDTEST_SERVERS_FALLBACK_URL: &str = + "https://www.speedtest.net/speedtest-servers-static.php"; +const FETCH_PING_TIMEOUT: Duration = Duration::from_secs(4); +const PING_COUNT: u32 = 10; +const PING_INTERVAL: Duration = Duration::from_millis(200); +const SPEED_TEST_DURATION: Duration = Duration::from_secs(15); +// Matches Go's ulSizes[4]=1000: chunkSize = (1000*100-51)*10 +const UPLOAD_CHUNK_BYTES: usize = 999_490; + +fn speed_test_concurrency() -> usize { + match std::thread::available_parallelism() { + Ok(n) => n.get(), + Err(e) => { + tracing::warn!(error = %e, "failed to query CPU count, defaulting to 1 concurrent stream"); + 1 + } + } +} + +#[derive(Deserialize)] +struct OoklaServerResponse { + id: String, + name: String, + country: String, + url: String, + #[serde(default)] + distance: f64, +} + +#[derive(Deserialize)] +#[serde(rename = "settings")] +struct XmlServerList { + servers: XmlServersWrapper, +} + +#[derive(Deserialize)] +struct XmlServersWrapper { + server: Vec, +} + +#[derive(Deserialize)] +struct XmlServer { + #[serde(rename = "@url")] + url: String, + #[serde(rename = "@name")] + name: String, + #[serde(rename = "@country")] + country: String, + #[serde(rename = "@id")] + id: String, +} + +impl From for OoklaServerResponse { + fn from(s: XmlServer) -> Self { + Self { + id: s.id, + name: s.name, + country: s.country, + url: s.url, + distance: 0.0, + } + } +} + +pub(super) struct SpeedtestServer { + pub(super) id: String, + pub(super) name: String, + pub(super) country: String, + pub(super) distance: f64, + pub(super) latency: Duration, + pub(super) dl_speed_mbps: f64, + pub(super) ul_speed_mbps: f64, + url: String, +} + +impl SpeedtestServer { + fn from_response(r: OoklaServerResponse) -> Self { + Self { + id: r.id, + name: r.name, + country: r.country, + url: r.url, + distance: r.distance, + latency: Duration::ZERO, + dl_speed_mbps: 0.0, + ul_speed_mbps: 0.0, + } + } + + fn base_url(&self) -> &str { + match self.url.strip_suffix("upload.php") { + Some(base) => base, + None => { + tracing::warn!(url = %self.url, "Ookla server URL does not end in 'upload.php'; subsequent requests may fail"); + &self.url + } + } + } + + async fn quick_ping(&mut self, client: &reqwest::Client) -> Result<()> { + let latency_url = format!("{}latency.txt", self.base_url()); + let start = Instant::now(); + let response = client.get(&latency_url).send().await?; + // Read and discard the body so the connection is left in a clean state + // for the connection pool; dropping Response without reading closes the + // underlying TCP socket and corrupts pool state for subsequent requests. + let _ = response.bytes().await?; + self.latency = start.elapsed(); + Ok(()) + } + + pub(super) async fn ping_test(&mut self, client: &reqwest::Client) -> Result<()> { + let latency_url = format!("{}latency.txt", self.base_url()); + let mut samples = Vec::with_capacity(PING_COUNT as usize); + let mut ticker = tokio::time::interval(PING_INTERVAL); + for _ in 0..PING_COUNT { + ticker.tick().await; + let start = Instant::now(); + let response = client.get(&latency_url).send().await?; + let _ = response.bytes().await?; + samples.push(start.elapsed()); + } + let total: Duration = samples.iter().sum(); + self.latency = total + .checked_div(PING_COUNT) + .expect("PING_COUNT is non-zero"); + Ok(()) + } + + pub(super) async fn download_test(&mut self, client: &reqwest::Client) -> Result<()> { + let download_url = format!("{}random1000x1000.jpg", self.base_url()); + let start = Instant::now(); + let deadline = start + .checked_add(SPEED_TEST_DURATION) + .expect("deadline does not overflow"); + + // Go measures throughput via a Welford EWMA sampled every 50ms. Here we use + // total_bytes/elapsed, which is simpler but equally valid for a single + // measurement. + let mut set = tokio::task::JoinSet::new(); + for _ in 0..speed_test_concurrency() { + let client = client.clone(); + let url = download_url.clone(); + set.spawn(async move { + let mut bytes: usize = 0; + while Instant::now() < deadline { + let Ok(resp) = client.get(&url).send().await else { + break; + }; + if !resp.status().is_success() { + break; + } + if let Ok(body) = resp.bytes().await { + bytes = bytes + .checked_add(body.len()) + .expect("download byte count does not overflow"); + } + } + bytes + }); + } + + let total_bytes: usize = set.join_all().await.into_iter().sum(); + self.dl_speed_mbps = bytes_to_mbps(total_bytes, start.elapsed()); + Ok(()) + } + + pub(super) async fn upload_test(&mut self, client: &reqwest::Client) -> Result<()> { + let upload_url = self.url.clone(); + let start = Instant::now(); + let deadline = start + .checked_add(SPEED_TEST_DURATION) + .expect("deadline does not overflow"); + + let chunk = bytes::Bytes::from(vec![0u8; UPLOAD_CHUNK_BYTES]); + let mut set = tokio::task::JoinSet::new(); + for _ in 0..speed_test_concurrency() { + let client = client.clone(); + let url = upload_url.clone(); + let chunk = chunk.clone(); + set.spawn(async move { + let mut bytes: usize = 0; + while Instant::now() < deadline { + let Ok(resp) = client + .post(&url) + .header("Content-Type", "application/octet-stream") + .body(chunk.clone()) + .send() + .await + else { + break; + }; + if !resp.status().is_success() { + break; + } + let _ = resp.bytes().await; + bytes = bytes + .checked_add(UPLOAD_CHUNK_BYTES) + .expect("upload byte count does not overflow"); + } + bytes + }); + } + + let total_bytes: usize = set.join_all().await.into_iter().sum(); + self.ul_speed_mbps = bytes_to_mbps(total_bytes, start.elapsed()); + Ok(()) + } +} + +/// Builds a shared reqwest client configured for Ookla Speedtest servers. +pub(super) fn build_client() -> Result { + reqwest::Client::builder() + .user_agent("showwin/speedtest-go 1.7.10") + .build() + .map_err(|e| CliError::Other(format!("build HTTP client: {e}"))) +} + +async fn fetch_server_list(client: &reqwest::Client) -> Result> { + let response = client + .get(SPEEDTEST_SERVERS_URL) + .send() + .await + .map_err(|e| CliError::Other(format!("fetch Ookla servers: {e}")))?; + + if response.content_length() == Some(0) { + return fetch_server_list_xml(client).await; + } + + response + .json() + .await + .map_err(|e| CliError::Other(format!("fetch Ookla servers: {e}"))) +} + +async fn fetch_server_list_xml(client: &reqwest::Client) -> Result> { + let body = client + .get(SPEEDTEST_SERVERS_FALLBACK_URL) + .send() + .await + .map_err(|e| CliError::Other(format!("fetch Ookla servers (XML fallback): {e}")))? + .bytes() + .await + .map_err(|e| CliError::Other(format!("fetch Ookla servers (XML fallback): {e}")))?; + + let list: XmlServerList = quick_xml::de::from_reader(body.as_ref()) + .map_err(|e| CliError::Other(format!("parse Ookla servers XML: {e}")))?; + + Ok(list + .servers + .server + .into_iter() + .map(OoklaServerResponse::from) + .collect()) +} + +/// Fetches the Ookla server list, applies filters, pings all candidates +/// concurrently, and returns the lowest-latency reachable server. +pub(super) async fn fetch_best_server( + servers_only: &[String], + servers_exclude: &[String], + client: &reqwest::Client, +) -> Result { + let servers = fetch_server_list(client).await?; + + // Go bug parity: the original Go implementation (testinfra.go) appends both + // servers_only and servers_exclude filter results independently (union), so + // excluded servers can still appear if they also match servers_only. The Rust + // implementation correctly chains the filters as intersection, which is the + // intended behaviour. This intentional divergence from Go is kept. + let candidates: Vec<_> = servers + .into_iter() + .filter(|s| servers_only.is_empty() || servers_only.contains(&s.name)) + .filter(|s| !servers_exclude.contains(&s.name)) + .collect(); + + if candidates.is_empty() { + return Err(CliError::Other( + "fetch Ookla servers: no servers match the specified filters".to_string(), + )); + } + + let ping_futures: Vec<_> = candidates + .into_iter() + .map(|r| { + let client = client.clone(); + async move { + let mut server = SpeedtestServer::from_response(r); + let result = + tokio::time::timeout(FETCH_PING_TIMEOUT, server.quick_ping(&client)).await; + match result { + Ok(Ok(())) => Some(server), + _ => None, + } + } + }) + .collect(); + + let mut reachable: Vec = futures::future::join_all(ping_futures) + .await + .into_iter() + .flatten() + .collect(); + + reachable.sort_by_key(|s| s.latency); + reachable + .into_iter() + .next() + .ok_or_else(|| CliError::Other("find Ookla server: no reachable servers".to_string())) +} + +pub(super) fn bytes_to_mbps(bytes: usize, elapsed: Duration) -> f64 { + let secs = elapsed.as_secs_f64(); + if secs == 0.0 { + return 0.0; + } + + #[allow( + clippy::cast_precision_loss, + clippy::arithmetic_side_effects, + reason = "precision loss requires >8PB transferred; arithmetic overflow is impossible for realistic network speeds" + )] + let bytes: f64 = bytes as f64; + bytes * 8.0 / secs / 1_000_000.0 +} diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index c14d51f7..e52b0613 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -85,7 +85,7 @@ async fn run() -> std::result::Result<(), CliError> { TestCommands::Mev(args) => commands::test::mev::run(args, &mut stdout, ct) .await .map(|_| ()), - TestCommands::Infra(args) => commands::test::infra::run(args, &mut stdout) + TestCommands::Infra(args) => commands::test::infra::run(args, &mut stdout, ct) .await .map(|_| ()), TestCommands::All(args) => commands::test::all::run(*args, &mut stdout).await,