diff --git a/Cargo.lock b/Cargo.lock index 1c974deda5..f6529b9cd1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1897,6 +1897,12 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "home" version = "0.5.11" @@ -3843,6 +3849,17 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.9" @@ -4712,6 +4729,7 @@ dependencies = [ "directories", "flate2", "futures-util", + "hex", "httpmock", "indoc", "pathdiff", @@ -4722,6 +4740,8 @@ dependencies = [ "serde", "serde_json", "serde_yml", + "sha1", + "sha2", "tar", "tempfile", "test-log", diff --git a/Cargo.toml b/Cargo.toml index a8cf9ac103..22fd8fa824 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,7 @@ fspy_shared_unix = { path = "crates/fspy_shared_unix" } futures = "0.3.31" futures-core = "0.3.31" futures-util = "0.3.31" +hex = "0.4.3" httpmock = "0.7" indoc = "2.0.5" itertools = "0.14.0" @@ -77,6 +78,7 @@ serde = "1.0.219" serde_json = "1.0.140" serde_yml = "0.0.12" serial_test = "3.2.0" +sha1 = "0.10.6" sha2 = "0.10.9" shell-escape = "0.1.5" supports-color = "3.0.1" diff --git a/crates/vite_error/Cargo.toml b/crates/vite_error/Cargo.toml index 30d9ca3667..0caee467b8 100644 --- a/crates/vite_error/Cargo.toml +++ b/crates/vite_error/Cargo.toml @@ -21,6 +21,6 @@ serde_json = { workspace = true } serde_yml = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true } -wax = { workspace = true } vite_path = { workspace = true } vite_str = { workspace = true } +wax = { workspace = true } diff --git a/crates/vite_error/src/lib.rs b/crates/vite_error/src/lib.rs index 8dfaefd04b..3552470a3c 100644 --- a/crates/vite_error/src/lib.rs +++ b/crates/vite_error/src/lib.rs @@ -153,6 +153,15 @@ pub enum Error { #[error("User cancelled by Ctrl+C")] UserCancelled, + #[error("Hash mismatch: expected {expected}, got {actual}")] + HashMismatch { expected: Str, actual: Str }, + + #[error("Invalid hash format: {0}")] + InvalidHashFormat(Str), + + #[error("Unsupported hash algorithm: {0}")] + UnsupportedHashAlgorithm(Str), + #[error(transparent)] AnyhowError(#[from] anyhow::Error), } diff --git a/crates/vite_package_manager/Cargo.toml b/crates/vite_package_manager/Cargo.toml index 29574909f7..aa30f2ea97 100644 --- a/crates/vite_package_manager/Cargo.toml +++ b/crates/vite_package_manager/Cargo.toml @@ -13,6 +13,7 @@ backon = { workspace = true } directories = { workspace = true } flate2 = { workspace = true } futures-util = { workspace = true } +hex = { workspace = true } indoc = { workspace = true } pathdiff = { workspace = true } petgraph = { workspace = true, features = ["serde-1"] } @@ -23,6 +24,8 @@ serde = { workspace = true, features = ["derive"] } # use `preserve_order` feature to preserve the order of the fields in `package.json` serde_json = { workspace = true, features = ["preserve_order"] } serde_yml = { workspace = true } +sha1 = { workspace = true } +sha2 = { workspace = true } tar = { workspace = true } tempfile = { workspace = true } tokio = { workspace = true, features = ["full"] } diff --git a/crates/vite_package_manager/src/package_manager.rs b/crates/vite_package_manager/src/package_manager.rs index 4d0bb185f3..726e6cf360 100644 --- a/crates/vite_package_manager/src/package_manager.rs +++ b/crates/vite_package_manager/src/package_manager.rs @@ -15,7 +15,7 @@ use vite_str::Str; use crate::{ config::{get_cache_dir, get_npm_package_tgz_url, get_npm_package_version_url}, - request::{HttpClient, download_and_extract_tgz}, + request::{HttpClient, download_and_extract_tgz_with_hash}, shim, }; @@ -61,6 +61,7 @@ pub struct PackageManager { pub package_manager_type: PackageManagerType, pub package_name: Str, pub version: Str, + pub hash: Option, pub bin_name: Str, pub workspace_root: AbsolutePathBuf, pub install_dir: AbsolutePathBuf, @@ -87,7 +88,7 @@ impl PackageManagerBuilder { /// Detect the package manager from the current working directory. pub async fn build(self) -> Result { let workspace_root = find_workspace_root(&self.cwd)?; - let (package_manager_type, mut version) = + let (package_manager_type, mut version, mut hash) = get_package_manager_type_and_version(&workspace_root, self.package_manager_type)?; let mut package_name = package_manager_type.to_string(); @@ -96,6 +97,7 @@ impl PackageManagerBuilder { if version == "latest" { version = get_latest_version(package_manager_type).await?; should_update_package_manager_field = true; + hash = None; // Reset hash when fetching latest since hash is version-specific } // handle yarn >= 2.0.0 to use `@yarnpkg/cli-dist` as package name @@ -108,8 +110,13 @@ impl PackageManagerBuilder { } // only download the package manager if it's not already downloaded - let install_dir = - download_package_manager(package_manager_type, &package_name, &version).await?; + let install_dir = download_package_manager( + package_manager_type, + &package_name, + &version, + hash.as_deref(), + ) + .await?; if should_update_package_manager_field { // auto set `packageManager` field in package.json @@ -121,6 +128,7 @@ impl PackageManagerBuilder { package_manager_type, package_name: package_name.into(), version, + hash, bin_name: package_manager_type.to_string().into(), workspace_root: workspace_root.path.to_absolute_path_buf(), install_dir, @@ -258,18 +266,26 @@ pub fn find_workspace_root(original_cwd: &AbsolutePath) -> Result, -) -> Result<(PackageManagerType, Str), Error> { +) -> Result<(PackageManagerType, Str, Option), Error> { // check packageManager field in package.json let package_json_path = workspace_root.path.join("package.json"); if let Some(file) = open_exists_file(&package_json_path)? { let package_json: PackageJson = serde_json::from_reader(BufReader::new(&file))?; if !package_json.package_manager.is_empty() - && let Some((name, version)) = package_json.package_manager.split_once('@') + && let Some((name, version_with_hash)) = package_json.package_manager.split_once('@') { + // Parse version and optional hash (format: version+sha512.hash) + let (version, hash) = if let Some((ver, hash_part)) = version_with_hash.split_once("+") + { + (ver, Some(hash_part.into())) + } else { + (version_with_hash, None) + }; + // check if the version is a valid semver semver::Version::parse(version).map_err(|_| Error::PackageManagerVersionInvalid { name: name.into(), @@ -277,9 +293,9 @@ fn get_package_manager_type_and_version( package_json_path: package_json_path.to_absolute_path_buf(), })?; match name { - "pnpm" => return Ok((PackageManagerType::Pnpm, version.into())), - "yarn" => return Ok((PackageManagerType::Yarn, version.into())), - "npm" => return Ok((PackageManagerType::Npm, version.into())), + "pnpm" => return Ok((PackageManagerType::Pnpm, version.into(), hash)), + "yarn" => return Ok((PackageManagerType::Yarn, version.into(), hash)), + "npm" => return Ok((PackageManagerType::Npm, version.into(), hash)), _ => return Err(Error::UnsupportedPackageManager(name.into())), } } @@ -290,43 +306,43 @@ fn get_package_manager_type_and_version( let version = Str::from("latest"); // if pnpm-workspace.yaml exists, use pnpm@latest if matches!(workspace_root.workspace_file, WorkspaceFile::PnpmWorkspaceYaml(_)) { - return Ok((PackageManagerType::Pnpm, version)); + return Ok((PackageManagerType::Pnpm, version, None)); } // if pnpm-lock.yaml exists, use pnpm@latest let pnpm_lock_yaml_path = workspace_root.path.join("pnpm-lock.yaml"); if is_exists_file(&pnpm_lock_yaml_path)? { - return Ok((PackageManagerType::Pnpm, version)); + return Ok((PackageManagerType::Pnpm, version, None)); } // if yarn.lock or .yarnrc.yml exists, use yarn@latest let yarn_lock_path = workspace_root.path.join("yarn.lock"); let yarnrc_yml_path = workspace_root.path.join(".yarnrc.yml"); if is_exists_file(&yarn_lock_path)? || is_exists_file(&yarnrc_yml_path)? { - return Ok((PackageManagerType::Yarn, version)); + return Ok((PackageManagerType::Yarn, version, None)); } // if package-lock.json exists, use npm@latest let package_lock_json_path = workspace_root.path.join("package-lock.json"); if is_exists_file(&package_lock_json_path)? { - return Ok((PackageManagerType::Npm, version)); + return Ok((PackageManagerType::Npm, version, None)); } // if pnpmfile.cjs exists, use pnpm@latest let pnpmfile_cjs_path = workspace_root.path.join("pnpmfile.cjs"); if is_exists_file(&pnpmfile_cjs_path)? { - return Ok((PackageManagerType::Pnpm, version)); + return Ok((PackageManagerType::Pnpm, version, None)); } // if yarn.config.cjs exists, use yarn@latest (yarn 2.0+) let yarn_config_cjs_path = workspace_root.path.join("yarn.config.cjs"); if is_exists_file(&yarn_config_cjs_path)? { - return Ok((PackageManagerType::Yarn, version)); + return Ok((PackageManagerType::Yarn, version, None)); } // if default is specified, use it if let Some(default) = default { - return Ok((default, version)); + return Ok((default, version, None)); } // unrecognized package manager, let user specify the package manager @@ -370,6 +386,7 @@ async fn download_package_manager( package_manager_type: PackageManagerType, package_name: &str, version: &str, + expected_hash: Option<&str>, ) -> Result { let tgz_url = get_npm_package_tgz_url(package_name, version); let cache_dir = get_cache_dir()?; @@ -395,21 +412,23 @@ async fn download_package_manager( tokio::fs::create_dir_all(parent_dir).await?; let target_dir_tmp = tempfile::tempdir_in(parent_dir)?.path().to_path_buf(); - download_and_extract_tgz(&tgz_url, &target_dir_tmp).await.map_err(|err| { - // status 404 means the version is not found, convert to PackageManagerVersionNotFound error - if let Error::ReqwestError(e) = &err - && let Some(status) = e.status() - && status == reqwest::StatusCode::NOT_FOUND - { - Error::PackageManagerVersionNotFound { - name: package_manager_type.to_string().into(), - version: version.into(), - url: tgz_url.into(), + download_and_extract_tgz_with_hash(&tgz_url, &target_dir_tmp, expected_hash).await.map_err( + |err| { + // status 404 means the version is not found, convert to PackageManagerVersionNotFound error + if let Error::ReqwestError(e) = &err + && let Some(status) = e.status() + && status == reqwest::StatusCode::NOT_FOUND + { + Error::PackageManagerVersionNotFound { + name: package_manager_type.to_string().into(), + version: version.into(), + url: tgz_url.into(), + } + } else { + err } - } else { - err - } - })?; + }, + )?; // rename $target_dir_tmp/package to $target_dir_tmp/{bin_name} tracing::debug!("Rename package dir to {}", bin_name); @@ -988,7 +1007,201 @@ mod tests { } #[tokio::test] - #[cfg(not(windows))] // FIXME + async fn test_parse_package_manager_with_hash() { + let temp_dir = create_temp_dir(); + let temp_dir_path = AbsolutePathBuf::new(temp_dir.path().to_path_buf()).unwrap(); + + // Test with sha512 hash + let package_content = r#"{"name": "test-package", "packageManager": "yarn@1.22.22+sha512.a6b2f7906b721bba3d67d4aff083df04dad64c399707841b7acf00f6b133b7ac24255f2652fa22ae3534329dc6180534e98d17432037ff6fd140556e2bb3137e"}"#; + create_package_json(&temp_dir_path, package_content); + + let workspace_root = find_workspace_root(&temp_dir_path).unwrap(); + let (pm_type, version, hash) = + get_package_manager_type_and_version(&workspace_root, None).unwrap(); + + assert_eq!(pm_type, PackageManagerType::Yarn); + assert_eq!(version, "1.22.22"); + assert!(hash.is_some()); + assert_eq!( + hash.unwrap(), + "sha512.a6b2f7906b721bba3d67d4aff083df04dad64c399707841b7acf00f6b133b7ac24255f2652fa22ae3534329dc6180534e98d17432037ff6fd140556e2bb3137e" + ); + } + + #[tokio::test] + async fn test_parse_package_manager_with_sha1_hash() { + let temp_dir = create_temp_dir(); + let temp_dir_path = AbsolutePathBuf::new(temp_dir.path().to_path_buf()).unwrap(); + + // Test with sha1 hash + let package_content = r#"{"name": "test-package", "packageManager": "npm@10.5.0+sha1.abcd1234567890abcdef1234567890abcdef1234"}"#; + create_package_json(&temp_dir_path, package_content); + + let workspace_root = find_workspace_root(&temp_dir_path).unwrap(); + let (pm_type, version, hash) = + get_package_manager_type_and_version(&workspace_root, None).unwrap(); + + assert_eq!(pm_type, PackageManagerType::Npm); + assert_eq!(version, "10.5.0"); + assert!(hash.is_some()); + assert_eq!(hash.unwrap(), "sha1.abcd1234567890abcdef1234567890abcdef1234"); + } + + #[tokio::test] + async fn test_parse_package_manager_with_sha224_hash() { + let temp_dir = create_temp_dir(); + let temp_dir_path = AbsolutePathBuf::new(temp_dir.path().to_path_buf()).unwrap(); + + // Test with sha224 hash + let package_content = r#"{"name": "test-package", "packageManager": "pnpm@8.15.0+sha224.1234567890abcdef1234567890abcdef1234567890abcdef12345678"}"#; + create_package_json(&temp_dir_path, package_content); + + let workspace_root = find_workspace_root(&temp_dir_path).unwrap(); + let (pm_type, version, hash) = + get_package_manager_type_and_version(&workspace_root, None).unwrap(); + + assert_eq!(pm_type, PackageManagerType::Pnpm); + assert_eq!(version, "8.15.0"); + assert!(hash.is_some()); + assert_eq!( + hash.unwrap(), + "sha224.1234567890abcdef1234567890abcdef1234567890abcdef12345678" + ); + } + + #[tokio::test] + async fn test_parse_package_manager_with_sha256_hash() { + let temp_dir = create_temp_dir(); + let temp_dir_path = AbsolutePathBuf::new(temp_dir.path().to_path_buf()).unwrap(); + + // Test with sha256 hash + let package_content = r#"{"name": "test-package", "packageManager": "yarn@4.0.0+sha256.1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"}"#; + create_package_json(&temp_dir_path, package_content); + + let workspace_root = find_workspace_root(&temp_dir_path).unwrap(); + let (pm_type, version, hash) = + get_package_manager_type_and_version(&workspace_root, None).unwrap(); + + assert_eq!(pm_type, PackageManagerType::Yarn); + assert_eq!(version, "4.0.0"); + assert!(hash.is_some()); + assert_eq!( + hash.unwrap(), + "sha256.1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" + ); + } + + #[tokio::test] + async fn test_parse_package_manager_without_hash() { + let temp_dir = create_temp_dir(); + let temp_dir_path = AbsolutePathBuf::new(temp_dir.path().to_path_buf()).unwrap(); + + // Test without hash + let package_content = r#"{"name": "test-package", "packageManager": "pnpm@8.15.0"}"#; + create_package_json(&temp_dir_path, package_content); + + let workspace_root = find_workspace_root(&temp_dir_path).unwrap(); + let (pm_type, version, hash) = + get_package_manager_type_and_version(&workspace_root, None).unwrap(); + + assert_eq!(pm_type, PackageManagerType::Pnpm); + assert_eq!(version, "8.15.0"); + assert!(hash.is_none()); + } + + #[tokio::test] + async fn test_download_success_package_manager_with_hash() { + use std::process::Command; + + let temp_dir = create_temp_dir(); + let temp_dir_path = AbsolutePathBuf::new(temp_dir.path().to_path_buf()).unwrap(); + let package_content = r#"{"name": "test-package", "packageManager": "yarn@1.22.22+sha512.a6b2f7906b721bba3d67d4aff083df04dad64c399707841b7acf00f6b133b7ac24255f2652fa22ae3534329dc6180534e98d17432037ff6fd140556e2bb3137e"}"#; + create_package_json(&temp_dir_path, package_content); + + let result = PackageManager::builder(temp_dir_path) + .build() + .await + .expect("Should detect yarn with version and hash"); + assert_eq!(result.bin_name, "yarn"); + + // check shim files + let bin_prefix = result.get_bin_prefix(); + assert!(is_exists_file(&bin_prefix.join("yarn.js")).unwrap()); + assert!(is_exists_file(&bin_prefix.join("yarn")).unwrap()); + assert!(is_exists_file(&bin_prefix.join("yarn.cmd")).unwrap()); + assert!(is_exists_file(&bin_prefix.join("yarn.ps1")).unwrap()); + assert!(is_exists_file(&bin_prefix.join("yarnpkg")).unwrap()); + assert!(is_exists_file(&bin_prefix.join("yarnpkg.cmd")).unwrap()); + assert!(is_exists_file(&bin_prefix.join("yarnpkg.ps1")).unwrap()); + + // run pnpm --version + let mut paths = + env::split_paths(&env::var_os("PATH").unwrap_or_default()).collect::>(); + paths.insert(0, bin_prefix.into_path_buf()); + let mut cmd = "yarn"; + if cfg!(windows) { + cmd = "yarn.cmd"; + } + let output = Command::new(cmd) + .arg("--version") + .env("PATH", env::join_paths(paths).unwrap()) + .output() + .expect("Failed to run yarn"); + // println!("pnpm --version: {:?}", output); + assert!(output.status.success()); + assert_eq!(String::from_utf8_lossy(&output.stdout).trim(), "1.22.22"); + } + + #[tokio::test] + async fn test_download_failed_package_manager_with_hash() { + let temp_dir = create_temp_dir(); + let temp_dir_path = AbsolutePathBuf::new(temp_dir.path().to_path_buf()).unwrap(); + let package_content = r#"{"name": "test-package", "packageManager": "yarn@1.22.21+sha512.a6b2f7906b721bba3d67d4aff083df04dad64c399707841b7acf00f6b133b7ac24255f2652fa22ae3534329dc6180534e98d17432037ff6fd140556e2bb3137e"}"#; + create_package_json(&temp_dir_path, package_content); + + let result = PackageManager::builder(temp_dir_path).build().await; + assert!(result.is_err()); + // Check if it's the expected error type + if let Err(Error::HashMismatch { expected, actual }) = result { + assert_eq!( + expected, + "sha512.a6b2f7906b721bba3d67d4aff083df04dad64c399707841b7acf00f6b133b7ac24255f2652fa22ae3534329dc6180534e98d17432037ff6fd140556e2bb3137e" + ); + assert_eq!( + actual, + "sha512.ca75da26c00327d26267ce33536e5790f18ebd53266796fbb664d2a4a5116308042dd8ee7003b276a20eace7d3c5561c3577bdd71bcb67071187af124779620a" + ); + } else { + panic!("Expected HashMismatch error"); + } + } + + #[tokio::test] + async fn test_download_success_package_manager_with_sha1_and_sha224() { + let temp_dir = create_temp_dir(); + let temp_dir_path = AbsolutePathBuf::new(temp_dir.path().to_path_buf()).unwrap(); + let package_content = r#"{"name": "test-package", "packageManager": "yarn@1.22.20+sha1.167c8ab8d9c8c3826d3725d9579aaea8b47a2b18"}"#; + create_package_json(&temp_dir_path, package_content); + + let result = PackageManager::builder(temp_dir_path) + .build() + .await + .expect("Should detect yarn with version and hash"); + assert_eq!(result.bin_name, "yarn"); + + let temp_dir = create_temp_dir(); + let temp_dir_path = AbsolutePathBuf::new(temp_dir.path().to_path_buf()).unwrap(); + let package_content = r#"{"name": "test-package", "packageManager": "pnpm@4.11.6+sha224.7783c4b01916b7a69e6ff05d328df6f83cb7f127e9c96be88739386d"}"#; + create_package_json(&temp_dir_path, package_content); + + let result = PackageManager::builder(temp_dir_path) + .build() + .await + .expect("Should detect pnpm with version and hash"); + assert_eq!(result.bin_name, "pnpm"); + } + + #[tokio::test] async fn test_detect_package_manager_with_yarn_package_manager_field() { use std::process::Command; @@ -1022,8 +1235,16 @@ mod tests { assert!(is_exists_file(&bin_prefix.join("yarnpkg.ps1")).unwrap()); // run yarn --version - let output = Command::new(bin_prefix.join("yarn").into_path_buf()) + let mut cmd = "yarn"; + if cfg!(windows) { + cmd = "yarn.cmd"; + } + let mut paths = + env::split_paths(&env::var_os("PATH").unwrap_or_default()).collect::>(); + paths.insert(0, bin_prefix.into_path_buf()); + let output = Command::new(cmd) .arg("--version") + .env("PATH", env::join_paths(paths).unwrap()) .output() .expect("Failed to run yarn"); assert!(output.status.success()); @@ -1195,7 +1416,8 @@ mod tests { #[tokio::test] async fn test_download_package_manager() { let result = - download_package_manager(PackageManagerType::Yarn, "@yarnpkg/cli-dist", "4.9.2").await; + download_package_manager(PackageManagerType::Yarn, "@yarnpkg/cli-dist", "4.9.2", None) + .await; assert!(result.is_ok()); let target_dir = result.unwrap(); println!("result: {:?}", target_dir); @@ -1204,7 +1426,8 @@ mod tests { // again should skip download let result = - download_package_manager(PackageManagerType::Yarn, "@yarnpkg/cli-dist", "4.9.2").await; + download_package_manager(PackageManagerType::Yarn, "@yarnpkg/cli-dist", "4.9.2", None) + .await; assert!(result.is_ok()); let target_dir = result.unwrap(); assert!(is_exists_file(&target_dir.join("bin/yarn")).unwrap()); diff --git a/crates/vite_package_manager/src/request.rs b/crates/vite_package_manager/src/request.rs index 39e8e1fd8f..559822c4e7 100644 --- a/crates/vite_package_manager/src/request.rs +++ b/crates/vite_package_manager/src/request.rs @@ -5,6 +5,8 @@ use flate2::read::GzDecoder; use futures_util::stream::StreamExt; use reqwest::Response; use serde::de::DeserializeOwned; +use sha1::Sha1; +use sha2::{Digest, Sha224, Sha256, Sha512}; use tar::Archive; use tokio::{fs, io::AsyncWriteExt}; @@ -133,23 +135,28 @@ fn extract_tgz(tgz_file: impl AsRef, target_dir: impl AsRef) -> Resu Ok(()) } -/// Download tgz file from url and extract it to the target directory. +/// Download a tgz file from a URL and extract it to a target directory with optional hash verification. /// /// # Arguments -/// -/// * `url` - The url of the tgz file. +/// * `url` - The URL of the tgz file to download. /// * `target_dir` - The directory to extract the tgz file to. +/// * `expected_hash` - Optional expected hash in format "algorithm.hash" (e.g., "sha512.abcd1234...") /// /// # Returns -/// -/// * `Ok(())` - If the tgz file is downloaded and extracted successfully. -/// * `Err(e)` - If the tgz file is not downloaded or extracted successfully. -pub async fn download_and_extract_tgz( +/// * `Ok(())` - If the tgz file is downloaded, verified (if hash provided) and extracted successfully. +/// * `Err(e)` - If the tgz file is not downloaded, verified or extracted successfully. +pub async fn download_and_extract_tgz_with_hash( url: &str, target_dir: impl AsRef, + expected_hash: Option<&str>, ) -> Result<(), Error> { let target_dir = target_dir.as_ref().to_path_buf(); - tracing::debug!("Start download and extract {} to {:?}", url, target_dir); + tracing::debug!( + "Start download and extract {} to {:?}, expected hash: {:?}", + url, + target_dir, + expected_hash + ); // Create target directory fs::create_dir_all(&target_dir).await?; @@ -159,6 +166,11 @@ pub async fn download_and_extract_tgz( let client = HttpClient::new(); client.download_file(url, &tgz_file).await?; + // Verify hash if provided + if let Some(expected_hash) = expected_hash { + verify_file_hash(&tgz_file, expected_hash).await?; + } + // Extract the tgz file to the target directory let tgz_file_for_extract = tgz_file.clone(); let target_dir_for_extract = target_dir.clone(); @@ -169,11 +181,69 @@ pub async fn download_and_extract_tgz( // Remove the temp file fs::remove_file(&tgz_file).await?; - tracing::debug!("Download and extract finished"); Ok(()) } +/// Computes the hash of the given content using the specified digest algorithm. +/// +/// # Type Parameters +/// * `D` - A type that implements the [`Digest`] trait, such as `Sha256`, `Sha512`, etc. +/// +/// # Arguments +/// * `content` - The byte slice to hash. +/// +/// # Returns +/// A hex-encoded string representing the computed digest. +fn compute_hash(content: &[u8]) -> String { + let mut hasher = D::new(); + hasher.update(content); + hex::encode(hasher.finalize()) +} + +/// Verify the hash of a file against an expected hash. +/// +/// # Arguments +/// * `file_path` - Path to the file to verify +/// * `expected_hash` - Expected hash in format "algorithm.hash" (e.g., "sha512.abcd1234...") +/// +/// # Returns +/// * `Ok(())` - If the file hash matches the expected hash +/// * `Err(Error::HashMismatch)` - If the file hash doesn't match +pub async fn verify_file_hash( + file_path: impl AsRef, + expected_hash: &str, +) -> Result<(), Error> { + let file_path = file_path.as_ref(); + let content = fs::read(file_path).await?; + + // Parse the hash format (e.g., "sha512.abcd1234..." or "sha256.abcd1234...") + let (algorithm, expected_hex) = if let Some((algo, hash)) = expected_hash.split_once('.') { + (algo, hash) + } else { + return Err(Error::InvalidHashFormat(expected_hash.into())); + }; + + // Calculate the actual hash based on the algorithm + let actual_hex = match algorithm { + "sha512" => compute_hash::(&content), + "sha256" => compute_hash::(&content), + "sha224" => compute_hash::(&content), + "sha1" => compute_hash::(&content), + _ => return Err(Error::UnsupportedHashAlgorithm(algorithm.into())), + }; + + if actual_hex != expected_hex { + return Err(Error::HashMismatch { + expected: expected_hash.into(), + actual: format!("{}.{}", algorithm, actual_hex).into(), + }); + } + + tracing::debug!("Hash verification successful"); + Ok(()) +} + #[cfg(test)] mod tests { use std::fs; @@ -442,7 +512,7 @@ mod tests { }); let url = format!("{}/test-package.tgz", server.base_url()); - let result = download_and_extract_tgz(&url, &target_dir).await; + let result = download_and_extract_tgz_with_hash(&url, &target_dir, None).await; assert!(result.is_ok(), "Failed to download and extract: {:?}", result); assert!(target_dir.join("package/bin/yarn").exists()); @@ -451,6 +521,63 @@ mod tests { // TempDir automatically cleans up when it goes out of scope } + #[tokio::test] + async fn test_verify_file_hash_sha1() { + use tokio::io::AsyncWriteExt; + + let temp_dir = TempDir::new().unwrap(); + let test_file = temp_dir.path().join("test.txt"); + + // Write test content + let content = b"Hello, World!"; + let mut file = tokio::fs::File::create(&test_file).await.unwrap(); + file.write_all(content).await.unwrap(); + + // Calculate expected SHA1 + use sha1::Sha1; + use sha2::Digest; + let mut hasher = Sha1::new(); + hasher.update(content); + let expected_hash = format!("sha1.{:x}", hasher.finalize()); + + // Test successful verification + let result = verify_file_hash(&test_file, &expected_hash).await; + assert!(result.is_ok()); + + // Test failed verification + let wrong_hash = "sha1.0000000000000000000000000000000000000000"; + let result = verify_file_hash(&test_file, wrong_hash).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_verify_file_hash_sha224() { + use tokio::io::AsyncWriteExt; + + let temp_dir = TempDir::new().unwrap(); + let test_file = temp_dir.path().join("test.txt"); + + // Write test content + let content = b"Test content for SHA224"; + let mut file = tokio::fs::File::create(&test_file).await.unwrap(); + file.write_all(content).await.unwrap(); + + // Calculate expected SHA224 + use sha2::{Digest, Sha224}; + let mut hasher = Sha224::new(); + hasher.update(content); + let expected_hash = format!("sha224.{:x}", hasher.finalize()); + + // Test successful verification + let result = verify_file_hash(&test_file, &expected_hash).await; + assert!(result.is_ok()); + + // Test failed verification + let wrong_hash = "sha224.00000000000000000000000000000000000000000000000000000000"; + let result = verify_file_hash(&test_file, wrong_hash).await; + assert!(result.is_err()); + } + #[tokio::test] async fn test_http_client_download_with_404_error() { let server = MockServer::start();