diff --git a/ChangeLog.md b/ChangeLog.md index 53119cb2..6e2d724a 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -1,3 +1,100 @@ +# NEAT v4.4.3 +Major performance and memory overhaul focused on making NEAT viable for large +genomes on supercomputing hardware. No user-visible API changes (other than +chunk size now auto-tuning by default). + +**Benchmark (ecoli 10× coverage, 4 threads, identical configs):** + +| Metric | v4.4.2 | v4.4.3 | Improvement | +|-------------------------|------------|------------|-----------------| +| ecoli SE wall time | 14:55 | 1:35 | 9.4× faster | +| ecoli PE wall time | 14:46 | 1:35 | 9.4× faster | +| ecoli SE total CPU | 3,227 s | 331 s | 9.7× less | +| ecoli PE total CPU | 3,168 s | 338 s | 9.4× less | +| Peak resident memory | 549 MB | 175 MB | 3.1× less | +| Peak heap (memray) | 1.27 GB | 0.32 GB | 4× less | +| Per-worker memory | O(N×cov) | O(1) | bounded | +| `pysam.sort` calls | 2 | 0 | gone | +| BAM correctness | 0.06% dups | strict | fixed | + +**Versus NEAT 2.1 (single-threaded baseline):** +- SE: 12:28 → 1:35 (7.9× faster, 56% less CPU) +- PE: 20:12 → 1:35 (12.8× faster, 72% less CPU) + +**Scale-test (c_elegans 10× coverage, 4 threads, 100 Mb genome — ~7× the +ecoli reference):** + +| Metric | Value | +|-------------------------|---------------------------------------------------------| +| Wall time | 19:16 | +| Total CPU | 4,085 s | +| Peak resident memory | 304 MB | +| BAM records | 6,685,764 | +| BAM sort violations | 0 | +| Stitch step (parallel) | 5.3 s | +| Auto-tuned chunk size | 3.1 Mb (35 chunks) | + +Scaling behavior vs ecoli is ~linear in genome size as expected. The stitch +step is bounded by raw disk I/O via `pysam.cat`, so it stays at single-digit +seconds even as the BAM grows. Per-worker peak RSS is 304 MB ÷ 4 ≈ 76 MB, +which is the reference segment + models — independent of chunk size and +coverage. + +**What changed in the hot path:** +- Vectorized error sampling in `get_sequencing_errors` — replaced a ~150-iteration + per-read Python loop with batched numpy. Eliminated 28M `np.prod` calls per + 185k-read run. +- Vectorized `get_quality_scores` — replaced per-base scalar `rng.normal` with + one batched call. +- Replaced per-read `PairwiseAligner.align()` in `make_cigar` with a direct + walker that builds the CIGAR from known error/mutation positions in O(L). + 99% of reads now skip alignment entirely. +- Rewrote `apply_errors` as a single ascending-position pass. The previous + implementation did one `np.concatenate` and one `MutableSeq` slice/concat + per error — quadratic in errors-per-read. The new pass is linear regardless + of error count. +- Removed redundant `deepcopy(self.reference_segment)` calls in + `convert_masking` and `finalize_read_and_write`. Biopython `Seq` is + immutable; the downstream operations make their own working copies. + +**What changed in the I/O path:** +- Removed both `pysam.sort` calls. Per-worker BAMs are emitted coordinate-sorted + by construction; `pysam.merge` of sorted inputs already produces sorted output. + The final sort allocated a 1 GB buffer that dominated peak memory. +- Replaced `pysam.merge` with `pysam.cat` for the final stitch. cat does a raw + BGZF concatenation (no decompression / re-encode), bounded by raw disk I/O + instead of BGZF rate. At human-30× scale this is the difference between a + multi-hour stitch and a multi-minute one. +- Each chunk now owns a non-overlapping reference range for read1 placement + (`responsibility_length`), enabling the cat-based stitch and eliminating + ~0.06% over-coverage in chunk-overlap regions. +- Streamed FASTQ and BAM records directly to output during read generation. + Workers no longer accumulate `reads_to_write` — per-worker memory is now + bounded by reference segment + models, not by chunk size × coverage. +- Stitch steps (FASTQ concat, VCF dedup, BAM cat) now run concurrently in + threads. On a single-disk system the wall is bounded by the BAM cat alone; + on parallel filesystems the overlap is more pronounced. +- FASTQ stitch is now byte-level: per-chunk gzip streams are concatenated + without decompression / re-encode (concatenated gzip streams form a valid + gzip file per the spec). + +**Defaults and ergonomics:** +- `parallel_block_size` now auto-tunes from genome length and thread count + (target: ~8 chunks per thread). For small bacterial genomes this matches the + old hardcoded 500 kb; for human-scale genomes it produces ~6 Mb chunks + instead of ~500 kb, dramatically reducing stitch overhead. Specify the option + explicitly to override. +- FASTQ output is no longer shuffled; reads come out in the natural sampling + order. Pipe through `seqkit shuffle` if you need a uniform shuffle (documented + in README). +- Added a "Multi-node deployment on HPC clusters" section to the README + showing a SLURM array-job pattern for whole-genome simulation across nodes. + +**Caveats:** +- Several of the vectorization fixes change how the PRNG stream is consumed. + Same seed will produce statistically equivalent reads, but not bit-identical + to v4.4.2. Re-baseline any regression tests that compared exact output. + # NEAT v4.4.2 - Added GC bias modeling to generate reads and a function to create a GC bias model from real data. - Added improvements and efficiency upgrades to generate-reads. diff --git a/README.md b/README.md index 3d414c0f..24f39272 100755 --- a/README.md +++ b/README.md @@ -251,8 +251,25 @@ Features: - Output a BAM file with the 'golden' set of aligned reads. These indicate where each read originated and how it should be aligned with the reference - Create paired tumour/normal datasets using characteristics learned from real tumour data +### Output ordering + +The BAM file NEAT produces is coordinate-sorted (by construction at write time — no separate sort pass is run, which used to allocate ~1 GB of sort buffer at stitch time). + +The FASTQ files are written in the natural fragment-sampling order. This is *roughly* random — fragments are drawn from batched random positions — but is not a strict uniform shuffle. If your downstream tooling assumes a real-sequencer-style shuffle, pipe the FASTQ through [`seqkit shuffle`](https://bioinf.shenwei.me/seqkit/usage/#shuffle): + +```sh +# Single-end +seqkit shuffle reads.fastq.gz -o reads.shuffled.fastq.gz + +# Paired-end — use a shared seed so the two files stay mate-aligned +seqkit shuffle -s 42 reads_r1.fastq.gz -o reads_r1.shuffled.fastq.gz +seqkit shuffle -s 42 reads_r2.fastq.gz -o reads_r2.shuffled.fastq.gz +``` + ### Estimated runtimes +> **Note:** The tables below are from the original NEAT 4.4 (v4.4.0) benchmark. NEAT v4.4.3 is roughly **9× faster on multi-threaded ecoli** and uses **~3× less memory** thanks to the performance work landed in that release. The relative shape of the tables (size scaling, contig vs. size mode tradeoffs) remains accurate, but absolute runtimes should be divided by ~5–10× for v4.4.3+. See ChangeLog v4.4.3 for a detailed before/after table on ecoli and a c_elegans scale-test. + To give users a sense of how long `neat read-simulator` runs may take, we benchmarked NEAT 4.4 on several reference genomes. All runs were paired-end, with read length of 150 bp, coverage of 10, fragment mean of 300 bp, and fragment standard deviation of 50 bp. Runtimes are reported as the average across three unique runs (`Avg. time (ms)`) and the corresponding runtime in minutes. Cells marked with N/A indicate that NEAT was not able to run to completion. Benchmarks were run on a System76 Meerkat with a 13th Gen Intel Core i3-1315U (8 logical cores, up to 4.50 GHz) and 16 GiB RAM, using a 512 GB SSD and Ubuntu 24.04.3 LTS (Linux kernel 6.14). Actual runtimes will vary depending on your hardware. @@ -347,7 +364,7 @@ neat read-simulator \ ``` ### Parallelizing simulation -In this case, you would want to split the contig into blocks, rather than reading by contig. Even in single-threaded mode, this is likely to be faster. The default block size of 500,000 yields results quickly on a variety of datasets and can be easily modified to meet your requirements. +Split the contig into blocks rather than reading by contig. Even in single-threaded mode this is likely to be faster. The chunk size auto-tunes from total genome length and thread count, targeting roughly 8 chunks per thread for good load balancing — on small bacterial genomes you get ~500 kb chunks (similar to NEAT's old hardcoded default), on human-scale genomes you get ~6 Mb chunks (a few hundred total instead of thousands). Specify `parallel_block_size` explicitly if you want to override. Also, we demonstrate the situation where you do not want any logs produced: @@ -361,7 +378,8 @@ fragment_mean: 350 fragment_st_dev: 50 threads: 12 parallel_mode: size -parallel_block_size: 500000 +# parallel_block_size omitted: auto-tuned from genome length and thread count. +# Set explicitly (e.g. parallel_block_size: 1000000) only if you have a reason. ``` Then run with the command: ``` @@ -371,6 +389,61 @@ neat read-simulator \ -o /home/me/simulated_reads/ ``` +### Multi-node deployment on HPC clusters + +NEAT runs on a single node using Python's `multiprocessing`. To use multiple nodes on a supercomputer, dispatch one NEAT job per contig (or contig group) as a job-array element and concatenate the outputs afterwards. Each array task gets its own node and uses all available cores on it. NEAT itself doesn't need to know about the cluster. + +SLURM example with one task per human chromosome (24 tasks): + +```bash +#!/bin/bash +#SBATCH --job-name=neat-array +#SBATCH --array=1-24 +#SBATCH --cpus-per-task=64 +#SBATCH --mem=64G +#SBATCH --time=04:00:00 + +CHROM=$(awk "NR==$SLURM_ARRAY_TASK_ID" chroms.txt) # e.g. chr1, chr2, ... +samtools faidx hg38.fa "$CHROM" > "ref_${CHROM}.fa" # one chromosome per task + +cat > "config_${CHROM}.yml" < all_r1.fastq.gz +cat out/*/chr*_r2.fastq.gz > all_r2.fastq.gz + +# VCFs: use bcftools concat for proper merging +bcftools concat -o all.vcf.gz -Oz out/*/chr*_golden.vcf.gz +bcftools index all.vcf.gz +``` + +This gives you whole-genome simulation in roughly `single-chromosome-time / nodes` wall time. For a 30× human genome with 24 nodes (each 64 cores), a per-chromosome run takes ~10 min and the array finishes in ~10 min wall, plus a few minutes for the final concat. + ### Insert specific variants Simulate a whole genome dataset with only the variants in the provided VCF file using `-v` and setting mutation rate to 0 with `-M`. diff --git a/neat/models/error_models.py b/neat/models/error_models.py index f7036337..37cec861 100644 --- a/neat/models/error_models.py +++ b/neat/models/error_models.py @@ -81,32 +81,27 @@ def get_quality_scores( :return: An array of quality scores. """ if self.uniform_quality_score: - return np.array([self.uniform_quality_score] * length) - else: - if length == model_read_length: - quality_index_map = np.arange(model_read_length) - else: - # This is basically a way to evenly spread the distribution across the number of bases in the read - quality_index_map = np.array( - [max([0, model_read_length * n // length]) for n in range(length)] - ) - - temp_qual_array = [] - for i in quality_index_map: - score = rng.normal( - self.quality_score_probabilities[i][0], - scale=self.quality_score_probabilities[i][1] - ) - # make sure score is in range and an int - score = round(score) - if score > 42: - score = 42 - if score < 1: - score = 1 + return np.full(length, self.uniform_quality_score, dtype=int) - temp_qual_array.append(score) + # Map each position in the read onto a row of the (model_read_length, 2) score + # parameters table. When read length matches the model exactly this is just an + # identity; otherwise we evenly spread the model distribution across the read. + if length == model_read_length: + quality_index_map = np.arange(model_read_length, dtype=np.int64) + else: + quality_index_map = np.maximum( + 0, (model_read_length * np.arange(length, dtype=np.int64)) // length + ) - return np.array(temp_qual_array) + # Batched per-position normal draws — one rng.normal call returning `length` draws + # with element-wise (mean, scale) parameters. Replaces a ~150-iteration Python + # loop calling rng.normal scalar-per-base. Same statistical distribution but the + # PRNG stream is consumed in a different order, so seeded outputs are not + # bit-identical to the prior scalar-loop implementation. + means = self.quality_score_probabilities[quality_index_map, 0] + scales = self.quality_score_probabilities[quality_index_map, 1] + scores = rng.normal(means, scales) + return np.clip(np.rint(scores).astype(int), 1, 42) class MarkovQualityModel: @@ -182,30 +177,34 @@ def get_sequencing_errors( :return: Modified sequence and associated quality scores """ - error_indexes = [] introduced_errors = [] - # pre-compute the error rate for each quality score. This is the inverse of the phred score equation - quality_score_error_rate: dict[int, float] = {x: 10. ** (-x / 10) for x in quality_scores} # The use case here would be someone running a simulation where they want no sequencing errors. # No need to run any loops in this case. if self.average_error == 0: return introduced_errors - else: - i = len(quality_scores) - while len(error_indexes) < num_errors and i > 0: - index = rng.choice(list(range(len(quality_scores)))) - if rng.random() < quality_score_error_rate[quality_scores[index]]: - error_indexes.append(index) - i -= 1 - # Fallback: if quality scores are too high to naturally reach num_errors, - # force errors at positions with at-or-below-median quality scores. - # Using <= so that uniform quality arrays (all scores equal) always make progress. + + n = len(quality_scores) + # Batched rejection sampling: draw n candidate indices and n uniform deviates in two numpy + # calls, then accept the first num_errors candidates where the deviate is below the + # quality-derived error rate. Equivalent in distribution to the per-iteration scalar loop + # but ~150x cheaper in Python overhead. Statistical caveat: this changes the order in + # which the underlying PRNG stream is consumed, so seeded runs are not bit-identical to + # the prior interleaved-draw implementation. + candidate_indices = rng.integers(n, size=n) + candidate_randoms = rng.random(size=n) + rates_at_candidates = 10.0 ** (-quality_scores[candidate_indices].astype(float) / 10.0) + accepted = candidate_indices[candidate_randoms < rates_at_candidates] + error_indexes = accepted[:num_errors].tolist() + + if len(error_indexes) < num_errors: + # Fallback: if quality scores are too high to naturally reach num_errors, force errors + # at positions with at-or-below-median quality scores. Using <= so that uniform + # quality arrays (all scores equal) always make progress. median_score = median(quality_scores) - while len(error_indexes) < num_errors: - index = rng.integers(len(quality_scores)) - if quality_scores[index] <= median_score: - error_indexes.append(index) + eligible = np.flatnonzero(quality_scores <= median_score) + needed = num_errors - len(error_indexes) + error_indexes.extend(rng.choice(eligible, size=needed, replace=True).tolist()) total_indel_length = 0 # To prevent deletion collisions diff --git a/neat/read_simulator/runner.py b/neat/read_simulator/runner.py index 069a98f3..509d6cd4 100644 --- a/neat/read_simulator/runner.py +++ b/neat/read_simulator/runner.py @@ -87,6 +87,26 @@ def read_simulator_runner(config: str, output_dir: str, file_prefix: str): # Use the default value average_error = 0.009228843915252066 + # Auto-tune the chunk size from total genome length and thread count when the user + # left it at the default (parallel_block_size <= 0). Target ~8 chunks per thread — + # enough chunks to balance load across workers when chunk durations vary, few enough + # that the stitch step doesn't drown in file-handle overhead. Clamp to [100 kb, 50 Mb] + # to avoid pathological cases on very small or very large genomes. The FASTA index + # (.fai) gives us per-contig lengths without parsing sequences; pysam.FastaFile + # builds the .fai on first access if needed. + if (options.threads > 1 + and options.parallel_mode == "size" + and options.parallel_block_size <= 0): + with pysam.FastaFile(str(options.reference)) as _fa: + total_bp = sum(_fa.get_reference_length(c) for c in _fa.references) + target_chunks = options.threads * 8 + auto_size = max(100_000, min(50_000_000, total_bp // max(1, target_chunks))) + _LOG.info( + f"Auto-tuned parallel_block_size to {auto_size:,} bp " + f"({total_bp:,} bp genome / {options.threads} threads x 8 chunks/thread)" + ) + options.parallel_block_size = auto_size + # Split file by chunk for parallel analysis or by contig for either parallel or single analysis _LOG.info("Splitting reference...") (splits_files_dict, count, reference_keys_with_lens) = split_main( @@ -165,10 +185,25 @@ def read_simulator_runner(config: str, output_dir: str, file_prefix: str): contig_dict = {contig: contig_list.index(contig) for contig in reference_keys_with_lens.keys()} for contig in splits_files_dict: contig_index = contig_dict[contig] - for ((start, length), splits_file) in splits_files_dict[contig].items(): - block_percentage = length / reference_keys_with_lens[contig] + contig_chunks = list(splits_files_dict[contig].items()) + # Precompute each chunk's non-overlapping responsibility length: the distance from + # this chunk's start to the next chunk's start, or the chunk's full length for the + # last chunk of the contig. Reads with r1.position in [0, responsibility_length) + # belong to this chunk; the trailing overlap region (if any) provides reference + # context for boundary-spanning reads but is owned by the next chunk for placement. + # This is what lets the stitch step be a raw BGZF concatenation rather than a sort. + chunk_responsibility = {} + for i, ((c_start, c_end), _) in enumerate(contig_chunks): + if i + 1 < len(contig_chunks): + next_start = contig_chunks[i + 1][0][0] + chunk_responsibility[(c_start, c_end)] = next_start - c_start + else: + chunk_responsibility[(c_start, c_end)] = c_end - c_start + for ((start, length), splits_file) in contig_chunks: + responsibility_length = chunk_responsibility[(start, length)] + block_percentage = responsibility_length / reference_keys_with_lens[contig] block_errors = errors_per_contig[contig] * block_percentage - estimated_number_of_reads = (length // options.read_len) * options.coverage + estimated_number_of_reads = (responsibility_length // options.read_len) * options.coverage errors_per_read = round(block_errors / estimated_number_of_reads) if errors_per_read < 1.0 and block_errors > 0: # We know we need a few errors, but it's a small number total @@ -209,6 +244,7 @@ def read_simulator_runner(config: str, output_dir: str, file_prefix: str): discard_regions_dict[contig], mutation_rate_dict[contig], errors_per_read, + responsibility_length, ) _LOG.info(f"Completed simulating contig {contig}.") output_files.append((thread_idx, files_written)) @@ -230,6 +266,7 @@ def read_simulator_runner(config: str, output_dir: str, file_prefix: str): thread_discard_regions, thread_mutation_regions, errors_per_read, + responsibility_length, )) thread_idx += 1 diff --git a/neat/read_simulator/single_runner.py b/neat/read_simulator/single_runner.py index 3b6ea72a..dab470c0 100644 --- a/neat/read_simulator/single_runner.py +++ b/neat/read_simulator/single_runner.py @@ -2,10 +2,8 @@ Runner for read-simulator in single-ended mode """ import gzip -import os import pickle -import pysam from Bio import SeqIO, bgzf import logging from pathlib import Path @@ -35,6 +33,7 @@ def read_simulator_single( discard_regions: list, mutation_regions: list, errors_per_read: int, + responsibility_length: int | None = None, ) -> tuple[int, str, ContigVariants, dict[str, Path], ]: """ inputs: @@ -115,7 +114,11 @@ def read_simulator_single( ) if local_options.produce_fastq or local_options.produce_bam: - reads_to_write = generate_reads( + # generate_reads streams FASTQ and (if requested) BAM records directly to the + # output handles on local_output_file_writer. It no longer accumulates Read + # objects per chunk — the coordinate-sorted BAM is emitted inline using a + # bounded min-heap to interleave PE mate records. See generate_reads docstring. + generate_reads( thread_idx, local_seq_record, seq_error_model, @@ -131,33 +134,8 @@ def read_simulator_single( contig_name, contig_index, coords[0], + responsibility_length, ) - if local_options.produce_bam: - # Writing an intermediate bam that is sorted, to make compiling them together at the end easier. - bam_handle = local_output_file_writer.files_to_write[local_output_file_writer.bam] - for read_data in reads_to_write: - read1 = read_data[0] - read2 = read_data[1] - if read1: - local_output_file_writer.write_bam_record( - read1, - contig_index, - bam_handle, - local_options.read_len - ) - if read2: - local_output_file_writer.write_bam_record( - read2, - contig_index, - bam_handle, - local_options.read_len - ) - bam_handle.flush() - bam_handle.close() - sorted_bam = local_output_file_writer.bam.with_suffix(".sorted.bam") - pysam.sort("-@", str(local_options.threads), "-o", str(sorted_bam), str(local_output_file_writer.bam)) - os.rename(str(sorted_bam), str(local_output_file_writer.bam)) - _LOG.info(f"bam for thread {thread_idx} written") if local_options.produce_vcf: write_block_vcf(local_variants, contig_name, block_start, local_ref_index, local_output_file_writer) diff --git a/neat/read_simulator/utils/generate_reads.py b/neat/read_simulator/utils/generate_reads.py index 74485bb1..a8ca23b3 100644 --- a/neat/read_simulator/utils/generate_reads.py +++ b/neat/read_simulator/utils/generate_reads.py @@ -1,3 +1,4 @@ +import heapq import logging import pickle import time @@ -29,6 +30,8 @@ def cover_dataset( options: Options, fragment_model: FragmentLengthModel | None, gc_model: GCBiasModel | None, + *, + responsibility_length: int | None = None, ) -> list: """ Covers a dataset to the desired depth in the paired ended case. This is the main algorithm for creating the reads @@ -38,35 +41,48 @@ def cover_dataset( :param options: The options for the run :param fragment_model: The fragment model used for to generate random fragment lengths :param gc_model: The GC bias model used for fragment selection + :param responsibility_length: The number of bases at the start of `reference` that this + chunk is responsible for placing read1 starts in. Reads still extend into the + trailing overlap region for context. Defaults to len(reference) (the chunk owns + its full reference). For non-final chunks under sub-contig parallelism, this is + the chunk step (chunk_size - overlap) — restricting read1 positions to the + non-overlapping range so that per-chunk BAMs can be byte-concatenated into a + coordinate-sorted whole without re-sorting. """ final_reads = [] span_length = len(reference) + # Number of bases this chunk owns for read1 placement. Defaults to the full chunk. + if responsibility_length is None: + responsibility_length = span_length + # Last valid read1 start position. Bounded by both the chunk's responsibility (so + # reads don't appear in the next chunk's range) and by the requirement that the read + # fit in available reference (so a read of length read_len has its tail in-bounds). + max_start = min(responsibility_length - 1, span_length - options.read_len) # sanity check if span_length / fragment_model.fragment_mean < 5: _LOG.warning("The fragment mean is relatively large compared to the chromosome size. You may need to increase " "standard deviation, or decrease fragment mean, if NEAT cannot complete successfully.") # precompute how many reads we want - # The numerator is the total number of base pair calls needed. - # Divide that by read length gives the number of reads needed + # The numerator is the total number of base pair calls needed. Coverage is scaled by + # responsibility_length, not span_length, so chunks don't over-sample their overlap + # region (which is also covered by the next chunk). if options.paired_ended: # TODO use gc bias to skew this number. Calculate at the runner level. - number_reads = ceil(span_length * options.coverage / (2 * options.read_len)) + number_reads = ceil(responsibility_length * options.coverage / (2 * options.read_len)) else: - number_reads = ceil(span_length * options.coverage / options.read_len) + number_reads = ceil(responsibility_length * options.coverage / options.read_len) if gc_model and not gc_model.is_uniform: # CDF-based sampling for GC bias window_size = gc_model.window_size if span_length <= window_size: # Fallback to uniform if region is too short - return _uniform_sampling(span_length, number_reads, options, fragment_model) + return _uniform_sampling(span_length, number_reads, options, fragment_model, + max_start=max_start) - # Build prefix sum of weights - # We need weights for every possible start position that can produce a read. - # For single-ended, start must be in [0, span_length - read_len] - max_start = span_length - options.read_len + # Build prefix sum of weights only over the positions this chunk owns. if max_start < 0: return [] @@ -109,6 +125,15 @@ def cover_dataset( # Batch CDF sampling with adaptive retry (same pattern as _uniform_sampling). min_frag = options.read_len + (10 if options.paired_ended else 0) + # For PE, the read2 record must stay within this chunk's responsibility so the + # cat-stitched output remains coordinate-sorted (read2.position = e - read_len). + # Cap e at responsibility_length + read_len so read2.position <= responsibility_length, + # which lies at-or-before the next chunk's first read1 position. Falls back to + # span_length for SE and for the final chunk (where responsibility = span_length). + if options.paired_ended: + e_limit = min(responsibility_length + options.read_len, span_length) + else: + e_limit = span_length acc_starts: list[np.ndarray] = [] acc_ends: list[np.ndarray] = [] collected = 0 @@ -118,7 +143,7 @@ def cover_dataset( uv = options.rng.random(n_batch) * total_weight s = np.clip(np.searchsorted(prefix_sum, uv).astype(int), 0, max_start) fl = np.array(fragment_model.generate_fragments(n_batch, options.rng)) - e = np.minimum(s + fl, span_length) + e = np.minimum(s + fl, e_limit) mask = e - s >= min_frag acc_starts.append(s[mask]) acc_ends.append(e[mask]) @@ -135,19 +160,39 @@ def cover_dataset( else: # Uniform sampling - final_reads = _uniform_sampling(span_length, number_reads, options, fragment_model) + final_reads = _uniform_sampling( + span_length, number_reads, options, fragment_model, + max_start=max_start, responsibility_length=responsibility_length, + ) - # Now we shuffle them to add some randomness - options.rng.shuffle(final_reads) + # FASTQ is written in the natural fragment-sampling order (no explicit shuffle). + # The BAM is sorted to coordinate order at the BAM-write boundary in single_runner, + # which interleaves PE mate reads correctly. Users who want randomized FASTQ + # ordering can pipe the output through `seqkit shuffle` — see README "FASTQ + # output order". return final_reads -def _uniform_sampling(span_length, number_reads, options, fragment_model): +def _uniform_sampling(span_length, number_reads, options, fragment_model, *, + max_start=None, responsibility_length=None): if span_length <= options.read_len: return [] - max_start = span_length - options.read_len + # If caller didn't restrict the read1 placement range, default to the full + # in-bounds span (last valid r1.position = span_length - read_len). + if max_start is None: + max_start = span_length - options.read_len + if max_start < 0: + return [] + if responsibility_length is None: + responsibility_length = span_length min_frag = options.read_len + (10 if options.paired_ended else 0) + # For PE, cap e so read2.position stays within this chunk's responsibility (see GC path + # for the full rationale). + if options.paired_ended: + e_limit = min(responsibility_length + options.read_len, span_length) + else: + e_limit = span_length # First batch: 2× candidates covers >99 % of cases when frag_mean >> read_len. # Retry in small increments only when fragment_mean < read_len (rare). @@ -159,7 +204,7 @@ def _uniform_sampling(span_length, number_reads, options, fragment_model): while collected < number_reads: s = options.rng.integers(0, max_start + 1, size=n_batch) fl = np.array(fragment_model.generate_fragments(n_batch, options.rng)) - e = np.minimum(s + fl, span_length) + e = np.minimum(s + fl, e_limit) mask = e - s >= min_frag acc_starts.append(s[mask]) acc_ends.append(e[mask]) @@ -232,6 +277,7 @@ def generate_reads( contig_name: str, contig_index: int, ref_start: int, + responsibility_length: int | None = None, ): """ This will generate reads given a set of parameters for the run. The reads will output in a fastq. @@ -252,7 +298,8 @@ def generate_reads( :param contig_index: The index of the above chromosome within the overall bam header :param ref_start: The start point for this reference segment. Default is 0 and this is currently not fully implemented, to be used for parallelization. - :return: A tuple of the filenames for the temp files created + :return: None. FASTQ and BAM records are streamed directly to the output handles + on `ofw` as reads are generated; no per-chunk accumulation of Read objects. """ # _LOG.info(f'Sampling reads for thread {thread_index}...') start_time = time.time() @@ -262,7 +309,7 @@ def generate_reads( f"Contig '{contig_name}' (length {len(reference)}) is shorter than read_len " f"({options.read_len}). Skipping contig." ) - return [] + return # _LOG.debug("Covering dataset.") t = time.time() @@ -271,13 +318,24 @@ def generate_reads( options, fraglen_model, gc_model, + responsibility_length=responsibility_length, ) # _LOG.debug(f"Dataset coverage took: {(time.time() - t)/60:.2f} m") + # Process fragments in read1-start order so the BAM emerges coordinate-sorted by + # construction. In paired-end mode the read2 of each fragment is at a different + # (typically later) position, so we hold those in a min-heap keyed by position and + # flush each one before writing the next read1 that would precede it. The heap is + # bounded by ~(fragment_length / read_length) entries — single-digit reads in + # practice — so per-worker memory stays constant in chunk size and coverage. + reads.sort(key=lambda r: r[0]) + # _LOG.debug("Writing fastq(s) and optional bam, if indicated") t = time.time() - reads_to_write = [] + bam_handle = ofw.files_to_write[ofw.bam] if options.produce_bam else None + r2_buffer: list[tuple[int, int, "Read"]] = [] # (position, counter, read) + r2_counter = 0 for i in range(len(reads)): # First thing we'll do is check to see if this read is filtered out by a bed file @@ -364,6 +422,16 @@ def generate_reads( errors_per_read, options.rng ) + + # Stream BAM in coordinate order: flush any buffered read2 records whose + # position lies before this read1, then write read1 itself. Since fragments + # are sorted by read1.position, read1 positions arrive monotonically. + if bam_handle is not None: + while r2_buffer and r2_buffer[0][0] < read_1.position: + _, _, buffered_r2 = heapq.heappop(r2_buffer) + ofw.write_bam_record(buffered_r2, contig_index, bam_handle, options.read_len) + ofw.write_bam_record(read_1, contig_index, bam_handle, options.read_len) + # skip over read 2 for single ended reads. if options.paired_ended: # Padding, as above @@ -403,9 +471,16 @@ def generate_reads( errors_per_read, options.rng ) - reads_to_write.append((read_1, read_2)) - else: - reads_to_write.append((read_1, None)) + if bam_handle is not None: + heapq.heappush(r2_buffer, (read_2.position, r2_counter, read_2)) + r2_counter += 1 + + # Flush any read2 records still in the buffer — these all have positions at or + # after the last read1 we wrote, so popping them in heap order gives the correct + # coordinate-sorted tail. + if bam_handle is not None: + while r2_buffer: + _, _, buffered_r2 = heapq.heappop(r2_buffer) + ofw.write_bam_record(buffered_r2, contig_index, bam_handle, options.read_len) _LOG.info(f"Finished sampling reads for thread {thread_index} in {(time.time() - start_time)/60:.2f} m") - return reads_to_write diff --git a/neat/read_simulator/utils/options.py b/neat/read_simulator/utils/options.py index 4e6c3007..627eed34 100644 --- a/neat/read_simulator/utils/options.py +++ b/neat/read_simulator/utils/options.py @@ -82,7 +82,7 @@ def __init__(self, produce_fastq: bool = True, min_mutations: int = 0, parallel_mode: str = "contig", - parallel_block_size: int = 500000, + parallel_block_size: int = 0, cleanup_splits: bool = True, splits_dir: Path | None = None, reuse_splits: bool = False, @@ -134,7 +134,9 @@ def __init__(self, :param produce_fastq: False to turn off default fastq creation :param min_mutations: If you wish to have a minimunm number of mutations per block, enter it here :param parallel_mode: If you wish to use block size method, enter 'size' here - :param parallel_block_size: If you use size method, specify any value but 500000 to change the block size + :param parallel_block_size: If you use size method, specify a positive integer to set the per-chunk + size in basepairs. The default (0) auto-tunes from total genome length and thread count, targeting + ~8 chunks per thread. Specify a value explicitly to override. :param cleanup_splits: Set to False in order to preserve splits after run :param reuse_splits: Attempts to reuse existing splits file """ @@ -240,7 +242,7 @@ def from_cli(output_dir: Path, 'min_mutations': (int, 0, None, None), 'overwrite_output': (bool, False, None, None), 'parallel_mode': (str, 'size', 'choice', ['size', 'contig']), - 'parallel_block_size': (int, 500000, None, None), + 'parallel_block_size': (int, 0, None, None), 'threads': (int, 1, 1, 1000), 'cleanup_splits': (bool, True, None, None), 'reuse_splits': (bool, False, None, None), @@ -446,7 +448,10 @@ def log_configuration(self): if self.parallel_mode == 'size': _LOG.info(f'Splitting reference into chunks.') - _LOG.info(f' - splitting input into size {self.parallel_block_size}') + if self.parallel_block_size > 0: + _LOG.info(f' - splitting input into size {self.parallel_block_size}') + else: + _LOG.info(f' - chunk size will be auto-tuned from genome length and thread count') elif self.parallel_mode == 'contig': _LOG.info(f'Splitting input by contig.') if self.reuse_splits: diff --git a/neat/read_simulator/utils/read.py b/neat/read_simulator/utils/read.py index 22611de2..0c992228 100644 --- a/neat/read_simulator/utils/read.py +++ b/neat/read_simulator/utils/read.py @@ -170,26 +170,49 @@ def update_quality_array( def apply_errors(self, quality_model: TraditionalQualityModel): """ - This function applies errors to a sequence and calls the update_quality_array function after - - :param quality_model: The error model for this run, - :return: None, The sequence, with errors applied + Apply this read's stored sequencing errors to read_sequence and quality_array in + a single pass. + + self.errors is in descending position order (set up by get_sequencing_errors). + The original implementation iterated in that order, doing one np.concatenate + + one MutableSeq slice/concat per error — quadratic in errors-per-read. Here we + iterate ascending and append unchanged ranges plus error alternates into chunk + lists, then join once at the end. The result preserves the prior per-error + quality-array semantics (including the convention that insertion anchors lose + their original quality and gain alt_len-1 low-quality scores). """ - mutated_sequence = MutableSeq(self.read_sequence) - for error in self.errors: - # Replace the entire ref sequence with the entire alt sequence - mutated_sequence = \ - mutated_sequence[:error.location] + error.alt + mutated_sequence[error.location+len(error.ref):] - # update quality score for error - self.update_quality_array( - len(error.ref), - error.alt, - error.location, - "error", - list(quality_model.quality_scores), - ) + if not self.errors: + return + + errors_asc = list(reversed(self.errors)) + low_score = min(quality_model.quality_scores) + + seq_str = str(self.read_sequence) + seq_chunks: list[str] = [] + q_chunks: list[np.ndarray] = [] + prev_end = 0 + + for error in errors_asc: + loc = error.location + ref_len = len(error.ref) + alt_str = str(error.alt) + alt_len = len(alt_str) + seq_chunks.append(seq_str[prev_end:loc]) + q_chunks.append(self.quality_array[prev_end:loc]) + seq_chunks.append(alt_str) + if alt_len > 1: + q_chunks.append(np.full(alt_len - 1, low_score, dtype=int)) + elif ref_len > 1 and alt_len == 1: + pass # deletion — no new quality scores + else: + q_chunks.append(np.array([low_score], dtype=int)) + prev_end = loc + ref_len - self.read_sequence = Seq(mutated_sequence) + seq_chunks.append(seq_str[prev_end:]) + q_chunks.append(self.quality_array[prev_end:]) + + self.read_sequence = Seq(''.join(seq_chunks)) + self.quality_array = np.concatenate(q_chunks) def apply_mutations(self, quality_scores: list, rng: Generator): """ @@ -341,8 +364,10 @@ def finalize_read_and_write( # It updates the quality array and reference segment in place, including reversing them, if appropriate self.convert_masking(qual_model) - # set the read sequence to match the reference, then modify - self.read_sequence = deepcopy(self.reference_segment) + # Start the read sequence as an alias of the masked reference. Biopython Seq is + # immutable; apply_mutations and apply_errors below produce new Seq objects via + # their own working copies, so no independent deepcopy is needed here. + self.read_sequence = self.reference_segment # Get errors for the read and update the quality score self.errors, self.padding = err_model.get_sequencing_errors( @@ -385,7 +410,9 @@ def convert_masking(self, quality_model: TraditionalQualityModel): bad_score = min(quality_model.quality_scores) # we'll use generic human repeats, as commonly found in masked regions. We may refine this to make configurable repeat_bases = list("TTAGGG") - raw_sequence = deepcopy(self.reference_segment) + # Immutable Biopython Seq; the MutableSeq below is the working copy, so no + # deepcopy needed here. + raw_sequence = self.reference_segment start = raw_sequence.find('N') if start != -1: @@ -404,38 +431,108 @@ def convert_masking(self, quality_model: TraditionalQualityModel): def make_cigar(self): """ - Aligns the reference and mutated sequences. + Build the CIGAR string describing how this read aligns to its reference window. + + Three paths, in priority order: + 1. No indels (the common case, ~98% of reads on default error models) — the CIGAR + is just `{run_read_length}M`. + 2. Forward read whose indels all come from sequencing errors — walk the sorted + error positions in O(read length) to emit M/I/D ops directly. No alignment + needed; each error's stored location is already the anchor position in + reference-window coordinates. + 3. Anything else (reverse reads with indels, or any read where a mutation indel + may have been applied) — fall back to pairwise alignment. Mutation indels are + applied before sequencing errors and can shift the read coordinate frame, and + reverse reads truncate the segment from the opposite end; neither case fits + the simple walker. """ - # These parameters were set to minimize breaks in the mutated sequence and find the best - # alignment from there. + error_indels = [ + e for e in self.errors + if e.error_type is Insertion or e.error_type is Deletion + ] + has_mutation_indels = any( + type(v) is Insertion or type(v) is Deletion + for variants_at_loc in self.mutations.values() + for v in variants_at_loc + ) + + if not error_indels and not has_mutation_indels: + return f"{self.run_read_length}M" + + if not has_mutation_indels and not self.is_reverse: + return self._cigar_from_error_indels(error_indels) + return self._cigar_via_alignment() + + def _cigar_from_error_indels(self, error_indels): """ - The sequence alignment. We restrict the alignment to the section of the reference where we know the read - came from and try to generate a minimal cigar string. The cigar string part may still need tweaking. + Build the CIGAR for a forward read whose only indels come from sequencing errors. + + Walks the error list in reference-window order, emitting M/I/D ops and stopping + once the CIGAR's query side reaches run_read_length. Run-length-encodes the result + as it goes. """ + events = sorted(error_indels, key=lambda e: e.location) + ops: list[list] = [] + + def emit(op_char, count): + if count <= 0: + return + if ops and ops[-1][0] == op_char: + ops[-1][1] += count + else: + ops.append([op_char, count]) + + ref_pos = 0 + remaining_query = self.run_read_length + + for ev in events: + if remaining_query == 0: + break + # M's leading up to and including the anchor base at ev.location. + m_count = ev.location - ref_pos + 1 + m_emitted = min(m_count, remaining_query) + emit('M', m_emitted) + remaining_query -= m_emitted + ref_pos = ev.location + 1 + if remaining_query == 0: + break + if ev.error_type is Insertion: + i_emitted = min(ev.length, remaining_query) + emit('I', i_emitted) + remaining_query -= i_emitted + else: # Deletion + emit('D', ev.length) + ref_pos += ev.length + + if remaining_query > 0: + emit('M', remaining_query) + + return ''.join(f"{count}{op}" for op, count in ops) + + def _cigar_via_alignment(self): + """ + Pairwise-alignment fallback for cases the direct walker cannot model: reverse reads + with indels, or any read where a mutation indel preceded the sequencing errors. + """ template = self.reference_segment if self.is_reverse: template = template.reverse_complement() query = self.read_sequence cigar = ["M"] * self.run_read_length - aligner2 = PairwiseAligner() - aligner2.mode = "fogsaa" - alignments2 = aligner2.align(template, query) - aligned_template = alignments2[0][0] - aligned_query = alignments2[0][1] + aligner = PairwiseAligner() + aligner.mode = "fogsaa" + alignments = aligner.align(template, query) + aligned_template = alignments[0][0] + aligned_query = alignments[0][1] start_point = self.run_read_length - 1 if self.is_reverse else 0 for i in range(self.run_read_length): - if self.is_reverse: - index = start_point - i - else: - index = i + index = start_point - i if self.is_reverse else i if aligned_template[index] == "-": cigar[index] = "I" elif aligned_query[index] == "-": - # Ds are special because they don't count toward the final total and we need to maintain a consistent - # cigar length, so we insert. cigar.insert(index, "D") if self.is_reverse: cigar.reverse() diff --git a/neat/read_simulator/utils/stitch_outputs.py b/neat/read_simulator/utils/stitch_outputs.py index b9f06c20..5e415137 100644 --- a/neat/read_simulator/utils/stitch_outputs.py +++ b/neat/read_simulator/utils/stitch_outputs.py @@ -4,6 +4,8 @@ import resource import shutil import pysam +import time +from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import List @@ -17,15 +19,24 @@ _LOG = logging.getLogger(__name__) -def concat(files_to_join: List[Path], dest_file: gzip.GzipFile) -> None: +def concat(files_to_join: List[Path], dest_path: Path) -> None: + """ + Byte-level concatenation of gzip-compressed inputs into dest_path. + + Concatenated gzip streams are themselves a valid gzip file (gzip spec section 2.2 — + a gzip file is a sequence of "members"; decompression yields the concatenation of + contents). Skipping zlib entirely is much faster than the prior path that + decompressed each input and re-compressed into the destination. The caller must + ensure any gzip handle on dest_path is closed before invoking this (we open + dest_path in 'wb' mode, which truncates). + """ if not files_to_join: - # Nothing to do, and no error to throw - _LOG.warning(f"Concat called but there are no files to join: {files_to_join}" ) + _LOG.warning(f"Concat called but there are no files to join: {files_to_join}") return - - for f in files_to_join: - with gzip.open(f, 'rt') as in_f: - shutil.copyfileobj(in_f, dest_file) + with open(dest_path, 'wb') as out: + for f in files_to_join: + with open(f, 'rb') as in_f: + shutil.copyfileobj(in_f, out, length=4 * 1024 * 1024) def merge_vcfs(vcfs: List[Path], ofw: OutputFileWriter) -> None: dest = ofw.files_to_write[ofw.vcf] @@ -45,16 +56,15 @@ def merge_vcfs(vcfs: List[Path], ofw: OutputFileWriter) -> None: _LOG.warning(f"merge_vcfs: removed {n_duplicates} duplicate VCF line(s) during merge.") def merge_bam(bam_files: List[Path], ofw: OutputFileWriter, threads: int): - merged_file = ofw.tmp_dir / "temp_merged.bam" - intermediate_files = [] - # Note 1000 is arbitrary. May need to be a user parameter/adjustable/a function - for i in range(0, len(bam_files), 500): - temp_file = str(ofw.tmp_dir / f"temp_merged_{i}.bam") - pysam.merge("--no-PG", "-f", temp_file, *map(str, bam_files[i:i+500])) - intermediate_files.append(temp_file) - pysam.merge("--no-PG", "-f", str(merged_file), *intermediate_files) - pysam.sort("-@", str(threads), "-m", "1G", "-o", str(ofw.bam), str(merged_file)) - merged_file.unlink(missing_ok=True) + # Per-worker BAMs are coordinate-sorted within themselves, and each chunk owns a + # non-overlapping reference range for r1 placement (see cover_dataset's + # responsibility_length), so the BAMs concatenate into a globally coordinate-sorted + # output without any sort or k-way merge. pysam.cat does a raw BGZF concatenation — + # no decompression / re-encode — and is typically 10-30x faster than pysam.merge on + # large outputs. At supercomputer scale this is the difference between the stitch + # step being I/O-bound (fast) and BGZF-bound (slow). + pysam.cat("-o", str(ofw.bam), *map(str, bam_files)) + # The .bai index is produced by runner.py after stitching. def main( ofw: OutputFileWriter, @@ -76,14 +86,37 @@ def main( vcf_list.append(file_dict["vcf"]) if file_dict["bam"]: bam_list.append(file_dict["bam"]) - # concatenate all files of each type. An empty list will result in no action + # The byte-level FASTQ concat opens dest_path in 'wb' mode (truncates), so we close + # any existing gzip handles on those paths first. flush_and_close_files (called later + # from runner) will skip handles it finds already closed. + for path_attr in (ofw.fq1, ofw.fq2): + if path_attr is not None and path_attr in ofw.files_to_write: + try: + ofw.files_to_write[path_attr].close() + except Exception: + pass + + # Run the per-output-type stitches concurrently. They write to independent files and + # spend most of their time in I/O (raw byte copy for FASTQ/BAM, gzip read for VCF + # dedup) — Python's GIL is released during those calls, so threads overlap. Total + # stitch wall ≈ max(fq, vcf, bam) instead of their sum. At supercomputer scale where + # BAM dominates and FASTQ/VCF are small, the saving is bounded by the BAM cat alone, + # but on smaller-output workloads the overlap matters. + stitch_start = time.time() + work = [] if fq1_list: - concat(fq1_list, ofw.files_to_write[ofw.fq1]) + work.append(("fq1", lambda: concat(fq1_list, ofw.fq1))) if fq2_list: - concat(fq2_list, ofw.files_to_write[ofw.fq2]) + work.append(("fq2", lambda: concat(fq2_list, ofw.fq2))) if vcf_list: - merge_vcfs(vcf_list, ofw) + work.append(("vcf", lambda: merge_vcfs(vcf_list, ofw))) if bam_list: - merge_bam(bam_list, ofw, threads) - # Final success message via logging - _LOG.info("Stitching complete!") + work.append(("bam", lambda: merge_bam(bam_list, ofw, threads))) + + if work: + with ThreadPoolExecutor(max_workers=len(work)) as exe: + futures = {exe.submit(fn): label for label, fn in work} + for future in futures: + # .result() re-raises any exception from the worker. + future.result() + _LOG.info(f"Stitching complete! ({time.time() - stitch_start:.1f} s parallel stitch)") diff --git a/pyproject.toml b/pyproject.toml index 09d66eba..ec50b6db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "neat-genreads" -version = "4.4.2" +version = "4.4.3" description = "NGS Simulation toolkit" readme = "README.md" authors = ["Joshua Allen ", "Keshav Gandhi "] diff --git a/tests/test_read_simulator/test_generate_reads.py b/tests/test_read_simulator/test_generate_reads.py index d20ca1e1..7b63ed94 100644 --- a/tests/test_read_simulator/test_generate_reads.py +++ b/tests/test_read_simulator/test_generate_reads.py @@ -321,13 +321,37 @@ def _make_options(paired=False, seed=0): opts.read_len = _READ_LEN opts.paired_ended = paired opts.coverage = 5 + # Enable BAM so generate_reads streams Read objects to the collecting OFW for + # inspection. FASTQ stays off — we don't need its bytes for these tests. opts.produce_fastq = False - opts.produce_bam = False + opts.produce_bam = True opts.produce_vcf = False opts.overwrite_output = True return opts +class _CollectingOFW: + """ + Minimal OutputFileWriter stand-in for generate_reads tests. + + generate_reads now streams Read objects to its output_file_writer instead of + returning a list. Tests that want to inspect the generated reads enable + produce_bam=True and pass an instance of this class as the ofw. Every + ``write_bam_record`` call is captured into ``bam_records`` for assertion. + """ + def __init__(self): + # generate_reads dereferences ofw.bam and ofw.files_to_write[ofw.bam]. + # Provide a stub handle so those lookups succeed; we don't use the handle. + self.bam = "_collecting_bam" + self.fq1 = None + self.fq2 = None + self.files_to_write = {self.bam: SimpleNamespace(write=lambda *a, **kw: None)} + self.bam_records = [] + + def write_bam_record(self, read, contig_id, bam_handle, read_length): + self.bam_records.append(read) + + def _make_models(read_len=_READ_LEN, frag_mean=300): error_model = SequencingErrorModel(read_length=read_len) qual_model = TraditionalQualityModel() @@ -407,16 +431,17 @@ def test_generate_reads_single_ended_returns_read_none_pairs(): err, qual, frag = _make_models() opts = _make_options(paired=False) cv = ContigVariants() + ofw = _CollectingOFW() - results = generate_reads(0, ref, err, 3, qual, frag, get_uniform_gc_model(), cv, - _all_span_targeted(), _nothing_discarded(), - opts, None, "chr1", 0, 0) + generate_reads(0, ref, err, 3, qual, frag, get_uniform_gc_model(), cv, + _all_span_targeted(), _nothing_discarded(), + opts, ofw, "chr1", 0, 0) - assert isinstance(results, list) - assert len(results) > 0 - for read1, read2 in results: - assert isinstance(read1, Read) - assert read2 is None + # In SE mode only read1 records are emitted — all forward strand. + assert len(ofw.bam_records) > 0 + for read in ofw.bam_records: + assert isinstance(read, Read) + assert read.is_reverse is False def test_generate_reads_paired_ended_returns_read_read_pairs(): @@ -424,15 +449,21 @@ def test_generate_reads_paired_ended_returns_read_read_pairs(): err, qual, frag = _make_models() opts = _make_options(paired=True) cv = ContigVariants() + ofw = _CollectingOFW() - results = generate_reads(0, ref, err, 3, qual, frag, get_uniform_gc_model(), cv, - _all_span_targeted(), _nothing_discarded(), - opts, None, "chr1", 0, 0) + generate_reads(0, ref, err, 3, qual, frag, get_uniform_gc_model(), cv, + _all_span_targeted(), _nothing_discarded(), + opts, ofw, "chr1", 0, 0) - assert len(results) > 0 - for read1, read2 in results: - assert isinstance(read1, Read) - assert isinstance(read2, Read) + # PE mode emits both forward (r1) and reverse (r2) records, interleaved by + # position via the heap-buffer in generate_reads. Verify both strands appear. + assert len(ofw.bam_records) > 0 + forwards = [r for r in ofw.bam_records if not r.is_reverse] + reverses = [r for r in ofw.bam_records if r.is_reverse] + assert len(forwards) > 0 + assert len(reverses) > 0 + # Each fragment contributes one r1 and one r2. + assert len(forwards) == len(reverses) def test_generate_reads_read_length_matches_options(): @@ -440,13 +471,15 @@ def test_generate_reads_read_length_matches_options(): err, qual, frag = _make_models() opts = _make_options(paired=False) cv = ContigVariants() + ofw = _CollectingOFW() - results = generate_reads(0, ref, err, 3, qual, frag, get_uniform_gc_model(), cv, - _all_span_targeted(), _nothing_discarded(), - opts, None, "chr1", 0, 0) + generate_reads(0, ref, err, 3, qual, frag, get_uniform_gc_model(), cv, + _all_span_targeted(), _nothing_discarded(), + opts, ofw, "chr1", 0, 0) - for read1, _ in results: - assert len(read1.read_sequence) == _READ_LEN + assert len(ofw.bam_records) > 0 + for read in ofw.bam_records: + assert len(read.read_sequence) == _READ_LEN # --------------------------------------------------------------------------- @@ -459,13 +492,14 @@ def test_generate_reads_targeted_region_flag_false_filters_all(): err, qual, frag = _make_models() opts = _make_options(paired=False) cv = ContigVariants() + ofw = _CollectingOFW() no_target = [(0, _SPAN, False)] - results = generate_reads(0, ref, err, 3, qual, frag, get_uniform_gc_model(), cv, - no_target, _nothing_discarded(), - opts, None, "chr1", 0, 0) + generate_reads(0, ref, err, 3, qual, frag, get_uniform_gc_model(), cv, + no_target, _nothing_discarded(), + opts, ofw, "chr1", 0, 0) - assert results == [] + assert ofw.bam_records == [] def test_generate_reads_discard_region_removes_all(): @@ -474,13 +508,14 @@ def test_generate_reads_discard_region_removes_all(): err, qual, frag = _make_models() opts = _make_options(paired=False) cv = ContigVariants() + ofw = _CollectingOFW() discard_all = [(0, _SPAN, True)] - results = generate_reads(0, ref, err, 3, qual, frag, get_uniform_gc_model(), cv, - _all_span_targeted(), discard_all, - opts, None, "chr1", 0, 0) + generate_reads(0, ref, err, 3, qual, frag, get_uniform_gc_model(), cv, + _all_span_targeted(), discard_all, + opts, ofw, "chr1", 0, 0) - assert results == [] + assert ofw.bam_records == [] def test_generate_reads_discard_flag_false_keeps_reads(): @@ -489,12 +524,13 @@ def test_generate_reads_discard_flag_false_keeps_reads(): err, qual, frag = _make_models() opts = _make_options(paired=False) cv = ContigVariants() + ofw = _CollectingOFW() - results = generate_reads(0, ref, err, 3, qual, frag, get_uniform_gc_model(), cv, - _all_span_targeted(), _nothing_discarded(), - opts, None, "chr1", 0, 0) + generate_reads(0, ref, err, 3, qual, frag, get_uniform_gc_model(), cv, + _all_span_targeted(), _nothing_discarded(), + opts, ofw, "chr1", 0, 0) - assert len(results) > 0 + assert len(ofw.bam_records) > 0 # --------------------------------------------------------------------------- @@ -506,6 +542,7 @@ def test_generate_reads_variants_populated_on_reads(): ref = _make_reference() err, qual, frag = _make_models() opts = _make_options(paired=False) + ofw = _CollectingOFW() cv = ContigVariants() snv = SingleNucleotideVariant( @@ -516,11 +553,11 @@ def test_generate_reads_variants_populated_on_reads(): ) cv.add_variant(snv) - results = generate_reads(0, ref, err, 3, qual, frag, get_uniform_gc_model(), cv, - _all_span_targeted(), _nothing_discarded(), - opts, None, "chr1", 0, 0) + generate_reads(0, ref, err, 3, qual, frag, get_uniform_gc_model(), cv, + _all_span_targeted(), _nothing_discarded(), + opts, ofw, "chr1", 0, 0) - reads_with_mutations = [r1 for r1, _ in results if r1.mutations] + reads_with_mutations = [r for r in ofw.bam_records if r.mutations] assert len(reads_with_mutations) > 0 @@ -540,30 +577,33 @@ def test_generate_reads_paired_discard_region_removes_all(): err, qual, frag = _make_models() opts = _make_options(paired=True) cv = ContigVariants() + ofw = _CollectingOFW() discard_all = [(0, _SPAN, True)] - results = generate_reads(0, ref, err, 3, qual, frag, get_uniform_gc_model(), cv, - _all_span_targeted(), discard_all, - opts, None, "chr1", 0, 0) + generate_reads(0, ref, err, 3, qual, frag, get_uniform_gc_model(), cv, + _all_span_targeted(), discard_all, + opts, ofw, "chr1", 0, 0) - assert results == [] + assert ofw.bam_records == [] def test_generate_reads_paired_no_discard_produces_read_pairs(): - """Paired-end run without discard produces (Read, Read) pairs (regression guard).""" + """Paired-end run without discard produces both forward and reverse read records.""" ref = _make_reference() err, qual, frag = _make_models() opts = _make_options(paired=True) cv = ContigVariants() + ofw = _CollectingOFW() - results = generate_reads(0, ref, err, 3, qual, frag, get_uniform_gc_model(), cv, - _all_span_targeted(), _nothing_discarded(), - opts, None, "chr1", 0, 0) + generate_reads(0, ref, err, 3, qual, frag, get_uniform_gc_model(), cv, + _all_span_targeted(), _nothing_discarded(), + opts, ofw, "chr1", 0, 0) - assert len(results) > 0 - for read1, read2 in results: - assert isinstance(read1, Read) - assert isinstance(read2, Read) + forwards = [r for r in ofw.bam_records if not r.is_reverse] + reverses = [r for r in ofw.bam_records if r.is_reverse] + assert len(forwards) > 0 + assert len(reverses) > 0 + assert len(forwards) == len(reverses) # --------------------------------------------------------------------------- diff --git a/tests/test_read_simulator/test_options.py b/tests/test_read_simulator/test_options.py index f08a49f7..33e9c17c 100644 --- a/tests/test_read_simulator/test_options.py +++ b/tests/test_read_simulator/test_options.py @@ -158,7 +158,9 @@ def test_default_values(): assert opts.quality_offset == 33 assert opts.threads == 1 assert opts.parallel_mode == "contig" - assert opts.parallel_block_size == 500000 + # Default is 0 (sentinel for auto-tune from genome length and thread count). + # An explicit positive int in YAML overrides; see runner for the auto-tune logic. + assert opts.parallel_block_size == 0 assert opts.cleanup_splits is True assert opts.reuse_splits is False assert opts.overwrite_output is False diff --git a/tests/test_read_simulator/test_stitch_outputs.py b/tests/test_read_simulator/test_stitch_outputs.py index ed830f85..a6dd517b 100644 --- a/tests/test_read_simulator/test_stitch_outputs.py +++ b/tests/test_read_simulator/test_stitch_outputs.py @@ -50,42 +50,47 @@ def _make_ofw(tmp_path: Path, vcf_path: Path = None): return ofw -# concat +# concat (byte-level gzip concatenation — concatenated gzip streams are valid gzip) def test_concat_single_file(tmp_path): src = _write_gz(tmp_path / "a.gz", "hello\n") - dest = io.StringIO() + dest = tmp_path / "out.gz" concat([src], dest) - assert dest.getvalue() == "hello\n" + with gzip.open(dest, "rt") as f: + assert f.read() == "hello\n" def test_concat_multiple_files(tmp_path): a = _write_gz(tmp_path / "a.gz", "line1\n") b = _write_gz(tmp_path / "b.gz", "line2\n") - dest = io.StringIO() + dest = tmp_path / "out.gz" concat([a, b], dest) - assert dest.getvalue() == "line1\nline2\n" + with gzip.open(dest, "rt") as f: + assert f.read() == "line1\nline2\n" def test_concat_empty_list(tmp_path): - dest = io.StringIO() + """Empty input list: concat is a no-op and dest is not created.""" + dest = tmp_path / "out.gz" concat([], dest) - assert dest.getvalue() == "" + assert not dest.exists() def test_concat_preserves_content(tmp_path): content = "ACGT\nACGT\nACGT\n" src = _write_gz(tmp_path / "reads.gz", content) - dest = io.StringIO() + dest = tmp_path / "out.gz" concat([src], dest) - assert dest.getvalue() == content + with gzip.open(dest, "rt") as f: + assert f.read() == content def test_concat_order_is_preserved(tmp_path): files = [_write_gz(tmp_path / f"{i}.gz", f"chunk{i}\n") for i in range(5)] - dest = io.StringIO() + dest = tmp_path / "out.gz" concat(files, dest) - result = dest.getvalue() + with gzip.open(dest, "rt") as f: + result = f.read() positions = [result.index(f"chunk{i}") for i in range(5)] assert positions == sorted(positions) @@ -199,54 +204,46 @@ def test_find_dups_exact_duplicate_rejected(tmp_path): assert cv.add_variant(v2) == 1 -# merge_bam +# merge_bam (now a raw BGZF concatenation via pysam.cat — see stitch_outputs) -def test_merge_bam_calls_pysam_merge_and_sort(tmp_path): +def test_merge_bam_calls_pysam_cat(tmp_path): ofw = _make_ofw(tmp_path) bam_files = [tmp_path / f"{i}.bam" for i in range(3)] with patch("neat.read_simulator.utils.stitch_outputs.pysam") as mock_pysam: merge_bam(bam_files, ofw, threads=4) - # pysam.merge should have been called at least twice (once per chunk + final) - assert mock_pysam.merge.call_count >= 2 - # pysam.sort should have been called once - mock_pysam.sort.assert_called_once() + # Per-worker BAMs are coord-sorted and cover non-overlapping ranges, so a single + # pysam.cat is sufficient — no merge/sort needed. + mock_pysam.cat.assert_called_once() + mock_pysam.merge.assert_not_called() + mock_pysam.sort.assert_not_called() -def test_merge_bam_sort_uses_output_bam_path(tmp_path): +def test_merge_bam_cat_uses_output_bam_path(tmp_path): ofw = _make_ofw(tmp_path) bam_files = [tmp_path / "a.bam"] with patch("neat.read_simulator.utils.stitch_outputs.pysam") as mock_pysam: merge_bam(bam_files, ofw, threads=2) - sort_args = mock_pysam.sort.call_args[0] - assert str(ofw.bam) in sort_args + cat_args = mock_pysam.cat.call_args[0] + assert "-o" in cat_args + assert str(ofw.bam) in cat_args -def test_merge_bam_temp_file_cleaned_up(tmp_path): - ofw = _make_ofw(tmp_path) - bam_files = [tmp_path / "a.bam"] - temp_merged = ofw.tmp_dir / "temp_merged.bam" - - with patch("neat.read_simulator.utils.stitch_outputs.pysam"): - merge_bam(bam_files, ofw, threads=1) - - # temp_merged.bam should have been unlinked (missing_ok=True means no error if absent) - assert not temp_merged.exists() - - -def test_merge_bam_chunks_large_bam_list(tmp_path): - """More than 500 BAMs triggers chunked intermediate merges.""" +def test_merge_bam_passes_all_inputs_to_cat(tmp_path): + """Cat is given every per-worker BAM as a positional argument, in order.""" ofw = _make_ofw(tmp_path) bam_files = [tmp_path / f"{i}.bam" for i in range(600)] with patch("neat.read_simulator.utils.stitch_outputs.pysam") as mock_pysam: merge_bam(bam_files, ofw, threads=1) - # Two chunks (0–499, 500–599) → 2 intermediate merges + 1 final = 3 total - assert mock_pysam.merge.call_count == 3 + cat_args = mock_pysam.cat.call_args[0] + # cat call shape: ("-o", out_path, *input_paths) + passed_inputs = [a for a in cat_args if a not in ("-o", str(ofw.bam))] + assert passed_inputs == [str(p) for p in bam_files] # main @@ -260,7 +257,10 @@ def test_main_fq1_only(tmp_path): src = _write_gz(tmp_path / "chunk.fq1.gz", "@read1\nACGT\n+\nIIII\n") output_files = [(0, _file_dict(fq1=src))] main(ofw, output_files) - assert "read1" in ofw._fq1_buf.getvalue() + # Byte-level concat writes raw bytes to ofw.fq1 (the path), not through the + # StringIO handle. Read the actual file back. + with gzip.open(ofw.fq1, "rt") as fh: + assert "read1" in fh.read() def test_main_fq1_and_fq2(tmp_path): @@ -269,8 +269,10 @@ def test_main_fq1_and_fq2(tmp_path): src2 = _write_gz(tmp_path / "c.fq2.gz", "@r2\nTTGG\n+\nIIII\n") output_files = [(0, _file_dict(fq1=src1, fq2=src2))] main(ofw, output_files) - assert "r1" in ofw._fq1_buf.getvalue() - assert "r2" in ofw._fq2_buf.getvalue() + with gzip.open(ofw.fq1, "rt") as fh: + assert "r1" in fh.read() + with gzip.open(ofw.fq2, "rt") as fh: + assert "r2" in fh.read() def test_main_vcf(tmp_path): @@ -288,8 +290,10 @@ def test_main_none_files_not_concatenated(tmp_path): ofw = _make_ofw(tmp_path) output_files = [(0, _file_dict())] # all None main(ofw, output_files) # should not raise - assert ofw._fq1_buf.getvalue() == "" - assert ofw._fq2_buf.getvalue() == "" + # No inputs → byte-level concat is a no-op and the destination paths are not + # created. The VCF buffer remains untouched (no merge_vcfs run). + assert not ofw.fq1.exists() + assert not ofw.fq2.exists() assert ofw._vcf_buf.getvalue() == "" @@ -300,7 +304,8 @@ def test_main_multiple_threads(tmp_path): for i in range(3) ] main(ofw, chunks) - result = ofw._fq1_buf.getvalue() + with gzip.open(ofw.fq1, "rt") as fh: + result = fh.read() for i in range(3): assert f"chunk{i}" in result