Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions tests/unit/test_get_charging_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,38 @@ def test_extracts_ml_tag(self, temp_dir):
assert "read_id\ttRNA\tcharging_likelihood\n" == lines[0]
assert "read1\ttRNA-Ala-AGC-1-1\t220\n" == lines[1]

def test_extracts_scalar_int_tag(self, temp_dir):
"""Should handle CL tag stored as a scalar integer (not array)."""
input_bam = temp_dir / "input.bam"
output_tsv = temp_dir / "output.tsv"

header = {
"HD": {"VN": "1.0"},
"SQ": [{"SN": "tRNA-Ala-AGC-1-1", "LN": 100}],
}

with pysam.AlignmentFile(str(input_bam), "wb", header=header) as outf:
read = pysam.AlignedSegment()
read.query_name = "read1"
read.query_sequence = "A" * 100
read.flag = 0
read.reference_id = 0
read.reference_start = 0
read.cigartuples = [(0, 100)]
read.query_qualities = pysam.qualitystring_to_array("I" * 100)
read.set_tag("CL", 220) # Scalar int, not array
outf.write(read)

pysam.index(str(input_bam))

extract_tag(str(input_bam), str(output_tsv), "CL")

with open(output_tsv) as f:
lines = f.readlines()

assert len(lines) == 2
assert "read1\ttRNA-Ala-AGC-1-1\t220\n" == lines[1]

def test_handles_gzip_output(self, temp_dir):
"""Should write gzipped output when filename ends in .gz."""
input_bam = temp_dir / "input.bam"
Expand Down
54 changes: 54 additions & 0 deletions tests/unit/test_transfer_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,60 @@ def test_only_outputs_reads_with_tags(self, temp_dir):
assert len(reads) == 1
assert reads[0].query_name == "read1"

def test_to_scalar_converts_single_element_array(self, temp_dir):
"""Single-element array tags listed in to_scalar should become scalar ints."""
source_bam = temp_dir / "source.bam"
target_bam = temp_dir / "target.bam"
output_bam = temp_dir / "output.bam"

header = {
"HD": {"VN": "1.0"},
"SQ": [{"SN": "ref", "LN": 100}],
}

# Create source with ML as byte array
with pysam.AlignmentFile(str(source_bam), "wb", header=header) as outf:
read = pysam.AlignedSegment()
read.query_name = "read1"
read.query_sequence = "A" * 100
read.flag = 0
read.reference_id = 0
read.reference_start = 0
read.cigartuples = [(0, 100)]
read.query_qualities = pysam.qualitystring_to_array("I" * 100)
read.set_tag("ML", array("B", [220]))
outf.write(read)

# Create target
with pysam.AlignmentFile(str(target_bam), "wb", header=header) as outf:
read = pysam.AlignedSegment()
read.query_name = "read1"
read.query_sequence = "A" * 100
read.flag = 0
read.reference_id = 0
read.reference_start = 0
read.cigartuples = [(0, 100)]
read.query_qualities = pysam.qualitystring_to_array("I" * 100)
outf.write(read)

# Transfer ML -> CL with to_scalar=["CL"]
transfer_tags(
tags=["ML"],
rename=["ML=CL"],
source_bam=str(source_bam),
target_bam=str(target_bam),
output_bam=str(output_bam),
to_scalar=["CL"],
)

with pysam.AlignmentFile(str(output_bam), "rb") as bam:
read = next(bam)
assert read.has_tag("CL")
cl_val = read.get_tag("CL")
# Should be a scalar int, not an array
assert isinstance(cl_val, int)
assert cl_val == 220

def test_multiple_tags(self, temp_dir):
"""Multiple tags should all be transferred."""
source_bam = temp_dir / "source.bam"
Expand Down
1 change: 1 addition & 0 deletions workflow/rules/aatrnaseq-process.smk
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ rule transfer_bam_tags:
python {params.src}/transfer_tags.py \
--tags ML MM \
--rename ML=CL MM=CM \
--to-scalar CL \
--source {input.source_bam} \
--target {input.target_bam} \
--output {output.classified_bam}
Expand Down
16 changes: 9 additions & 7 deletions workflow/scripts/get_charging_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,17 @@ def extract_tag(bam_file, output_tsv, tag):
for read in bam.fetch():
read_id = read.query_name
reference = read.reference_name if read.reference_name else "*"
tag_array = dict(read.tags).get(tag, None)

# XXX: handle case where there are more than 1 tag value
# not clear why this is, but we skip for now as it's a small
# number of reads affected
if len(tag_array) > 1:
tag_raw = dict(read.tags).get(tag, None)
if tag_raw is None:
continue

tag_value = tag_array[0]
# Handle both scalar int (new CL) and array (legacy ML/CL)
if isinstance(tag_raw, int):
tag_value = tag_raw
elif hasattr(tag_raw, "__len__") and len(tag_raw) == 1:
tag_value = tag_raw[0]
else:
continue

if tag_value and reference != "*":
writer.writerow([read_id, reference, tag_value])
Expand Down
19 changes: 14 additions & 5 deletions workflow/scripts/transfer_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@


def transfer_tags(
tags, rename, source_bam, target_bam, output_bam, all_tags=False, threads=1
tags, rename, source_bam, target_bam, output_bam, all_tags=False, to_scalar=None, threads=1
):
renamed_tags = parse_tag_items(rename)
scalar_tags = set(to_scalar) if to_scalar else set()

# Collect target read names (primary only) so we only cache matching source reads
target_names = set()
Expand Down Expand Up @@ -63,10 +64,10 @@ def transfer_tags(

if read_tags:
for tag, tag_val in read_tags.items():
if tag in renamed_tags:
read.set_tag(renamed_tags[tag], tag_val)
else:
read.set_tag(tag, tag_val)
out_tag = renamed_tags.get(tag, tag)
if out_tag in scalar_tags and hasattr(tag_val, "__len__") and len(tag_val) == 1:
tag_val = int(tag_val[0])
read.set_tag(out_tag, tag_val)

if all_tags or read_tags:
output.write(read)
Expand Down Expand Up @@ -103,6 +104,13 @@ def parse_tag_items(rename):
help="tags to rename during transfer",
)

parser.add_argument(
"--to-scalar",
nargs="+",
metavar="TAG",
help="convert single-element array tags to scalar integers",
)

parser.add_argument("--source", required=True, help="Source BAM file (with tags)")

parser.add_argument(
Expand Down Expand Up @@ -130,5 +138,6 @@ def parse_tag_items(rename):
args.target,
args.output,
all_tags=args.all_tags,
to_scalar=args.to_scalar,
threads=args.threads,
)
Loading