diff --git a/.gitignore b/.gitignore index c9474082..cb7ddcea 100644 --- a/.gitignore +++ b/.gitignore @@ -172,3 +172,6 @@ gtdb-rs214-reps.k31_0.9995_pretrained/ # added by mahmudhera src/cpp/main.o .gitignore +src/cpp/utils.o +src/yacht/run_yacht_train_core +.vscode/settings.json diff --git a/MANIFEST.in b/MANIFEST.in index 2774067a..6df8f35d 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,7 @@ # Include C++ source files include src/cpp/*.cpp include src/cpp/*.hpp +include src/cpp/*.h # Include other necessary files include LICENSE.txt diff --git a/Makefile b/Makefile index 8248777f..a7dc5995 100644 --- a/Makefile +++ b/Makefile @@ -7,32 +7,29 @@ SRC_DIR = src/cpp BIN_DIR = src/yacht # Source files -SRC_FILES = $(SRC_DIR)/main.cpp +SRC_FILES = $(SRC_DIR)/yacht_train_core.cpp $(SRC_DIR)/utils.cpp $(SRC_DIR)/yacht_run_compute_similarity.cpp # Object files OBJ_FILES = $(SRC_FILES:.cpp=.o) -# Target executable -TARGET = $(BIN_DIR)/run_yacht_train_core +# Target executables +TARGET1 = $(BIN_DIR)/yacht_train_core +TARGET2 = $(BIN_DIR)/yacht_run_compute_similarity -# Default target -all: $(TARGET) +# Build rules +all: $(TARGET1) $(TARGET2) -# Create the bin directory if it doesn't exist -$(BIN_DIR): - echo "Creating directory: $(BIN_DIR)" - mkdir -p $(BIN_DIR) +$(TARGET1): $(SRC_DIR)/yacht_train_core.cpp $(SRC_DIR)/utils.cpp + $(CXX) $(CXXFLAGS) -pthread $(SRC_DIR)/yacht_train_core.cpp $(SRC_DIR)/utils.cpp -o $(TARGET1) -# build the object files -$(OBJ_FILES): %.o: %.cpp - echo "Compiling: $<" - $(CXX) $(CXXFLAGS) -c $< -o $@ +$(TARGET2): $(SRC_DIR)/yacht_run_compute_similarity.cpp $(SRC_DIR)/utils.cpp + $(CXX) $(CXXFLAGS) -pthread $(SRC_DIR)/yacht_run_compute_similarity.cpp $(SRC_DIR)/utils.cpp -o $(TARGET2) -# build the target executable -$(TARGET): $(OBJ_FILES) | $(BIN_DIR) - echo "Linking to create executable: $(TARGET)" - $(CXX) $(CXXFLAGS) $(OBJ_FILES) -o $(TARGET) -lpthread +%.o: %.cpp + $(CXX) $(CXXFLAGS) -pthread -c $< -o $@ -# clean up clean: - rm -f $(OBJ_FILES) $(TARGET) \ No newline at end of file + rm -f $(OBJ_FILES) $(TARGET1) $(TARGET2) + +.PHONY: all clean + diff --git a/build_windows.bat b/build_windows.bat index e8eef706..6fd616d1 100644 --- a/build_windows.bat +++ b/build_windows.bat @@ -1,21 +1,41 @@ @echo off +setlocal enabledelayedexpansion -REM Set up paths for directories +rem Set up paths for directories set SRC_DIR=src\cpp set BIN_DIR=src\yacht -REM Create bin directory if it doesn't exist +rem Create bin directory if it doesn't exist if not exist %BIN_DIR% ( mkdir %BIN_DIR% ) -REM Compile the main.cpp file using g++ from MinGW or another suitable compiler -g++ -std=c++17 -Wsign-compare -Wall -O3 -o %BIN_DIR%\run_yacht_train_core.exe %SRC_DIR%\main.cpp +rem Compile source files into object files +set OBJ_FILES= +for %%f in (%SRC_DIR%\*.cpp) do ( + g++ -std=c++17 -Wall -O3 -Wsign-compare -c %%f -o %%~nf.o + if %errorlevel% neq 0 ( + echo Compilation failed for %%f! + exit /b %errorlevel% + ) + set OBJ_FILES=!OBJ_FILES! %%~nf.o +) + +rem Create yacht_train_core.exe +g++ %OBJ_FILES% -o %BIN_DIR%\yacht_train_core.exe +if %errorlevel% neq 0 ( + echo Linking failed for yacht_train_core.exe! + exit /b %errorlevel% +) -REM Check if compilation succeeded +rem Create yacht_run_compute_similarity.exe +g++ %OBJ_FILES% -o %BIN_DIR%\yacht_run_compute_similarity.exe if %errorlevel% neq 0 ( - echo Compilation failed! + echo Linking failed for yacht_run_compute_similarity.exe! exit /b %errorlevel% ) -echo Compilation successful. Executable created at %BIN_DIR%\run_yacht_train_core.exe \ No newline at end of file +rem Clean up object files +del *.o + +echo Compilation successful. Executables created in %BIN_DIR%. \ No newline at end of file diff --git a/conda_recipe/meta.yaml b/conda_recipe/meta.yaml index d3344d87..538cee21 100644 --- a/conda_recipe/meta.yaml +++ b/conda_recipe/meta.yaml @@ -1,4 +1,4 @@ -{% set version = "1.3.0" %} +{% set version = "1.4.0" %} package: name: yacht @@ -6,11 +6,10 @@ package: source: url: https://github.com/KoslickiLab/YACHT/releases/download/v{{ version }}/yacht-{{ version }}.tar.gz - sha256: 68d272daeb70ed7390aa2d468934dc4bf0aa9a021f99fe99847b8a664e8ac8cf + sha256: 3558abd6d1084f0679ffbd8b1a8592ec7e0e642e201b5fcc0e34eaa62ae7e705 build: number: 0 - skip: True # [osx] script: "{{ PYTHON }} -m pip install . --no-deps --no-build-isolation --no-cache-dir -vvv" run_exports: - {{ pin_subpackage('yacht') }} @@ -52,7 +51,6 @@ requirements: - pytaxonkit - openpyxl - ruff - - sourmash_plugin_branchwater test: commands: diff --git a/env/yacht_env.yml b/env/yacht_env.yml index e69a9716..0b04259b 100644 --- a/env/yacht_env.yml +++ b/env/yacht_env.yml @@ -20,7 +20,6 @@ dependencies: - pytaxonkit - requests - pip - - sourmash_plugin_branchwater - pip: - openpyxl - ruff \ No newline at end of file diff --git a/setup.py b/setup.py index b32557fd..fcc905dc 100644 --- a/setup.py +++ b/setup.py @@ -30,15 +30,24 @@ def run(self): print(f"Error during Unix compilation: {e}") raise e - # Move the compiled binary to the correct location for packaging - compiled_binary = os.path.join('src', 'yacht', 'run_yacht_train_core') - if os.path.exists(compiled_binary): + # Move the compiled binary files to the correct location for packaging + compiled_binary1 = os.path.join('src', 'yacht', 'yacht_train_core') + compiled_binary2 = os.path.join('src', 'yacht', 'yacht_run_compute_similarity') + if os.path.exists(compiled_binary1): destination = os.path.join(self.build_lib, 'yacht') os.makedirs(destination, exist_ok=True) - shutil.move(compiled_binary, destination) + shutil.move(compiled_binary1, destination) else: print("Compiled binary not found after build step.") - raise FileNotFoundError("The executable 'run_yacht_train_core' was not generated successfully.") + raise FileNotFoundError("The executable 'yacht_train_core' was not generated successfully.") + + if os.path.exists(compiled_binary2): + destination = os.path.join(self.build_lib, 'yacht') + os.makedirs(destination, exist_ok=True) + shutil.move(compiled_binary2, destination) + else: + print("Compiled binary not found after build step.") + raise FileNotFoundError("The executable 'yacht_run_compute_similarity' was not generated successfully.") # Run the usual build_ext logic (necessary to continue with setuptools) super().run() diff --git a/src/cpp/main.cpp b/src/cpp/main.cpp deleted file mode 100644 index 39fc606e..00000000 --- a/src/cpp/main.cpp +++ /dev/null @@ -1,499 +0,0 @@ -/* - * Author: Mahmudur Rahman Hera (mahmudhera93@gmail.com) - * Date: November 1, 2024 - * Description: yacht train core using indexing of sketches to do genome comparison - */ - -#include "argparse.hpp" -#include "json.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace std; -using json = nlohmann::json; - - -struct Arguments { - string file_list; - string working_directory; - string output_filename; - int number_of_threads; - int num_of_passes; - double containment_threshold; -}; - - -typedef Arguments Arguments; -typedef unsigned long long int hash_t; - - -Arguments arguments; -std::vector sketch_names; -uint num_sketches; -vector> sketches; -vector> genome_id_size_pairs; -unordered_map> hash_index; -int count_empty_sketch = 0; -mutex mutex_count_empty_sketch; -vector empty_sketch_ids; -int ** intersectionMatrix; -vector> similars; - - - -vector read_min_hashes(const string& json_filename) { - - // Open the JSON file - ifstream inputFile(json_filename); - - // Check if the file is open - if (!inputFile.is_open()) { - cerr << "Could not open the file!" << endl; - return {}; - } - - // Parse the JSON data - json jsonData; - inputFile >> jsonData; - - // Access and print values - vector min_hashes = jsonData[0]["signatures"][0]["mins"]; - - // Close the file - inputFile.close(); - - return min_hashes; -} - - - - -void read_sketches_one_chunk(int start_index, int end_index) { - for (int i = start_index; i < end_index; i++) { - auto min_hashes_genome_name = read_min_hashes(sketch_names[i]); - sketches[i] = min_hashes_genome_name; - if (sketches[i].size() == 0) { - mutex_count_empty_sketch.lock(); - count_empty_sketch++; - empty_sketch_ids.push_back(i); - mutex_count_empty_sketch.unlock(); - } - genome_id_size_pairs[i] = {i, sketches[i].size()}; - } -} - - - -void read_sketches() { - for (uint i = 0; i < num_sketches; i++) { - sketches.push_back( vector() ); - genome_id_size_pairs.push_back({-1, 0}); - } - - int num_threads = arguments.number_of_threads; - - int chunk_size = num_sketches / num_threads; - vector threads; - for (int i = 0; i < num_threads; i++) { - int start_index = i * chunk_size; - int end_index = (i == num_threads - 1) ? num_sketches : (i + 1) * chunk_size; - threads.push_back(thread(read_sketches_one_chunk, start_index, end_index)); - } - for (int i = 0; i < num_threads; i++) { - threads[i].join(); - } - -} - - -void get_sketch_names(const std::string& filelist) { - // the filelist is a file, where each line is a path to a sketch file - std::ifstream file(filelist); - if (!file.is_open()) { - std::cerr << "Could not open the filelist: " << filelist << std::endl; - return; - } - std::string line; - while (std::getline(file, line)) { - sketch_names.push_back(line); - } - num_sketches = sketch_names.size(); -} - - -void parse_arguments(int argc, char *argv[]) { - - argparse::ArgumentParser parser("yacht train using indexing of sketches"); - - parser.add_argument("file_list") - .help("file containing list of files to be processed") - .store_into(arguments.file_list); - parser.add_argument("working_directory") - .help("working directory (where temp files are generated)") - .store_into(arguments.working_directory); - parser.add_argument("output_filename") - .help("output filename (where the reduced ref filenames will be written)") - .store_into(arguments.output_filename); - parser.add_argument("-t", "--threads") - .help("number of threads") - .scan<'i', int>() - .default_value(1) - .store_into(arguments.number_of_threads); - parser.add_argument("-p", "--passes") - .help("number of passes") - .scan<'i', int>() - .default_value(1) - .store_into(arguments.num_of_passes); - parser.add_argument("-c", "--containment_threshold") - .help("containment threshold") - .scan<'g', double>() - .default_value(0.9) - .store_into(arguments.containment_threshold); - parser.parse_args(argc, argv); - - if (arguments.number_of_threads < 1) { - throw std::runtime_error("number of threads must be at least 1"); - } - - if (arguments.num_of_passes < 1) { - throw std::runtime_error("number of passes must be at least 1"); - } - - if (arguments.containment_threshold < 0.0 || arguments.containment_threshold > 1.0) { - throw std::runtime_error("containment threshold must be between 0.0 and 1.0"); - } - -} - - -void show_arguments() { - cout << "Working with the following parameters:" << endl; - cout << "**************************************" << endl; - cout << "*" << endl; - cout << "* file_list: " << arguments.file_list << endl; - cout << "* working_directory: " << arguments.working_directory << endl; - cout << "* output_filename: " << arguments.output_filename << endl; - cout << "* number_of_threads: " << arguments.number_of_threads << endl; - cout << "* num_of_passes: " << arguments.num_of_passes << endl; - cout << "* containment_threshold: " << arguments.containment_threshold << endl; - cout << "*" << endl; - cout << "**************************************" << endl; -} - - -void show_empty_sketches() { - cout << "Number of empty sketches: " << count_empty_sketch << endl; - if (count_empty_sketch == 0) { - return; - } - cout << "Empty sketch ids: "; - for (int i : empty_sketch_ids) { - cout << i << " "; - } - cout << endl; -} - - -void compute_index_from_sketches() { - // create the index using all the hashes - for (uint i = 0; i < sketches.size(); i++) { - for (uint j = 0; j < sketches[i].size(); j++) { - hash_t hash = sketches[i][j]; - if (hash_index.find(hash) == hash_index.end()) { - hash_index[hash] = vector(); - } - hash_index[hash].push_back(i); - } - } - - size_t num_hashes = hash_index.size(); - - // remove the hashes that only appear in one sketch - vector hashes_to_remove; - for (auto it = hash_index.begin(); it != hash_index.end(); it++) { - if (it->second.size() == 1) { - hashes_to_remove.push_back(it->first); - } - } - for (uint i = 0; i < hashes_to_remove.size(); i++) { - hash_index.erase(hashes_to_remove[i]); - } - - size_t num_hashes_after_removal = hash_index.size(); - - cout << "Total number of distinct hashes: " << num_hashes << endl; - cout << "Total number of distinct hashes that appear in only one sketch: " << num_hashes - num_hashes_after_removal << endl; - cout << "Size of the index: " << num_hashes_after_removal << endl; - -} - - -void compute_intersection_matrix_by_sketches(int sketch_start_index, int sketch_end_index, int thread_id, string out_dir, int pass_id, int negative_offset) { - - // process the sketches in the range [sketch_start_index, sketch_end_index) - for (uint i = sketch_start_index; i < sketch_end_index; i++) { - for (int j = 0; j < sketches[i].size(); j++) { - hash_t hash = sketches[i][j]; - if (hash_index.find(hash) != hash_index.end()) { - vector sketch_indices = hash_index[hash]; - for (uint k = 0; k < sketch_indices.size(); k++) { - intersectionMatrix[i-negative_offset][sketch_indices[k]]++; - } - } - } - } - - // write the similarity values to file. filename: out_dir/passid_threadid.txt, where id is thread id in 3 digits - string id_in_three_digits_str = to_string(thread_id); - while (id_in_three_digits_str.size() < 3) { - id_in_three_digits_str = "0" + id_in_three_digits_str; - } - string pass_id_str = to_string(pass_id); - string filename = out_dir + "/" + pass_id_str + "_" + id_in_three_digits_str + ".txt"; - ofstream outfile(filename); - - // only write the values if larger than the threshold - for (int i = sketch_start_index; i < sketch_end_index; i++) { - for (uint j = 0; j < num_sketches; j++) { - // skip obvious cases - if ((uint)i == (uint)j) { - continue; - } - - // if nothing in the intersection, then skip - if (intersectionMatrix[i-negative_offset][j] == 0) { - continue; - } - - // if either of the sketches is empty, then skip - if (sketches[i].size() == 0 || sketches[j].size() == 0) { - continue; - } - - // if the divisor in the jaccard calculation is 0, then skip - if (sketches[i].size() + sketches[j].size() - intersectionMatrix[i-negative_offset][j] == 0) { - continue; - } - - double jaccard = 1.0 * intersectionMatrix[i-negative_offset][j] / ( sketches[i].size() + sketches[j].size() - intersectionMatrix[i-negative_offset][j] ); - double containment_i_in_j = 1.0 * intersectionMatrix[i-negative_offset][j] / sketches[i].size(); - double containment_j_in_i = 1.0 * intersectionMatrix[i-negative_offset][j] / sketches[j].size(); - - // containment_i_in_j is the containment of query in target, i is the query - if (containment_i_in_j < arguments.containment_threshold) { - continue; - } - - outfile << i << "," << j << "," << jaccard << "," << containment_i_in_j << "," << containment_j_in_i << endl; - similars[i].push_back(j); - } - } - - outfile.close(); - -} - - - -void compute_intersection_matrix() { - // allocate memory for the intersection matrix - int num_sketches_each_pass = ceil(1.0 * num_sketches / arguments.num_of_passes); - intersectionMatrix = new int*[num_sketches_each_pass + 1]; - for (int i = 0; i < num_sketches_each_pass + 1; i++) { - intersectionMatrix[i] = new int[num_sketches]; - } - - // allocate memory for the similars - for (int i = 0; i < num_sketches; i++) { - similars.push_back(vector()); - } - - for (int pass_id = 0; pass_id < arguments.num_of_passes; pass_id++) { - // set zeros in the intersection matrix - for (int i = 0; i < num_sketches_each_pass+1; i++) { - for (uint j = 0; j < num_sketches; j++) { - intersectionMatrix[i][j] = 0; - } - } - - // prepare the indices which will be processed in this pass - int sketch_idx_start_this_pass = pass_id * num_sketches_each_pass; - int sketch_idx_end_this_pass = (pass_id == arguments.num_of_passes - 1) ? num_sketches : (pass_id + 1) * num_sketches_each_pass; - int negative_offset = pass_id * num_sketches_each_pass; - int num_sketches_this_pass = sketch_idx_end_this_pass - sketch_idx_start_this_pass; - - // create threads - vector threads; - int chunk_size = num_sketches_this_pass / arguments.number_of_threads; - for (int i = 0; i < arguments.number_of_threads; i++) { - int start_index_this_thread = sketch_idx_start_this_pass + i * chunk_size; - int end_index_this_thread = (i == arguments.number_of_threads - 1) ? sketch_idx_end_this_pass : sketch_idx_start_this_pass + (i + 1) * chunk_size; - threads.push_back( thread(compute_intersection_matrix_by_sketches, start_index_this_thread, end_index_this_thread, i, arguments.working_directory, pass_id, negative_offset) ); - } - - // join threads - for (int i = 0; i < arguments.number_of_threads; i++) { - threads[i].join(); - } - - // show progress - std::cout << "Pass " << pass_id+1 << "/" << arguments.num_of_passes << " done." << std::endl; - } - - // free the memory allocated for the intersection matrix - for (int i = 0; i < num_sketches_each_pass + 1; i++) { - delete[] intersectionMatrix[i]; - } - delete[] intersectionMatrix; -} - - - - -void do_yacht_train() { - - cout << "Starting yacht train..." << endl; - - vector selected_genome_ids; - vector genome_id_to_exclude(num_sketches, false); - - // sort the genome ids by size - sort(genome_id_size_pairs.begin(), genome_id_size_pairs.end(), [](const pair& a, const pair& b) { - return a.second < b.second; - }); - - for (int i = 0; i < num_sketches; i++) { - - cout << "Processing " << i << "..." << '\r'; - int genome_id_this = genome_id_size_pairs[i].first; - int size_this = genome_id_size_pairs[i].second; - bool select_this = true; - - // show my size - for (int j = 0; j < similars[genome_id_this].size(); j++) { - int genome_id_other = similars[genome_id_this][j]; - if (genome_id_to_exclude[genome_id_other]) { - continue; - } - int size_other = sketches[genome_id_other].size(); - if (size_other >= size_this) { - select_this = false; - break; - } - } - if (select_this) { - selected_genome_ids.push_back(genome_id_this); - } else { - genome_id_to_exclude[genome_id_this] = true; - } - } - - cout << "Writing to output file.." << endl; - - // write the selected genome ids to file - ofstream outfile(arguments.output_filename); - for (int i = 0; i < selected_genome_ids.size(); i++) { - int genome_id = selected_genome_ids[i]; - string sketch_name = sketch_names[genome_id]; - outfile << sketch_name << endl; - } - outfile.close(); - -} - - - - -int main(int argc, char *argv[]) { - - // ********************************************************* - // ***** parse command line arguments ****** - // ********************************************************* - try { - parse_arguments(argc, argv); - } catch (const std::runtime_error &e) { - std::cerr << e.what() << std::endl; - cout << "Usage: " << argv[0] << " -h" << endl; - return 1; - } - - // show the arguments - show_arguments(); - - // ********************************************************* - // ************ read the input sketches ************ - // ********************************************************* - auto read_start = chrono::high_resolution_clock::now(); - cout << "Reading all sketches in filelist using all " << arguments.number_of_threads << " threads..." << endl; - get_sketch_names(arguments.file_list); - cout << "Total number of sketches to read: " << num_sketches << endl; - read_sketches(); - auto read_end = chrono::high_resolution_clock::now(); - - cout << "All sketches read" << endl; - - // show empty sketches - show_empty_sketches(); - - // show time taken to read all sketches - auto read_duration = chrono::duration_cast(read_end - read_start); - cout << "Time taken to read all sketches: " << read_duration.count() << " milliseconds" << endl; - - - - - // **************************************************************** - // ************* reading complete, now creating index ************* - // **************************************************************** - auto index_build_start = chrono::high_resolution_clock::now(); - cout << "Building index from sketches..." << endl; - compute_index_from_sketches(); - auto index_build_end = chrono::high_resolution_clock::now(); - auto index_build_duration = chrono::duration_cast(index_build_end - index_build_start); - cout << "Time taken to build index: " << index_build_duration.count() << " milliseconds" << endl; - - - - // ********************************************************************** - // ************* compute intersection matrix ************* - // ********************************************************************** - auto mat_computation_start = chrono::high_resolution_clock::now(); - cout << "Computing intersection matrix..." << endl; - compute_intersection_matrix(); - auto mat_computation_end = chrono::high_resolution_clock::now(); - auto mat_computation_duration = chrono::duration_cast(mat_computation_end - mat_computation_start); - cout << "Time taken to compute intersection matrix: " << mat_computation_duration.count() << " milliseconds" << endl; - - - - // ********************************************************************** - // ************* yacht train ************* - // ********************************************************************** - auto yacht_train_start = chrono::high_resolution_clock::now(); - cout << "Starting yacht train..." << endl; - do_yacht_train(); - auto yacht_train_end = chrono::high_resolution_clock::now(); - auto yacht_train_duration = chrono::duration_cast(yacht_train_end - yacht_train_start); - cout << "Time taken to do yacht train: " << yacht_train_duration.count() << " milliseconds" << endl; - - - return 0; -} \ No newline at end of file diff --git a/src/cpp/utils.cpp b/src/cpp/utils.cpp new file mode 100644 index 00000000..65e09220 --- /dev/null +++ b/src/cpp/utils.cpp @@ -0,0 +1,270 @@ +#include "utils.h" + +std::vector read_min_hashes(const std::string& json_filename) { + + // Open the JSON file + std::ifstream inputFile(json_filename); + + // Check if the file is open + if (!inputFile.is_open()) { + std::cerr << "Could not open the file!" << std::endl; + return {}; + } + + // Parse the JSON data + json jsonData; + inputFile >> jsonData; + + // Access and print values + std::vector min_hashes = jsonData[0]["signatures"][0]["mins"]; + + // Close the file + inputFile.close(); + + return min_hashes; +} + + +void compute_index_from_sketches(std::vector>& sketches, std::unordered_map>& hash_index) { + // create the index using all the hashes + for (uint i = 0; i < sketches.size(); i++) { + for (uint j = 0; j < sketches[i].size(); j++) { + hash_t hash = sketches[i][j]; + if (hash_index.find(hash) == hash_index.end()) { + hash_index[hash] = std::vector(); + } + hash_index[hash].push_back(i); + } + } + + size_t num_hashes = hash_index.size(); + + // cannot remove hashes that are in only one sketch, because + // we do not have an index for the query sketches + +} + + + +void get_sketch_paths(const std::string& filelist, std::vector& sketch_paths) { + // the filelist is a file, where each line is a path to a sketch file + std::ifstream file(filelist); + if (!file.is_open()) { + std::cerr << "Could not open the filelist: " << filelist << std::endl; + return; + } + std::string line; + while (std::getline(file, line)) { + sketch_paths.push_back(line); + } +} + + + +void read_sketches_one_chunk(int start_index, int end_index, + std::vector& sketch_paths, + std::vector>& sketches, + std::mutex& mutex_count_empty_sketch, + std::vector& empty_sketch_ids) { + + for (int i = start_index; i < end_index; i++) { + auto min_hashes = read_min_hashes(sketch_paths[i]); + sketches[i] = min_hashes; + if (sketches[i].size() == 0) { + mutex_count_empty_sketch.lock(); + empty_sketch_ids.push_back(i); + mutex_count_empty_sketch.unlock(); + } + } +} + + + +void read_sketches(std::vector& sketch_paths, + std::vector>& sketches, + std::vector& empty_sketch_ids, + const uint num_threads) { + + uint num_sketches = sketch_paths.size(); + for (uint i = 0; i < num_sketches; i++) { + sketches.push_back( std::vector() ); + } + + std::mutex mutex_count_empty_sketch; + int chunk_size = num_sketches / num_threads; + std::vector threads; + for (int i = 0; i < num_threads; i++) { + int start_index = i * chunk_size; + int end_index = (i == num_threads - 1) ? num_sketches : (i + 1) * chunk_size; + threads.push_back(std::thread(read_sketches_one_chunk, + start_index, end_index, + std::ref(sketch_paths), std::ref(sketches), + std::ref(mutex_count_empty_sketch), + std::ref(empty_sketch_ids))); + } + for (int i = 0; i < num_threads; i++) { + threads[i].join(); + } + +} + + + + +void show_empty_sketches(const std::vector& empty_sketch_ids) { + int count_empty_sketch = empty_sketch_ids.size(); + std::cout << "Number of empty sketches: " << count_empty_sketch << std::endl; + if (count_empty_sketch == 0) { + return; + } + std::cout << "Empty sketch ids: "; + for (int i : empty_sketch_ids) { + std::cout << i << " "; + } + std::cout << std::endl; +} + + + + + +void compute_intersection_matrix_by_sketches(int query_sketch_start_index, int query_sketch_end_index, + int thread_id, std::string out_dir, + int pass_id, int negative_offset, + const std::vector>& sketches_query, + const std::vector>& sketches_ref, + const std::unordered_map>& hash_index_ref, + int** intersectionMatrix, + double containment_threshold, + std::vector>& similars) { + + const int num_sketches_ref = sketches_ref.size(); + const int num_sketches_query = sketches_query.size(); + + // process the sketches in the range [sketch_start_index, sketch_end_index) + for (uint i = query_sketch_start_index; i < query_sketch_end_index; i++) { + for (int j = 0; j < sketches_query[i].size(); j++) { + hash_t hash = sketches_query[i][j]; + if (hash_index_ref.find(hash) != hash_index_ref.end()) { + std::vector ref_sketch_indices = hash_index_ref.at(hash); + for (uint k = 0; k < ref_sketch_indices.size(); k++) { + intersectionMatrix[i-negative_offset][ref_sketch_indices[k]]++; + } + } + } + } + + // write the similarity values to file. filename: out_dir/passid_threadid.txt, where id is thread id in 3 digits + std::string id_in_three_digits_str = std::to_string(thread_id); + while (id_in_three_digits_str.size() < 3) { + id_in_three_digits_str = "0" + id_in_three_digits_str; + } + std::string pass_id_str = std::to_string(pass_id); + std::string filename = out_dir + "/" + pass_id_str + "_" + id_in_three_digits_str + ".txt"; + std::ofstream outfile(filename); + + // only write the values if larger than the threshold + for (int i = query_sketch_start_index; i < query_sketch_end_index; i++) { + for (uint j = 0; j < num_sketches_ref; j++) { + // if nothing in the intersection, then skip + if (intersectionMatrix[i-negative_offset][j] == 0) { + continue; + } + + // if either of the sketches is empty, then skip + if (sketches_query[i].size() == 0 || sketches_ref[j].size() == 0) { + continue; + } + + // if the divisor in the jaccard calculation is 0, then skip + if (sketches_query[i].size() + sketches_ref[j].size() - intersectionMatrix[i-negative_offset][j] == 0) { + continue; + } + + double jaccard = 1.0 * intersectionMatrix[i-negative_offset][j] / ( sketches_query[i].size() + sketches_ref[j].size() - intersectionMatrix[i-negative_offset][j] ); + double containment_i_in_j = 1.0 * intersectionMatrix[i-negative_offset][j] / sketches_query[i].size(); + double containment_j_in_i = 1.0 * intersectionMatrix[i-negative_offset][j] / sketches_ref[j].size(); + + // containment_i_in_j is the containment of query in target, i is the query + if (containment_i_in_j < containment_threshold) { + continue; + } + + outfile << i << "," << j << "," << jaccard << "," << containment_i_in_j << "," << containment_j_in_i << std::endl; + similars[i].push_back(j); + } + } + + outfile.close(); + +} + + + +void compute_intersection_matrix(const std::vector>& sketches_query, + const std::vector>& sketches_ref, + const std::unordered_map>& hash_index_ref, + const std::string& out_dir, + std::vector>& similars, + double containment_threshold, + const int num_passes, const int num_threads) { + + int num_sketches_query = sketches_query.size(); + int num_sketches_ref = sketches_ref.size(); + + // allocate memory for the intersection matrix + int num_query_sketches_each_pass = ceil(1.0 * num_sketches_query / num_passes); + int** intersectionMatrix = new int*[num_query_sketches_each_pass + 1]; + for (int i = 0; i < num_query_sketches_each_pass + 1; i++) { + intersectionMatrix[i] = new int[num_sketches_ref]; + } + + // allocate memory for the similars + for (int i = 0; i < num_sketches_query; i++) { + similars.push_back(std::vector()); + } + + for (int pass_id = 0; pass_id < num_passes; pass_id++) { + // set zeros in the intersection matrix + for (int i = 0; i < num_query_sketches_each_pass+1; i++) { + for (uint j = 0; j < num_sketches_ref; j++) { + intersectionMatrix[i][j] = 0; + } + } + + // prepare the indices which will be processed in this pass + int sketch_idx_start_this_pass = pass_id * num_query_sketches_each_pass; + int sketch_idx_end_this_pass = (pass_id == num_passes - 1) ? num_sketches_query : (pass_id + 1) * num_query_sketches_each_pass; + int negative_offset = pass_id * num_query_sketches_each_pass; + int num_query_sketches_this_pass = sketch_idx_end_this_pass - sketch_idx_start_this_pass; + + // create threads + std::vector threads; + int chunk_size = num_query_sketches_this_pass / num_threads; + for (int i = 0; i < num_threads; i++) { + int start_query_index_this_thread = sketch_idx_start_this_pass + i * chunk_size; + int end_query_index_this_thread = (i == num_threads - 1) ? sketch_idx_end_this_pass : sketch_idx_start_this_pass + (i + 1) * chunk_size; + threads.push_back(std::thread(compute_intersection_matrix_by_sketches, + start_query_index_this_thread, end_query_index_this_thread, + i, out_dir, pass_id, negative_offset, + std::ref(sketches_query), std::ref(sketches_ref), + std::ref(hash_index_ref), intersectionMatrix, + containment_threshold, + std::ref(similars))); + } + + // join threads + for (int i = 0; i < num_threads; i++) { + threads[i].join(); + } + + // show progress + std::cout << "Pass " << pass_id+1 << "/" << num_passes << " done." << std::endl; + } + + // free the memory allocated for the intersection matrix + for (int i = 0; i < num_query_sketches_each_pass + 1; i++) { + delete[] intersectionMatrix[i]; + } + delete[] intersectionMatrix; +} \ No newline at end of file diff --git a/src/cpp/utils.h b/src/cpp/utils.h new file mode 100644 index 00000000..432f04c0 --- /dev/null +++ b/src/cpp/utils.h @@ -0,0 +1,117 @@ +#ifndef UTILS_H +#define UTILS_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "json.hpp" + +using json = nlohmann::json; + +typedef unsigned long long int hash_t; + +/** + * @brief Read the min-hashes from a FMH sketch file + * + * Assumption: the file is a json file, and its not gzipped + * + * @param sketch_path The path to the sketch file + */ +std::vector read_min_hashes(const std::string& sketch_path); + + + + + + +/** + * @brief Compute the index from the sketches + * + * @param sketches The sketches + * @param hash_index The reference to the hash index (where the index will be stored) + */ +void compute_index_from_sketches(std::vector>& sketches, std::unordered_map>& hash_index); + + + + + + +/** + * @brief Get the sketch paths + * + * @param filelist The file containing the paths of the sketches + * @param sketch_paths The vector to store the paths + */ +void get_sketch_paths(const std::string& filelist, std::vector& sketch_paths); + + + + + + +/** + * @brief Read the sketches from the sketch paths + * + * @param sketch_paths The paths to the sketches + * @param sketches The vector to store the sketches + * @param empty_sketch_ids The vector to store the ids of empty sketches + * @param num_threads The number of threads to use + */ +void read_sketches(std::vector& sketch_paths, + std::vector>& sketches, + std::vector& empty_sketch_ids, + const uint num_threads); + + + + + +/** + * @brief Show the empty sketches + * + * @param empty_sketch_ids The ids of the empty sketches + */ +void show_empty_sketches(const std::vector&); + + + + +/** + * @brief Compute the intersection matrix + * + * @param sketches_query The query sketches + * @param sketches_ref The reference (target) sketches + * @param hash_index_ref The index of the reference (target) sketches + * @param out_dir The output directory to store the results + * @param similars The vector to store the similar sketches + * @param containment_threshold The containment threshold + * @param num_passes The number of passes to use + * @param num_threads The number of threads to use + */ +void compute_intersection_matrix(const std::vector>& sketches_query, + const std::vector>& sketches_ref, + const std::unordered_map>& hash_index_ref, + const std::string& out_dir, + std::vector>& similars, + double containment_threshold, + const int num_passes, const int num_threads); + +#endif \ No newline at end of file diff --git a/src/cpp/yacht_run_compute_similarity.cpp b/src/cpp/yacht_run_compute_similarity.cpp new file mode 100644 index 00000000..6dd04a67 --- /dev/null +++ b/src/cpp/yacht_run_compute_similarity.cpp @@ -0,0 +1,213 @@ +/* + * Author: Mahmudur Rahman Hera (mahmudhera93@gmail.com) + * Date: November 1, 2024 + * Description: This code reads the query and target sketches from the files, builds an index from the target sketches, and computes the similarity matrix. + * All query vs all target pairs are written if containment(query,target) >= provided threshold. + * + * Output files are written in the output directory. Many output files are + * generated, in the form a_bcd.txt, where a is the pass id, and bcd is the thread id. + * By default, the output files are not combined. If you want to combine the output files, + * use the -C flag. The combined output file will be written to the output filename provided. + * Each line in the output file contains the query and target sketch ids, and the similarity value. + * A typical line in the output file looks like this: 12 34 0.2 0.3 0.4 + * This means that the (12+1)-th query sketch is similar to the (34+1)-th target sketch, + * and the Jaccard, containment(query,target), and containment(target,query) values are 0.2, 0.3, and 0.4. + */ + +#include "argparse.hpp" +#include "json.hpp" +#include "utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using json = nlohmann::json; + +struct Arguments { + string file_list_query; + string file_list_target; + string output_directory; + int number_of_threads; + int num_of_passes; + double containment_threshold; + bool combine; + string combined_output_filename; +}; + + +typedef Arguments Arguments; +typedef unsigned long long int hash_t; + + +void parse_arguments(int argc, char *argv[], Arguments &arguments) { + + argparse::ArgumentParser parser("compute similarity of targets with queries"); + + parser.add_description("This code reads the query and target sketches from the files, builds an index from the target sketches, and computes the similarity matrix.\n" + "All query vs all target pairs are written if containment(query,target) >= provided threshold.\n" + "Output files are written in the output directory. Many output files are generated, in the form a_bcd.txt, where a is the pass id, and bcd is the thread id.\n" + "By default, the output files are not combined. If you want to combine the output files, use the -C flag. The combined output file will be written to the output filename provided.\n" + "Each line in the output file contains the query and target sketch ids, and the similarity value.\n" + "A typical line in the output file looks like this: 12,34,0.2,0.3,0.4\n" + "This means that the (12+1)-th query sketch is similar to the (34+1)-th target sketch, and the Jaccard, containment(query,target), and containment(target,query) values are 0.2, 0.3, and 0.4."); + + parser.add_argument("file_list_query") + .help("file containing paths of query sketches") + .store_into(arguments.file_list_query); + + parser.add_argument("file_list_target") + .help("file containing paths of target sketches") + .store_into(arguments.file_list_target); + + parser.add_argument("output_directory") + .help("output directory (where similarity values will be written)") + .store_into(arguments.output_directory); + + parser.add_argument("-t", "--threads") + .help("number of threads") + .scan<'i', int>() + .default_value(1) + .store_into(arguments.number_of_threads); + + parser.add_argument("-p", "--passes") + .help("number of passes") + .scan<'i', int>() + .default_value(1) + .store_into(arguments.num_of_passes); + + parser.add_argument("-c", "--containment_threshold") + .help("containment threshold") + .scan<'g', double>() + .default_value(0.9) + .store_into(arguments.containment_threshold); + + // argument: combine files (store true if present) + parser.add_argument("-C", "--combine") + .help("combine the output files") + .default_value(false) + .implicit_value(true) + .store_into(arguments.combine); + + // argument: combined output filename + parser.add_argument("-o", "--output") + .help("output filename (where the combined output will be written. Not used if -C is not present)") + .default_value("combined_output.txt") + .store_into(arguments.combined_output_filename); + + parser.parse_args(argc, argv); + + if (arguments.number_of_threads < 1) { + throw std::runtime_error("number of threads must be at least 1"); + } + + if (arguments.num_of_passes < 1) { + throw std::runtime_error("number of passes must be at least 1"); + } + + if (arguments.containment_threshold < 0.0 || arguments.containment_threshold > 1.0) { + throw std::runtime_error("containment threshold must be between 0.0 and 1.0"); + } + +} + + +void show_arguments(Arguments& args) { + cout << "**************************************" << endl; + cout << "*" << endl; + cout << "* file_list_query: " << args.file_list_query << endl; + cout << "* file_list_target: " << args.file_list_target << endl; + cout << "* output_directory: " << args.output_directory << endl; + cout << "* number_of_threads: " << args.number_of_threads << endl; + cout << "* num_of_passes: " << args.num_of_passes << endl; + cout << "* containment_threshold: " << args.containment_threshold << endl; + cout << "* combine: " << bool(args.combine) << endl; + cout << "* combined_output_filename: " << args.combined_output_filename << endl; + cout << "*" << endl; + cout << "**************************************" << endl; +} + + + +int main(int argc, char** argv) { + Arguments args; + + // parse the arguments + try { + parse_arguments(argc, argv, args); + } catch (const std::runtime_error &e) { + std::cerr << e.what() << std::endl; + cout << "Usage: " << argv[0] << " -h" << endl; + return 1; + } + + // show the arguments + show_arguments(args); + + // read the query sketches + cout << "Reading query sketches..." << endl; + vector sketch_paths_query; + vector> sketches_query; + vector empty_sketch_ids_query; + get_sketch_paths(args.file_list_query, sketch_paths_query); + read_sketches(sketch_paths_query, sketches_query, empty_sketch_ids_query, args.number_of_threads); + cout << "All query sketches read" << endl; + + // read the target sketches + cout << "Reading target sketches..." << endl; + vector sketch_paths_target; + vector> sketches_target; + vector empty_sketch_ids_target; + get_sketch_paths(args.file_list_target, sketch_paths_target); + read_sketches(sketch_paths_target, sketches_target, empty_sketch_ids_target, args.number_of_threads); + + // show empty sketches + cout << "Empty sketches in query:" << endl; + show_empty_sketches(empty_sketch_ids_query); + cout << "Empty sketches in target:" << endl; + show_empty_sketches(empty_sketch_ids_target); + + // compute the index from the target sketches + cout << "Building index from target sketches..." << endl; + unordered_map> hash_index_target; + compute_index_from_sketches(sketches_target, hash_index_target); + + // compute the similarity matrix + cout << "Computing similarity matrix..." << endl; + vector> similars; + compute_intersection_matrix(sketches_query, sketches_target, + hash_index_target, + args.output_directory, similars, + args.containment_threshold, + args.num_of_passes, + args.number_of_threads); + + cout << "similarity computation completed, results are here: " << args.output_directory << endl; + + if (args.combine) { + cout << "Combining the output files..." << endl; + // cat all the files in the output directory + string command = "cat " + args.output_directory + "/*.txt > " + args.combined_output_filename; + system(command.c_str()); + cout << "Combined output written to: " << args.combined_output_filename << endl; + } + + return 0; +} \ No newline at end of file diff --git a/src/cpp/yacht_train_core.cpp b/src/cpp/yacht_train_core.cpp new file mode 100644 index 00000000..5b7f3cbd --- /dev/null +++ b/src/cpp/yacht_train_core.cpp @@ -0,0 +1,268 @@ +/* + * Author: Mahmudur Rahman Hera (mahmudhera93@gmail.com) + * Date: November 1, 2024 + * Description: yacht train core using indexing of sketches to do genome comparison + */ + +#include "argparse.hpp" +#include "json.hpp" +#include "utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using json = nlohmann::json; + + +struct Arguments { + string file_list; + string working_directory; + string output_filename; + int number_of_threads; + int num_of_passes; + double containment_threshold; +}; + + +typedef Arguments Arguments; +typedef unsigned long long int hash_t; + + + +void parse_arguments(int argc, char *argv[], Arguments &arguments) { + + argparse::ArgumentParser parser("yacht train using indexing of sketches"); + + parser.add_argument("file_list") + .help("file containing list of files to be processed") + .store_into(arguments.file_list); + parser.add_argument("working_directory") + .help("working directory (where temp files are generated)") + .store_into(arguments.working_directory); + parser.add_argument("output_filename") + .help("output filename (where the reduced ref filenames will be written)") + .store_into(arguments.output_filename); + parser.add_argument("-t", "--threads") + .help("number of threads") + .scan<'i', int>() + .default_value(1) + .store_into(arguments.number_of_threads); + parser.add_argument("-p", "--passes") + .help("number of passes") + .scan<'i', int>() + .default_value(1) + .store_into(arguments.num_of_passes); + parser.add_argument("-c", "--containment_threshold") + .help("containment threshold") + .scan<'g', double>() + .default_value(0.9) + .store_into(arguments.containment_threshold); + parser.parse_args(argc, argv); + + if (arguments.number_of_threads < 1) { + throw std::runtime_error("number of threads must be at least 1"); + } + + if (arguments.num_of_passes < 1) { + throw std::runtime_error("number of passes must be at least 1"); + } + + if (arguments.containment_threshold < 0.0 || arguments.containment_threshold > 1.0) { + throw std::runtime_error("containment threshold must be between 0.0 and 1.0"); + } + +} + + +void show_arguments(Arguments& arguments) { + cout << "Working with the following parameters:" << endl; + cout << "**************************************" << endl; + cout << "*" << endl; + cout << "* file_list: " << arguments.file_list << endl; + cout << "* working_directory: " << arguments.working_directory << endl; + cout << "* output_filename: " << arguments.output_filename << endl; + cout << "* number_of_threads: " << arguments.number_of_threads << endl; + cout << "* num_of_passes: " << arguments.num_of_passes << endl; + cout << "* containment_threshold: " << arguments.containment_threshold << endl; + cout << "*" << endl; + cout << "**************************************" << endl; +} + + + + +void do_yacht_train(const vector>& sketches, + const vector>& similars, + const vector& sketch_paths, + const string& output_filename) { + + cout << "Starting yacht train..." << endl; + + int num_sketches = sketches.size(); + + vector selected_genome_ids; + vector genome_id_to_exclude(num_sketches, false); + + // sort the genome ids by size + vector> genome_id_size_pairs; + for (int i = 0; i < num_sketches; i++) { + genome_id_size_pairs.push_back(make_pair(i, sketches[i].size())); + } + sort(genome_id_size_pairs.begin(), genome_id_size_pairs.end(), [](const pair& a, const pair& b) { + return a.second < b.second; + }); + + for (int i = 0; i < num_sketches; i++) { + + cout << "Processing " << i << "..." << '\r'; + int genome_id_this = genome_id_size_pairs[i].first; + int size_this = genome_id_size_pairs[i].second; + bool select_this = true; + + // show my size + for (int j = 0; j < similars[genome_id_this].size(); j++) { + int genome_id_other = similars[genome_id_this][j]; + + if (genome_id_other == genome_id_this) { + continue; + } + + if (genome_id_to_exclude[genome_id_other]) { + continue; + } + int size_other = sketches[genome_id_other].size(); + if (size_other >= size_this) { + select_this = false; + break; + } + } + if (select_this) { + selected_genome_ids.push_back(genome_id_this); + } else { + genome_id_to_exclude[genome_id_this] = true; + } + } + + cout << "Writing to output file.." << endl; + + // write the selected sketch paths to file + ofstream outfile(output_filename); + for (int i = 0; i < selected_genome_ids.size(); i++) { + int genome_id = selected_genome_ids[i]; + string sketch_name = sketch_paths[genome_id]; + outfile << sketch_name << endl; + } + outfile.close(); + +} + + + + +int main(int argc, char *argv[]) { + + Arguments arguments; + + std::vector sketch_paths; + vector> sketches; + unordered_map> hash_index; + mutex mutex_count_empty_sketch; + vector empty_sketch_ids; + int ** intersectionMatrix; + vector> similars; + + // ********************************************************* + // ***** parse command line arguments ****** + // ********************************************************* + try { + parse_arguments(argc, argv, arguments); + } catch (const std::runtime_error &e) { + std::cerr << e.what() << std::endl; + cout << "Usage: " << argv[0] << " -h" << endl; + return 1; + } + + // show the arguments + show_arguments(arguments); + + // ********************************************************* + // ************ read the input sketches ************ + // ********************************************************* + auto read_start = chrono::high_resolution_clock::now(); + cout << "Reading all sketches in filelist using all " << arguments.number_of_threads << " threads..." << endl; + get_sketch_paths(arguments.file_list, sketch_paths); + cout << "Total number of sketches to read: " << sketch_paths.size() << endl; + read_sketches(sketch_paths, sketches, empty_sketch_ids, arguments.number_of_threads); + auto read_end = chrono::high_resolution_clock::now(); + + cout << "All sketches read" << endl; + + + // show empty sketches + show_empty_sketches(empty_sketch_ids); + + // show time taken to read all sketches + auto read_duration = chrono::duration_cast(read_end - read_start); + cout << "Time taken to read all sketches: " << read_duration.count() << " milliseconds" << endl; + + + + + // **************************************************************** + // ************* reading complete, now creating index ************* + // **************************************************************** + auto index_build_start = chrono::high_resolution_clock::now(); + cout << "Building index from sketches..." << endl; + compute_index_from_sketches(sketches, hash_index); + auto index_build_end = chrono::high_resolution_clock::now(); + auto index_build_duration = chrono::duration_cast(index_build_end - index_build_start); + cout << "Time taken to build index: " << index_build_duration.count() << " milliseconds" << endl; + + + + // ********************************************************************** + // ************* compute intersection matrix ************* + // ********************************************************************** + auto mat_computation_start = chrono::high_resolution_clock::now(); + cout << "Computing intersection matrix..." << endl; + compute_intersection_matrix(sketches, sketches, hash_index, + arguments.working_directory, similars, + arguments.containment_threshold, arguments.num_of_passes, + arguments.number_of_threads); + auto mat_computation_end = chrono::high_resolution_clock::now(); + auto mat_computation_duration = chrono::duration_cast(mat_computation_end - mat_computation_start); + cout << "Time taken to compute intersection matrix: " << mat_computation_duration.count() << " milliseconds" << endl; + + + + // ********************************************************************** + // ************* yacht train ************* + // ********************************************************************** + auto yacht_train_start = chrono::high_resolution_clock::now(); + cout << "Starting yacht train..." << endl; + do_yacht_train(sketches, similars, sketch_paths, arguments.output_filename); + auto yacht_train_end = chrono::high_resolution_clock::now(); + auto yacht_train_duration = chrono::duration_cast(yacht_train_end - yacht_train_start); + cout << "Time taken to do yacht train: " << yacht_train_duration.count() << " milliseconds" << endl; + + + return 0; +} \ No newline at end of file diff --git a/src/yacht/hypothesis_recovery_src.py b/src/yacht/hypothesis_recovery_src.py index 77a0218d..6b1fdb07 100644 --- a/src/yacht/hypothesis_recovery_src.py +++ b/src/yacht/hypothesis_recovery_src.py @@ -10,6 +10,7 @@ from multiprocessing import Pool import sourmash import glob +import shutil from typing import List, Set, Tuple from .utils import load_signature_with_ksize, decompress_all_sig_files # Configure Loguru logger @@ -24,17 +25,17 @@ sys.stdout, format="{time:YYYY-MM-DD HH:mm:ss} - {level} - {message}", level="INFO" ) +# Set up contants SIG_SUFFIX = ".sig" - +FILE_LOCATION = os.path.dirname(os.path.realpath(__file__)) def get_organisms_with_nonzero_overlap( manifest: pd.DataFrame, sample_file: str, - scale: int, - ksize: int, num_threads: int, path_to_genome_temp_dir: str, path_to_sample_temp_dir: str, + num_genome_threshold: int = 1000000 ) -> List[str]: """ This function runs the sourmash multisearch to find the organisms that have non-zero overlap with the sample. @@ -49,10 +50,9 @@ def get_organisms_with_nonzero_overlap( 'sample_scale_factor', 'min_coverage' :param sample_file: string (path to the sample signature file) - :param scale: int (scale factor) - :param ksize: string (size of kmer) :param num_threads: int (number of threads to use for parallelization) :param path_to_genome_temp_dir: string (path to the genome temporary directory generated by the training step) + :param num_genome_threshold : int (a threshold to detmine the number of passes/block size for the similarity algorithm) :param path_to_sample_temp_dir: string (path to the sample temporary directory) :return: a list of organism names that have non-zero overlap with the sample """ @@ -75,7 +75,7 @@ def get_organisms_with_nonzero_overlap( ) ] ) - sample_sig_file_path = os.path.join(path_to_sample_temp_dir, "sample_sig_file.txt") + sample_sig_file_path = os.path.join(path_to_sample_temp_dir, "sample_sig_file.tsv") sample_sig_file.to_csv(sample_sig_file_path, header=False, index=False) organism_sig_file = pd.DataFrame( @@ -85,26 +85,36 @@ def get_organisms_with_nonzero_overlap( ] ) organism_sig_file_path = os.path.join( - path_to_sample_temp_dir, "organism_sig_file.txt" + path_to_sample_temp_dir, "organism_sig_file.tsv" ) organism_sig_file.to_csv(organism_sig_file_path, header=False, index=False) - # run the sourmash multisearch - cmd = f"sourmash scripts multisearch {sample_sig_file_path} {organism_sig_file_path} -s {scale} -k {ksize} -c {num_threads} -t 0 -o {os.path.join(path_to_sample_temp_dir, 'sample_multisearch_result.csv')}" - logger.info(f"Running sourmash multisearch with command: {cmd}") + # run algorithm to calculate similarity between sample and organisms + total_sig_files = len(organism_sig_file[0]) + if total_sig_files <= num_genome_threshold: + passes = 1 + else: + passes = int(total_sig_files / num_genome_threshold) + 1 + cmd = f"{FILE_LOCATION}/yacht_run_compute_similarity {sample_sig_file_path} {organism_sig_file_path} {path_to_sample_temp_dir} -t {num_threads} -p {passes} -c 0 -C -o {path_to_sample_temp_dir}/sample_comparison_result.csv" + + logger.info(f"Running similarity algorithm with command: {cmd}") exit_code = os.system(cmd) if exit_code != 0: - raise ValueError(f"Error running sourmash multisearch with command: {cmd}") + raise ValueError(f"Error running similarity algorithm with command: {cmd}") + + # remove all split comparison files ("*.txt") but only keep the combined ones + for file in glob.glob(os.path.join(path_to_sample_temp_dir, "*.txt")): + os.remove(file) - # read the multisearch result - multisearch_result = pd.read_csv( - os.path.join(path_to_sample_temp_dir, "sample_multisearch_result.csv"), + # read the similarity algorithm result + comparison_result = pd.read_csv( + os.path.join(path_to_sample_temp_dir, "sample_comparison_result.csv"), sep=",", - header=0, + header=None, ) - multisearch_result = multisearch_result.drop_duplicates().reset_index(drop=True) + comparison_result = comparison_result.drop_duplicates().reset_index(drop=True) - return multisearch_result["match_name"].to_list() + return [manifest['organism_name'].to_list()[i] for i in comparison_result[1].to_list()] def get_exclusive_hashes( @@ -355,8 +365,6 @@ def hypothesis_recovery( nontrivial_organism_names = get_organisms_with_nonzero_overlap( manifest, sample_file, - scale, - ksize, num_threads, path_to_genome_temp_dir, path_to_sample_temp_dir, diff --git a/src/yacht/utils.py b/src/yacht/utils.py index d4026229..81a010b2 100644 --- a/src/yacht/utils.py +++ b/src/yacht/utils.py @@ -22,7 +22,7 @@ FILE_LOCATION = os.path.dirname(os.path.realpath(__file__)) # Set up global variables -__version__ = "1.3.0" +__version__ = "1.4.0" GITHUB_API_URL = "https://api.github.com/repos/KoslickiLab/YACHT/contents/demo/{path}" GITHUB_RAW_URL = "https://raw.githubusercontent.com/KoslickiLab/YACHT/main/demo/{path}" BASE_URL = "https://farm.cse.ucdavis.edu/~ctbrown/sourmash-db/" @@ -119,6 +119,8 @@ def run_yacht_train_core( :param ani_thresh: float (threshold for ANI, below which we consider two organisms to be distinct) :param ksize: int (size of kmer) :param path_to_temp_dir: string (path to the folder to store the intermediate files) + :param sig_info_dict: a dictionary mapping signature name to a tuple (md5sum, minhash mean abundance, minhash_hashes_len, minhash scaled, raw file path) + :param num_genome_threshold : int (a threshold to detmine the number of passes/block size for the similarity algorithm) :return: a dataframe containing the selected reference signature information """ @@ -140,16 +142,15 @@ def run_yacht_train_core( passes = 1 else: passes = int(total_sig_files / num_genome_threshold) + 1 - cmd = f"{FILE_LOCATION}/run_yacht_train_core -t {num_threads} -c {containment_thresh} -p {passes} {sig_files_path} {path_to_temp_dir} {os.path.join(path_to_temp_dir, 'selected_result.tsv')}" + cmd = f"{FILE_LOCATION}/yacht_train_core -t {num_threads} -c {containment_thresh} -p {passes} {sig_files_path} {path_to_temp_dir} {os.path.join(path_to_temp_dir, 'selected_result.tsv')}" logger.info(f"Running comparison algorithm with command: {cmd}") exit_code = os.system(cmd) if exit_code != 0: raise ValueError(f"Error running comparison algorithm with command: {cmd}") - # move all split comparison files to a single foldr - os.makedirs(os.path.join(path_to_temp_dir, "comparison_files"), exist_ok=True) + # remove all split comparison files ("*.txt") but only keep the combined ones for file in glob(os.path.join(path_to_temp_dir, "*.txt")): - shutil.move(file, os.path.join(path_to_temp_dir, "comparison_files")) + os.remove(file) # get info from the signature files of selected genomes selected_sig_files = pd.read_csv(os.path.join(path_to_temp_dir, 'selected_result.tsv'), sep="\t", header=None) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..1159d616 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,14 @@ +import pytest + +def pytest_addoption(parser): + parser.addoption( + "--runslow", action="store_true", default=False, help="Run slow tests" + ) + +def pytest_collection_modifyitems(config, items): + if not config.getoption("--runslow"): + # Skip slow tests unless --runslow is specified + skip_slow = pytest.mark.skip(reason="Need --runslow option to run") + for item in items: + if "slow" in item.keywords: + item.add_marker(skip_slow) \ No newline at end of file diff --git a/tests/test_hypothesis_recovery_src.py b/tests/test_hypothesis_recovery_src.py index 903db258..8227d360 100644 --- a/tests/test_hypothesis_recovery_src.py +++ b/tests/test_hypothesis_recovery_src.py @@ -40,6 +40,7 @@ def test_hypothesis_recovery(self, mock_get_exclusive_hashes, mock_get_organisms self.assertIsInstance(result, list) self.assertEqual(len(result), 2) + @patch('yacht.hypothesis_recovery_src.os.remove') @patch('pandas.read_csv') @patch('yacht.hypothesis_recovery_src.os.system') @patch('yacht.hypothesis_recovery_src.os.path.join') @@ -48,17 +49,22 @@ def test_hypothesis_recovery(self, mock_get_exclusive_hashes, mock_get_organisms @patch('yacht.hypothesis_recovery_src.zipfile.ZipFile') @patch('glob.glob') def test_get_organisms_with_nonzero_overlap(self, mock_glob, mock_zipfile, _, mock_os_listdir, - mock_os_path_join, mock_os_system, mock_read_csv): + mock_os_path_join, mock_os_system, mock_read_csv, mock_os_remove): mock_glob.return_value = ['training_sig_file_1.sig'] mock_os_listdir.return_value = ['sig_file'] mock_os_path_join.return_value = 'joined_path' mock_os_system.return_value = 0 - mock_read_csv.return_value = pd.DataFrame({'match_name': ['org1', 'org2']}) + mock_os_remove.return_value = None + # Mock reading CSV with the expected column structure + mock_read_csv.return_value = pd.DataFrame({ + 0: ['dummy_value1', 'dummy_value2'], # Unused in function logic + 1: [0, 1] # Indices that map to `manifest['organism_name']` + }) mock_zip_file_instance = mock_zipfile.return_value.__enter__.return_value mock_zip_file_instance.extractall.return_value = None - result = get_organisms_with_nonzero_overlap(self.mock_manifest, 'sample_file.zip', 10, 31, 4, + result = get_organisms_with_nonzero_overlap(self.mock_manifest, 'sample_file.zip', 4, '/path/to/genome_temp_dir', '/path/to/sample_temp_dir') self.assertIsInstance(result, list) diff --git a/tests/test_y_integration_tests.py b/tests/test_y_integration_tests.py index 851bad1e..d7f2cf8a 100644 --- a/tests/test_y_integration_tests.py +++ b/tests/test_y_integration_tests.py @@ -2,6 +2,7 @@ import tempfile import json from os.path import exists +import pytest cpath = os.path.dirname(os.path.realpath(__file__)) project_path = os.path.join(cpath,'..') @@ -73,12 +74,13 @@ def test_run_yacht(): assert exists('result.xlsx') def test_run_pretrained_ref_db(): - cmd = "yacht download pretrained_ref_db --database gtdb --db_version rs214 --k 31 --ani_thresh 0.9995 --outfolder ./" + cmd = "yacht download pretrained_ref_db --database gtdb --db_version rs214 --k 31 --ani_thresh 0.80 --outfolder ./" res = subprocess.run(cmd, shell=True, check=True) assert res.returncode == 0 +# @pytest.mark.slow def test_run_yacht_pretrained_ref_db(): - cmd = f"yacht run --json ./gtdb-rs214-reps.k31_0.9995_pretrained/gtdb-rs214-reps.k31_0.9995_config.json --sample_file '{project_path}/tests/testdata/sample.sig.zip' --significance 0.99 --num_threads 32 --min_coverage_list 1 0.6 0.2 0.1 --out ./result_pretrained.xlsx" + cmd = f"yacht run --json ./gtdb-rs214-reps.k31_0.80_pretrained/gtdb-rs214-reps.k31_0.80_config.json --sample_file '{project_path}/tests/testdata/sample.sig.zip' --significance 0.99 --num_threads 32 --min_coverage_list 1 0.6 0.2 0.1 --out ./result_pretrained.xlsx" res = subprocess.run(cmd, shell=True, check=True) assert res.returncode == 0