Skip to content

Commit 7e5b397

Browse files
Merge pull request #176 from ncsa/172-fix-vcf-headers
Fixed and cleaned up tests
2 parents 0c4dedb + 5e9586c commit 7e5b397

9 files changed

Lines changed: 50 additions & 133 deletions

File tree

neat/common/io.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,9 @@ def open_input(path: str | Path) -> Iterator[TextIO]:
6363
# - https://github.com/python/mypy/issues/12053
6464
open_: Callable[..., TextIO]
6565
if is_compressed(path):
66-
open_ = bgzf.open
66+
handle = bgzf.BgzfReader(path, 'r')
6767
else:
68-
open_ = open
69-
handle = open_(path, "rt", encoding="utf-8")
68+
handle = open(path, 'rt', encoding='utf-8')
7069
try:
7170
yield handle
7271
finally:

neat/read_simulator/utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from .options import *
66
from .output_file_writer import *
77
from .read import *
8-
from .t_sam_record import *
98
from .vcf_func import *
109
from .generate_reads import *
1110
from .generate_variants import *

neat/read_simulator/utils/stitch_outputs.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,28 +21,27 @@
2121
__all__ = ["main"]
2222

2323
from Bio import SeqIO, bgzf
24+
from Bio.bgzf import BgzfWriter
2425

2526
from neat.common import open_output, open_input
2627
from neat.read_simulator.utils import Options, OutputFileWriter
2728

2829
_LOG = logging.getLogger(__name__)
2930

30-
def concat(files_to_join: List[Path], ofw: OutputFileWriter, file: Path) -> None:
31+
def concat(files_to_join: List[Path], dest_file: BgzfWriter) -> None:
3132
if not files_to_join:
3233
# Nothing to do, and no error to throw
3334
return
3435

35-
out_handle = ofw.files_to_write[file]
3636
for f in files_to_join:
3737
with bgzf.BgzfReader(f) as in_f:
38-
shutil.copyfileobj(in_f, out_handle)
38+
shutil.copyfileobj(in_f, dest_file)
3939

4040
def merge_bam(bam_files: List[Path], ofw: OutputFileWriter, threads: int) -> None:
4141
if not bam_files:
4242
return
4343

4444
unsorted = ofw.bam.with_suffix(".unsorted.bam")
45-
sorted_bam_files = []
4645
pysam.merge("--no-PG", "-@", str(threads), "-f", str(unsorted), *map(str, bam_files))
4746
pysam.sort("-@", str(threads), "-o", str(ofw.bam), str(unsorted))
4847
unsorted.unlink(missing_ok=True)
@@ -65,8 +64,8 @@ def main(
6564
if file_dict["bam"]:
6665
bam.append(file_dict["bam"])
6766
# concatenate all files of each type. An empty list will result in no action
68-
concat(fq1_list, ofw, ofw.fq1)
69-
concat(fq2_list, ofw, ofw.fq2)
67+
concat(fq1_list, ofw.files_to_write[ofw.fq1])
68+
concat(fq2_list, ofw.files_to_write[ofw.fq2])
7069
merge_bam(bam, ofw, threads)
7170
# Final success message via logging
7271
_LOG.info("Stitching complete!")

neat/read_simulator/utils/t_sam_record.py

Lines changed: 0 additions & 58 deletions
This file was deleted.

tests/test_common/test_io.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pathlib import Path
55

66
import pytest
7+
from Bio import bgzf
78

89
from neat.common.io import is_compressed, open_input, open_output, validate_input_path, validate_output_path
910

@@ -12,7 +13,7 @@ def test_is_compressed_plain_and_gz(tmp_path: Path):
1213
plain = tmp_path / "file.txt"
1314
plain.write_text("hello", encoding="utf-8")
1415
gz = tmp_path / "file.txt.gz"
15-
with gzip.open(gz, "wt", encoding="utf-8") as f:
16+
with bgzf.BgzfWriter(gz, "wt") as f:
1617
f.write("hello")
1718

1819
assert is_compressed(plain) is False
@@ -28,10 +29,13 @@ def test_open_input_reads_plain_and_gz(tmp_path: Path):
2829

2930
# Gz
3031
gz = tmp_path / "b.txt.gz"
31-
with gzip.open(gz, "wt", encoding="utf-8") as fh:
32+
with bgzf.BgzfWriter(gz, "wt") as fh:
3233
fh.write("xyz\n")
34+
out = ""
3335
with open_input(gz) as fh:
34-
assert fh.read() == "xyz\n"
36+
for line in fh:
37+
out += line
38+
assert out == "xyz\n"
3539

3640

3741
def test_open_output_creates_dirs_and_writes_plain(tmp_path: Path):

tests/test_models/test_parallelize.py

Lines changed: 0 additions & 37 deletions
This file was deleted.
Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
"""
22
Unit tests for the split_inputs module of the parallel read simulator
33
"""
4+
from Bio.Seq import Seq
45

5-
from neat.read_simulator.utils.split_inputs import SimpleRecord, chunk_record
6-
6+
from neat.read_simulator.utils.split_inputs import chunk_record
77

88
def test_chunk_record_overlaps() -> None:
99
"""Ensure that chunk_record yields overlapping chunks with the correct ids."""
1010
# A simple sequence of 12 bases to chunk into length 5 with overlap 2
11-
rec = SimpleRecord("contig", "ACGTACGTACGT")
11+
rec = Seq("ACGTACGTACGT")
1212
chunks = list(chunk_record(rec, 5, 2))
1313
# Should produce four chunks: positions [0:5], [3:8], [6:11], [9:12]
1414
assert len(chunks) == 4
1515
# Check that chunk ids are sequential and lengths match expected slices
16-
lengths = [len(r.seq) for r, _ in chunks]
16+
lengths = [len(seq) for _, seq in chunks]
1717
assert lengths == [5, 5, 5, 3]
18-
ids = [cid for _, cid in chunks]
19-
assert ids == [1, 2, 3, 4]
18+
ids = [cid for cid, _ in chunks]
19+
seqs = [seq for _, seq in chunks]
20+
assert ids == [0, 3, 6, 9]
21+
assert seqs == [Seq('ACGTA'), Seq('TACGT'), Seq('GTACG'), Seq('CGT')]

tests/test_models/test_stitch_outputs.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,35 @@
33
"""
44

55
from pathlib import Path
6+
7+
from Bio import bgzf
8+
69
from neat.read_simulator.utils.stitch_outputs import concat
710

811

912
def test_concat_joins_files_in_order(tmp_path: Path) -> None:
1013
"""Verify that concat writes the exact bytewise concatenation of inputs."""
1114
# Prepare small input files
1215
f1 = tmp_path / "a.bin"
13-
f1.write_bytes(b"hello\n")
16+
with bgzf.BgzfWriter(f1, 'w')as f1_in:
17+
f1_in.write("hello\n")
1418
f2 = tmp_path / "b.bin"
15-
f2.write_bytes(b"world\n")
19+
with bgzf.BgzfWriter(f2, 'w') as f2_in:
20+
f2_in.write("world\n")
1621
f3 = tmp_path / "c.bin"
17-
f3.write_bytes(b"!!!")
22+
with bgzf.BgzfWriter(f3, 'w') as f3_in:
23+
f3_in.write("!!!")
1824

1925
dest = tmp_path / "out.bin"
20-
concat([f1, f2, f3], dest)
21-
26+
dest_write = bgzf.BgzfWriter(dest, 'w')
27+
concat([f1, f2, f3], dest_write)
28+
dest_write.close()
2229
assert dest.exists()
23-
assert dest.read_bytes() == b"hello\nworld\n!!!"
30+
with bgzf.BgzfReader(dest) as read_dest:
31+
text = ""
32+
for line in read_dest:
33+
text += line
34+
assert text == "hello\nworld\n!!!"
2435

2536

2637
def test_concat_noop_on_empty_list(tmp_path: Path) -> None:

tests/test_read_simulator/test_generate_reads.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -83,27 +83,25 @@ def test_various_read_lengths():
8383

8484
def test_fragment_mean_st_dev_combinations():
8585
"""Test cover_dataset with combinations of fragment mean and standard deviation to ensure no errors"""
86-
span_length = 10000
86+
span_length = 5000
8787
options = Options(rng_seed=0)
88-
options.paired_ended = True
89-
options.read_len = 100
90-
options.coverage = 5
88+
options.paired_ended = False
89+
options.read_len = 101
90+
options.coverage = 2
9191
options.overwrite_output = True
9292

93-
fragment_means = [100, 150, 200, 250, 300, 350, 400, 450, 500, 750, 1000]
94-
fragment_st_devs = [1, 2, 5, 10, 25, 50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 750, 1000]
93+
fragment_means = [100, 150, 200, 250,]
94+
fragment_st_devs = [1, 5, 25, 50]
9595

9696
for mean in fragment_means:
9797
for st_dev in fragment_st_devs:
9898
options.fragment_mean = mean
9999
options.fragment_st_dev = st_dev
100100
fragment_model = FragmentLengthModel(mean, st_dev)
101-
try:
102-
reads = cover_dataset(span_length, options, fragment_model)
103-
assert isinstance(reads, list)
104-
except Exception as e:
105-
pytest.fail(f"Test failed for mean={mean}, st_dev={st_dev} with exception: {e}")
106-
101+
frags = fragment_model.generate_fragments(20, options.rng)
102+
assert len(frags) == 20
103+
assert fragment_model.fragment_mean == mean
104+
assert fragment_model.fragment_st_dev == st_dev
107105

108106
def test_coverage_ploidy_combinations():
109107
"""Test cover_dataset with various combinations of coverage and ploidy values to ensure no errors"""
@@ -116,7 +114,7 @@ def test_coverage_ploidy_combinations():
116114
options.overwrite_output = True
117115
fragment_model = FragmentLengthModel(250, 100)
118116

119-
coverage_values = [1, 2, 5, 10, 25, 50]
117+
coverage_values = [1, 2, 5, 10]
120118
ploidy_values = [1, 2, 4]
121119

122120
for coverage in coverage_values:

0 commit comments

Comments
 (0)