diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a81bec..ae0c052 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ... +## [v0.2.0] +### Fixed +- fixed a bug in the intervals to values cuda kernel that + introduced zeros in places where there should be + "default_value" (see release v0.1.5). +### Added +- custom_position_sampler argument to bigwig_loader.dataset.BigWigDataset + and bigwig_loader.pytorch.PytorchBigWigDataset to optionally overwrite the + default random sampling of genomic coordinates from "regions of interest." +- custom_track_sampler argument to bigwig_loader.dataset.BigWigDataset + and bigwig_loader.pytorch.PytorchBigWigDataset to optionally use a different + track sampling strategy. + ## [v0.1.5] ### Added - set a default value different from 0.0 @@ -47,7 +60,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - release to pypi -[Unreleased]: https://github.com/pfizer-opensource/bigwig-loader/compare/v0.1.5...HEAD +[Unreleased]: https://github.com/pfizer-opensource/bigwig-loader/compare/v0.2.0...HEAD +[v0.1.6]: https://github.com/pfizer-opensource/bigwig-loader/compare/v0.1.5...v0.2.0 [v0.1.5]: https://github.com/pfizer-opensource/bigwig-loader/compare/v0.1.4...v0.1.5 [v0.1.4]: https://github.com/pfizer-opensource/bigwig-loader/compare/v0.1.3...v0.1.4 [v0.1.3]: https://github.com/pfizer-opensource/bigwig-loader/compare/v0.1.2...v0.1.3 diff --git a/bigwig_loader/dataset.py b/bigwig_loader/dataset.py index 787f0bf..fa3753e 100644 --- a/bigwig_loader/dataset.py +++ b/bigwig_loader/dataset.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Any from typing import Callable +from typing import Iterable from typing import Iterator from typing import Literal from typing import Optional @@ -78,6 +79,12 @@ class BigWigDataset: GPU. More threads means that more IO can take place while the GPU is busy doing calculations (decompressing or neural network training for example). More threads also means a higher GPU memory usage. Default: 4 + custom_position_sampler: if set, this sampler will be used instead of the default + position sampler (which samples randomly and uniform from regions of interest) + This should be an iterable of tuples (chromosome, center). + custom_track_sampler: if specified, this sampler will be used to sample tracks. When not + specified, each batch simply contains all tracks, or a randomly sellected subset of + tracks in case sub_sample_tracks is set. Should be Iterable batches of track indices. return_batch_objects: if True, the batches will be returned as instances of bigwig_loader.batch.Batch """ @@ -107,6 +114,8 @@ def __init__( repeat_same_positions: bool = False, sub_sample_tracks: Optional[int] = None, n_threads: int = 4, + custom_position_sampler: Optional[Iterable[tuple[str, int]]] = None, + custom_track_sampler: Optional[Iterable[list[int]]] = None, return_batch_objects: bool = False, ): super().__init__() @@ -152,32 +161,34 @@ def __init__( self._sub_sample_tracks = sub_sample_tracks self._n_threads = n_threads self._return_batch_objects = return_batch_objects - - def _create_dataloader(self) -> StreamedDataloader: - position_sampler = RandomPositionSampler( + self._position_sampler = custom_position_sampler or RandomPositionSampler( regions_of_interest=self.regions_of_interest, buffer_size=self._position_sampler_buffer_size, repeat_same=self._repeat_same_positions, ) + if custom_track_sampler is not None: + self._track_sampler: Optional[Iterable[list[int]]] = custom_track_sampler + elif sub_sample_tracks is not None: + self._track_sampler = TrackSampler( + total_number_of_tracks=len(self.bigwig_collection), + sample_size=sub_sample_tracks, + ) + else: + self._track_sampler = None + def _create_dataloader(self) -> StreamedDataloader: sequence_sampler = GenomicSequenceSampler( reference_genome_path=self.reference_genome_path, sequence_length=self.sequence_length, - position_sampler=position_sampler, + position_sampler=self._position_sampler, maximum_unknown_bases_fraction=self.maximum_unknown_bases_fraction, ) - track_sampler = None - if self._sub_sample_tracks is not None: - track_sampler = TrackSampler( - total_number_of_tracks=len(self.bigwig_collection), - sample_size=self._sub_sample_tracks, - ) query_batch_generator = QueryBatchGenerator( genomic_location_sampler=sequence_sampler, center_bin_to_predict=self.center_bin_to_predict, batch_size=self.super_batch_size, - track_sampler=track_sampler, + track_sampler=self._track_sampler, ) return StreamedDataloader( diff --git a/bigwig_loader/download_example_data.py b/bigwig_loader/download_example_data.py index d74b699..54e5fe1 100644 --- a/bigwig_loader/download_example_data.py +++ b/bigwig_loader/download_example_data.py @@ -19,20 +19,55 @@ def download_example_data() -> None: def get_reference_genome(reference_genome_path: Path = config.reference_genome) -> Path: compressed_file = reference_genome_path.with_suffix(".fasta.gz") - if reference_genome_path.exists(): - return reference_genome_path - elif compressed_file.exists(): - # subprocess.run(["bgzip", "-d", compressed_file]) - unzip_gz_file(compressed_file, reference_genome_path) - else: - LOGGER.info("Need reference genome for tests. Downloading it from ENCODE.") - url = "https://www.encodeproject.org/files/GRCh38_no_alt_analysis_set_GCA_000001405.15/@@download/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta.gz" - urllib.request.urlretrieve(url, compressed_file) + if compressed_file.exists() and not reference_genome_path.exists(): # subprocess.run(["bgzip", "-d", compressed_file]) unzip_gz_file(compressed_file, reference_genome_path) + + if ( + reference_genome_path.exists() + and checksum_md5_for_path(reference_genome_path) + != config.reference_genome_checksum + ): + LOGGER.info( + f"Reference genome checksum mismatch, downloading again from {reference_genome_path}" + ) + _download_genome( + url=config.reference_genome_url, + compressed_file_path=compressed_file, + uncompressed_file_path=reference_genome_path, + md5_checksum=config.reference_genome_checksum, + ) + + if not reference_genome_path.exists(): + LOGGER.info( + f"Reference genome not found, downloading from {config.reference_genome_url}" + ) + _download_genome( + url=config.reference_genome_url, + compressed_file_path=compressed_file, + uncompressed_file_path=reference_genome_path, + md5_checksum=config.reference_genome_checksum, + ) return reference_genome_path +def _download_genome( + url: str, + compressed_file_path: Path, + uncompressed_file_path: Path, + md5_checksum: str, +) -> Path: + urllib.request.urlretrieve(url, compressed_file_path) + # subprocess.run(["bgzip", "-d", compressed_file]) + unzip_gz_file(compressed_file_path, uncompressed_file_path) + this_checksum = checksum_md5_for_path(uncompressed_file_path) + if this_checksum != md5_checksum: + raise RuntimeError( + f"{uncompressed_file_path} has incorrect checksum: {this_checksum} vs. {md5_checksum}" + ) + return uncompressed_file_path + + def unzip_gz_file(compressed_file_path: Path, output_file_path: Path) -> Path: with gzip.open(compressed_file_path, "rb") as gz_file: with open(output_file_path, "wb") as output_file: @@ -52,6 +87,13 @@ def unzip_gz_file(compressed_file_path: Path, output_file_path: Path) -> Path: } +def checksum_md5_for_path(path: Path, chunk_size: int = 10 * 1024 * 1024) -> str: + """return the md5sum""" + with path.open(mode="rb") as f: + checksum = checksum_md5(f, chunk_size=chunk_size) + return checksum + + def checksum_md5(f: BinaryIO, *, chunk_size: int = 10 * 1024 * 1024) -> str: """return the md5sum""" m = hashlib.md5(b"", usedforsecurity=False) @@ -68,7 +110,7 @@ def get_example_bigwigs_files(bigwig_dir: Path = config.bigwig_dir) -> Path: file = bigwig_dir / fn if not file.exists(): urllib.request.urlretrieve(url, file) - with file.open(mode="rb") as f: - if checksum_md5(f) != md5: - raise RuntimeError(f"{fn} has incorrect checksum!") + checksum = checksum_md5_for_path(file) + if checksum != md5: + raise RuntimeError(f"{fn} has incorrect checksum: {checksum} vs. {md5}") return bigwig_dir diff --git a/bigwig_loader/intervals_to_values.py b/bigwig_loader/intervals_to_values.py index bc70ab6..b379747 100644 --- a/bigwig_loader/intervals_to_values.py +++ b/bigwig_loader/intervals_to_values.py @@ -1,5 +1,6 @@ import logging import math +from math import isnan from pathlib import Path import cupy as cp @@ -86,11 +87,15 @@ def intervals_to_values( ) if out is None: + logging.debug(f"Creating new out tensor with default value {default_value}") + out = cp.full( (found_starts.shape[0], len(query_starts), sequence_length // window_size), default_value, dtype=cp.float32, ) + logging.debug(out) + else: logging.debug(f"Setting default value in output tensor to {default_value}") out.fill(default_value) @@ -120,6 +125,7 @@ def intervals_to_values( array_start = cp.ascontiguousarray(array_start) array_end = cp.ascontiguousarray(array_end) array_value = cp.ascontiguousarray(array_value) + default_value_isnan = isnan(default_value) cuda_kernel( (grid_size,), @@ -137,6 +143,8 @@ def intervals_to_values( sequence_length, max_number_intervals, window_size, + cp.float32(default_value), + default_value_isnan, out, ), ) @@ -167,8 +175,10 @@ def kernel_in_python_with_window( int, int, int, - cp.ndarray, int, + float, + bool, + cp.ndarray, ], ) -> cp.ndarray: """Equivalent in python to cuda_kernel_with_window. Just for debugging.""" @@ -186,6 +196,8 @@ def kernel_in_python_with_window( sequence_length, max_number_intervals, window_size, + default_value, + default_value_isnan, out, ) = args @@ -214,7 +226,7 @@ def kernel_in_python_with_window( print("reduced_dim") print(reduced_dim) - out_vector = [0.0] * reduced_dim * batch_size * num_tracks + out_vector = [default_value] * reduced_dim * batch_size * num_tracks for thread in range(n_threads): batch_index = thread % batch_size @@ -235,7 +247,8 @@ def kernel_in_python_with_window( cursor = found_start_index window_index = 0 - summation = 0 + summation = 0.0 + valid_count = 0 # cursor moves through the rows of the bigwig file # window_index moves through the sequence @@ -261,19 +274,31 @@ def kernel_in_python_with_window( print("start index", start_index) if start_index >= window_end: - print("CONTINUE") - out_vector[i * reduced_dim + window_index] = summation / window_size - summation = 0 + if default_value_isnan: + if valid_count > 0: + out_vector[i * reduced_dim + window_index] = ( + summation / valid_count + ) + else: + out_vector[i * reduced_dim + window_index] = default_value + else: + summation = summation + (window_size - valid_count) * default_value + out_vector[i * reduced_dim + window_index] = summation / window_size + summation = 0.0 + valid_count = 0 window_index += 1 + print("CONTINUE") continue number = min(window_end, end_index) - max(window_start, start_index) - print( - f"Add {number} x {track_values[cursor]} = {number * track_values[cursor]} to summation" - ) - summation += number * track_values[cursor] - print(f"Summation = {summation}") + if number > 0: + print( + f"Add {number} x {track_values[cursor]} = {number * track_values[cursor]} to summation" + ) + summation += number * track_values[cursor] + print(f"Summation = {summation}") + valid_count += number print("end_index", "window_end") print(end_index, window_end) @@ -288,8 +313,19 @@ def kernel_in_python_with_window( print( "cursor + 1 >= found_end_index \t\t calculate average, reset summation and move to next window" ) - out_vector[i * reduced_dim + window_index] = summation / window_size - summation = 0 + # out_vector[i * reduced_dim + window_index] = summation / window_size + if default_value_isnan: + if valid_count > 0: + out_vector[i * reduced_dim + window_index] = ( + summation / valid_count + ) + else: + out_vector[i * reduced_dim + window_index] = default_value + else: + summation = summation + (window_size - valid_count) * default_value + out_vector[i * reduced_dim + window_index] = summation / window_size + summation = 0.0 + valid_count = 0 window_index += 1 # move cursor if end_index < window_end: diff --git a/bigwig_loader/pytorch.py b/bigwig_loader/pytorch.py index 474a1cd..08750dc 100644 --- a/bigwig_loader/pytorch.py +++ b/bigwig_loader/pytorch.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import Any from typing import Callable +from typing import Iterable from typing import Iterator from typing import Literal from typing import Optional @@ -165,6 +166,12 @@ class PytorchBigWigDataset(IterableDataset[BATCH_TYPE]): also means a higher GPU memory usage. Default: 4 return_batch_objects: if True, the batches will be returned as instances of bigwig_loader.pytorch.PytorchBatch + custom_position_sampler: if set, this sampler will be used instead of the default + position sampler (which samples randomly and uniform from regions of interest) + This should be an iterable of tuples (chromosome, center). + custom_track_sampler: if specified, this sampler will be used to sample tracks. When not + specified, each batch simply contains all tracks, or a randomly sellected subset of + tracks in case sub_sample_tracks is set. Should be Iterable batches of track indices. """ def __init__( @@ -192,6 +199,8 @@ def __init__( repeat_same_positions: bool = False, sub_sample_tracks: Optional[int] = None, n_threads: int = 4, + custom_position_sampler: Optional[Iterable[tuple[str, int]]] = None, + custom_track_sampler: Optional[Iterable[list[int]]] = None, return_batch_objects: bool = False, ): super().__init__() @@ -217,6 +226,8 @@ def __init__( repeat_same_positions=repeat_same_positions, sub_sample_tracks=sub_sample_tracks, n_threads=n_threads, + custom_position_sampler=custom_position_sampler, + custom_track_sampler=custom_track_sampler, return_batch_objects=True, ) self._return_batch_objects = return_batch_objects diff --git a/bigwig_loader/sampler/genome_sampler.py b/bigwig_loader/sampler/genome_sampler.py index c0cd2e6..1bf5bfe 100644 --- a/bigwig_loader/sampler/genome_sampler.py +++ b/bigwig_loader/sampler/genome_sampler.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import Any from typing import Callable +from typing import Iterable from typing import Iterator from typing import Literal from typing import Optional @@ -21,7 +22,7 @@ def __init__( self, reference_genome_path: Path, sequence_length: int, - position_sampler: Iterator[tuple[str, int]], + position_sampler: Iterable[tuple[str, int]], maximum_unknown_bases_fraction: float = 0.1, ): self.reference_genome_path = reference_genome_path diff --git a/bigwig_loader/sampler/position_sampler.py b/bigwig_loader/sampler/position_sampler.py index 03fc279..4465016 100644 --- a/bigwig_loader/sampler/position_sampler.py +++ b/bigwig_loader/sampler/position_sampler.py @@ -3,7 +3,13 @@ import numpy as np import pandas as pd -from bigwig_loader.util import make_cumulative_index_intervals + +def make_cumulative_index_intervals(intervals: pd.DataFrame) -> pd.DataFrame: + intervals.reset_index(drop=True, inplace=True) + intervals.index = ( + (intervals["end"] - intervals["start"]).cumsum().shift().fillna(0).astype(int) # type: ignore + ) + return intervals class RandomPositionSampler: @@ -22,6 +28,8 @@ def __init__( self._repeat_same = repeat_same def __iter__(self) -> Iterator[tuple[str, int]]: + if self._repeat_same: + self._index = 0 return self def __next__(self) -> tuple[str, int]: @@ -36,6 +44,7 @@ def __next__(self) -> tuple[str, int]: return chromosome, center def _refresh_buffer(self) -> None: + print("refresh buffer called") batch_rand = np.random.randint( low=0, high=self._max_index, size=self.buffer_size ) diff --git a/bigwig_loader/settings.py b/bigwig_loader/settings.py index 9a949e7..8276f5e 100644 --- a/bigwig_loader/settings.py +++ b/bigwig_loader/settings.py @@ -34,6 +34,7 @@ class Settings(BaseSettings): reference_genome: Path = ( example_data_dir / "GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta" ) + reference_genome_checksum: str = "a6da8681616c05eb542f1d91606a7b2f" bigwig_dir: Path = example_data_dir / "bigwig" def __str__(self) -> str: diff --git a/bigwig_loader/util.py b/bigwig_loader/util.py index f267a30..b607ad1 100644 --- a/bigwig_loader/util.py +++ b/bigwig_loader/util.py @@ -28,14 +28,6 @@ def sort_intervals(intervals: pd.DataFrame, inplace: bool = False) -> pd.DataFra ) -def make_cumulative_index_intervals(intervals: pd.DataFrame) -> pd.DataFrame: - intervals.reset_index(drop=True, inplace=True) - intervals.index = ( - (intervals["end"] - intervals["start"]).cumsum().shift().fillna(0).astype(int) # type: ignore - ) - return intervals - - _string_to_encoding = { "A": [1.0, 0.0, 0.0, 0.0], "C": [0.0, 1.0, 0.0, 0.0], diff --git a/cuda_kernels/intervals_to_values.cu b/cuda_kernels/intervals_to_values.cu index c8173d2..2d1299d 100644 --- a/cuda_kernels/intervals_to_values.cu +++ b/cuda_kernels/intervals_to_values.cu @@ -1,3 +1,5 @@ +#include + extern "C" __global__ void intervals_to_values( const unsigned int* query_starts, @@ -12,6 +14,8 @@ void intervals_to_values( const int sequence_length, const int max_number_intervals, const int window_size, + const float default_value, + const bool default_value_isnan, float* out ) { @@ -49,7 +53,7 @@ void intervals_to_values( } } else { - int track_index = i / batch_size; +// int track_index = i / batch_size; int found_start_index = found_starts[i]; int found_end_index = found_ends[i]; @@ -59,6 +63,8 @@ void intervals_to_values( int cursor = found_start_index; int window_index = 0; float summation = 0.0f; + int valid_count = 0; + int reduced_dim = sequence_length / window_size; @@ -73,19 +79,34 @@ void intervals_to_values( int end_index = min(interval_end, query_end) - query_start; if (start_index >= window_end) { - out[i * reduced_dim + window_index] = summation / window_size; + if (default_value_isnan) { + out[i * reduced_dim + window_index] = valid_count > 0 ? summation / valid_count : CUDART_NAN_F; + } else { + summation = summation + (window_size - valid_count) * default_value; + out[i * reduced_dim + window_index] = summation / window_size; + } summation = 0.0f; + valid_count = 0; window_index += 1; continue; } int number = min(window_end, end_index) - max(window_start, start_index); - summation += number * track_values[cursor]; + if (number > 0) { + summation += number * track_values[cursor]; + valid_count += number; + } if (end_index >= window_end || cursor + 1 >= found_end_index) { - out[i * reduced_dim + window_index] = summation / window_size; - summation = 0.0f; + if (default_value_isnan) { + out[i * reduced_dim + window_index] = valid_count > 0 ? summation / valid_count : CUDART_NAN_F; + } else { + summation = summation + (window_size - valid_count) * default_value; + out[i * reduced_dim + window_index] = summation / window_size; + } + summation = 0.0f; + valid_count = 0; window_index += 1; } diff --git a/example_data/some_positions.tsv b/example_data/some_positions.tsv index a813904..52bec21 100644 --- a/example_data/some_positions.tsv +++ b/example_data/some_positions.tsv @@ -1,4 +1,5 @@ chr center +chr1 45298878 chr18 61036865 chr17 12174372 chr3 65857025 diff --git a/tests/conftest.py b/tests/conftest.py index 0de49d7..539d1af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,11 +5,11 @@ import pytest from bigwig_loader import config +from bigwig_loader.download_example_data import get_example_bigwigs_files +from bigwig_loader.download_example_data import get_reference_genome try: from bigwig_loader.collection import BigWigCollection - from bigwig_loader.download_example_data import get_example_bigwigs_files - from bigwig_loader.download_example_data import get_reference_genome except ImportError: logging.warning( "Can not import from bigwig_loader.collection without cupy installed" diff --git a/tests/test_against_pybigwig.py b/tests/test_against_pybigwig.py index 99e412d..a17307b 100644 --- a/tests/test_against_pybigwig.py +++ b/tests/test_against_pybigwig.py @@ -1,6 +1,7 @@ import numpy as np import pandas as pd import pyBigWig +import pytest from bigwig_loader import config from bigwig_loader.collection import BigWigCollection @@ -26,8 +27,8 @@ def get_batch(self, chromosomes, starts, ends): def test_same_output(bigwig_path): - pybigwig_collection = PyBigWigCollection(bigwig_path, first_n_files=2) - collection = BigWigCollection(bigwig_path, first_n_files=2) + pybigwig_collection = PyBigWigCollection(bigwig_path, first_n_files=3) + collection = BigWigCollection(bigwig_path, first_n_files=3) df = pd.read_csv(config.example_positions, sep="\t") df = df[df["chr"].isin(collection.get_chromosomes_present_in_all_files())] @@ -49,3 +50,74 @@ def test_same_output(bigwig_path): print(this_batch[pybigwig_batch != this_batch]) print(pybigwig_batch[pybigwig_batch != this_batch]) assert (pybigwig_batch == this_batch).all() + + +def test_same_output_with_nans(bigwig_path): + pybigwig_collection = PyBigWigCollection(bigwig_path, first_n_files=3) + collection = BigWigCollection(bigwig_path, first_n_files=3) + + df = pd.read_csv(config.example_positions, sep="\t") + df = df[df["chr"].isin(collection.get_chromosomes_present_in_all_files())] + chromosomes, starts, ends = ( + list(df["chr"]), + list(df["center"] - 1000), + list(df["center"] + 1000), + ) + + pybigwig_batch = pybigwig_collection.get_batch(chromosomes, starts, ends) + + this_batch = collection.get_batch( + chromosomes, starts, ends, default_value=np.nan + ).get() + print("PyBigWig:") + print(pybigwig_batch) + print(type(this_batch), "shape:", pybigwig_batch.shape) + print("This Library:") + print(this_batch) + print(type(this_batch), "shape:", this_batch.shape) + print(this_batch[pybigwig_batch != this_batch]) + print(pybigwig_batch[pybigwig_batch != this_batch]) + assert np.allclose(pybigwig_batch, this_batch, equal_nan=True) + + +@pytest.mark.parametrize("window_size", [2, 11, 128]) +@pytest.mark.parametrize("default_value", [0.0, np.nan, 2.0, 5.6, 10]) +@pytest.mark.parametrize("sequence_length", [1000, 2048]) +def test_windowed_output_against_pybigwig( + bigwig_path, window_size, default_value, sequence_length +): + print("window_size:", window_size) + pybigwig_collection = PyBigWigCollection(bigwig_path, first_n_files=3) + collection = BigWigCollection(bigwig_path, first_n_files=3) + + df = pd.read_csv(config.example_positions, sep="\t") + df = df[df["chr"].isin(collection.get_chromosomes_present_in_all_files())] + + chromosomes = list(df["chr"]) + starts = list(df["center"] - sequence_length // 2) + ends = [position + sequence_length for position in starts] + + pybigwig_batch = pybigwig_collection.get_batch(chromosomes, starts, ends) + + this_batch = collection.get_batch( + chromosomes, starts, ends, default_value=default_value, window_size=window_size + ).get() + + # Reshape the tensor so the last dimension contains + # all the values corresponding to one window + reduced_dim = sequence_length // window_size + pybigwig_batch = pybigwig_batch[:, :, : reduced_dim * window_size] + pybigwig_batch = pybigwig_batch.reshape( + pybigwig_batch.shape[0], pybigwig_batch.shape[1], reduced_dim, window_size + ) + + # fill nan's with the chosen default value + pybigwig_batch = np.nan_to_num(pybigwig_batch, copy=False, nan=default_value) + # And take mean over the window + pybigwig_batch = np.nanmean(pybigwig_batch, axis=-1) + + print("PyBigWig (with window function applied afterwards):") + print(pybigwig_batch) + print("bigwig-loader:") + print(this_batch) + assert np.allclose(pybigwig_batch, this_batch, equal_nan=True) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 1976f6f..dd30fc1 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -73,3 +73,38 @@ def test_batch_return_type(bigwig_path, reference_genome_path, merged_intervals) for i, batch in enumerate(dataset): assert isinstance(batch, Batch) assert batch.track_indices is not None + + +def test_positions_are_reproducible( + bigwig_path, reference_genome_path, merged_intervals +): + batch_size = 16 + + dataset = BigWigDataset( + regions_of_interest=merged_intervals, + collection=bigwig_path, + reference_genome_path=reference_genome_path, + sequence_length=2000, + center_bin_to_predict=1000, + window_size=4, + batch_size=batch_size, + batches_per_epoch=10, + maximum_unknown_bases_fraction=0.1, + first_n_files=2, + repeat_same_positions=True, + n_threads=1, + return_batch_objects=True, + ) + + starts_a = [ + position + for batch in dataset + for position in zip(batch.chromosomes, batch.starts) + ] + starts_b = [ + position + for batch in dataset + for position in zip(batch.chromosomes, batch.starts) + ] + + assert starts_a == starts_b diff --git a/tests/test_intervals_to_values.py b/tests/test_intervals_to_values.py index 9ef5215..8098b33 100644 --- a/tests/test_intervals_to_values.py +++ b/tests/test_intervals_to_values.py @@ -21,7 +21,8 @@ def test_throw_exception_when_queried_intervals_are_of_different_lengths() -> No ) -def test_get_values_from_intervals() -> None: +@pytest.mark.parametrize("default_value", [0.0, cp.nan]) +def test_get_values_from_intervals(default_value) -> None: """Probably most frequent situation.""" track_starts = cp.asarray([1, 3, 10, 12, 16], dtype=cp.int32) track_ends = cp.asarray([3, 10, 12, 16, 20], dtype=cp.int32) @@ -30,15 +31,22 @@ def test_get_values_from_intervals() -> None: query_ends = cp.asarray([17], dtype=cp.int32) reserved = cp.zeros((1, 15), dtype=cp.float32) values = intervals_to_values( - track_starts, track_ends, track_values, query_starts, query_ends, reserved + track_starts, + track_ends, + track_values, + query_starts, + query_ends, + default_value=default_value, + out=reserved, ) - assert ( - values - == cp.asarray([[20, 15, 15, 15, 15, 15, 15, 15, 30, 30, 40, 40, 40, 40, 50]]) - ).all() + expected = cp.asarray( + [[20, 15, 15, 15, 15, 15, 15, 15, 30, 30, 40, 40, 40, 40, 50]] + ) + assert cp.allclose(expected, values, equal_nan=True) -def test_get_values_from_intervals_edge_case_1() -> None: +@pytest.mark.parametrize("default_value", [0.0, cp.nan]) +def test_get_values_from_intervals_edge_case_1(default_value) -> None: """Query start is somewhere in a "gap".""" track_starts = cp.asarray([1, 10, 12, 16], dtype=cp.int32) track_ends = cp.asarray([3, 12, 16, 20], dtype=cp.int32) @@ -47,19 +55,28 @@ def test_get_values_from_intervals_edge_case_1() -> None: query_ends = cp.asarray([18], dtype=cp.int32) reserved = cp.zeros((1, 12), dtype=cp.dtype(" None: +@pytest.mark.parametrize("default_value", [0.0, cp.nan]) +def test_get_values_from_intervals_edge_case_2(default_value) -> None: """Query start is exactly at start index after "gap".""" track_starts = cp.asarray([1, 10, 12, 16], dtype=cp.int32) track_ends = cp.asarray([3, 12, 16, 20], dtype=cp.int32) @@ -68,15 +85,23 @@ def test_get_values_from_intervals_edge_case_2() -> None: query_ends = cp.asarray([18], dtype=cp.int32) reserved = cp.zeros((1, 8), dtype=cp.dtype(" None: +@pytest.mark.parametrize("default_value", [0.0, cp.nan]) +def test_get_values_from_intervals_edge_case_3(default_value) -> None: """Query end is somewhere in a "gap".""" track_starts = cp.asarray([5, 10, 12, 18], dtype=cp.int32) track_ends = cp.asarray([10, 12, 14, 20], dtype=cp.int32) @@ -85,15 +110,24 @@ def test_get_values_from_intervals_edge_case_3() -> None: query_ends = cp.asarray([16], dtype=cp.int32) reserved = cp.zeros((1, 7), dtype=cp.dtype(" None: +@pytest.mark.parametrize("default_value", [0.0, cp.nan]) +def test_get_values_from_intervals_edge_case_4(default_value) -> None: """Query end is exactly at end index before "gap".""" track_starts = cp.asarray([5, 10, 12, 18], dtype=cp.int32) track_ends = cp.asarray([10, 12, 14, 20], dtype=cp.int32) @@ -102,7 +136,13 @@ def test_get_values_from_intervals_edge_case_4() -> None: query_ends = cp.asarray([14], dtype=cp.int32) reserved = cp.zeros((1, 5), dtype=cp.dtype(" None: ] - query_starts[0] -def test_get_values_from_intervals_edge_case_5() -> None: +@pytest.mark.parametrize("default_value", [0.0, cp.nan]) +def test_get_values_from_intervals_edge_case_5(default_value) -> None: """Query end is exactly at end index before "gap".""" track_starts = cp.asarray([5, 10, 12, 18], dtype=cp.uint32) track_ends = cp.asarray([10, 12, 14, 20], dtype=cp.uint32) @@ -123,19 +164,28 @@ def test_get_values_from_intervals_edge_case_5() -> None: query_ends = cp.asarray([20], dtype=cp.uint32) reserved = cp.zeros((1, 11), dtype=cp.dtype(" None: +@pytest.mark.parametrize("default_value", [0.0, cp.nan]) +def test_get_values_from_intervals_batch_of_2(default_value) -> None: """Query end is exactly at end index before "gap".""" track_starts = cp.asarray([5, 10, 12, 18], dtype=cp.int32) track_ends = cp.asarray([10, 12, 14, 20], dtype=cp.int32) @@ -144,20 +194,30 @@ def test_get_values_from_intervals_batch_of_2() -> None: query_ends = cp.asarray([18, 20], dtype=cp.int32) reserved = cp.zeros([2, 11], dtype=cp.dtype(" None: +@pytest.mark.parametrize("default_value", [0.0, cp.nan]) +def test_get_values_from_intervals_batch_multiple_tracks(default_value) -> None: """Query end is exactly at end index before "gap".""" track_starts = cp.asarray( [5, 10, 12, 18, 8, 9, 10, 18, 25, 10, 100, 1000], dtype=cp.int32 @@ -178,27 +238,29 @@ def test_get_values_from_intervals_batch_multiple_tracks() -> None: track_values, query_starts, query_ends, - reserved, + default_value=default_value, + out=reserved, sizes=cp.asarray([4, 5, 3], dtype=cp.int32), ) + x = default_value expected = cp.asarray( [ [ - [20.0, 20.0, 20.0, 30.0, 30.0, 40.0, 40.0, 0.0, 0.0, 0.0, 0.0], - [20.0, 30.0, 30.0, 40.0, 40.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [20.0, 20.0, 20.0, 30.0, 30.0, 40.0, 40.0, x, x, x, x], + [20.0, 30.0, 30.0, 40.0, 40.0, x, x, x, x, 50.0, 50.0], + [x, x, x, x, x, x, x, x, x, x, x], + [x, x, x, x, x, x, x, x, x, x, x], ], [ - [0.0, 60.0, 70.0, 80.0, 80.0, 80.0, 80.0, 0.0, 0.0, 0.0, 0.0], - [70.0, 80.0, 80.0, 80.0, 80.0, 0.0, 0.0, 0.0, 0.0, 90.0, 90.0], - [90.0, 90.0, 0.0, 0.0, 0.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [x, 60.0, 70.0, 80.0, 80.0, 80.0, 80.0, x, x, x, x], + [70.0, 80.0, 80.0, 80.0, 80.0, x, x, x, x, 90.0, 90.0], + [90.0, 90.0, x, x, x, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0], + [x, x, x, x, x, x, x, x, x, x, x], ], [ - [0.0, 0.0, 0.0, 110.0, 110.0, 110.0, 110.0, 110.0, 110.0, 110.0, 110.0], + [x, x, x, 110.0, 110.0, 110.0, 110.0, 110.0, 110.0, 110.0, 110.0], [ - 0.0, + x, 110.0, 110.0, 110.0, @@ -210,9 +272,9 @@ def test_get_values_from_intervals_batch_multiple_tracks() -> None: 110.0, 110.0, ], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [x, x, x, x, x, x, x, x, x, x, x], [ - 0.0, + x, 120.0, 120.0, 120.0, @@ -229,35 +291,4 @@ def test_get_values_from_intervals_batch_multiple_tracks() -> None: ) print(expected) print(values) - assert (values == expected).all() - - -def test_default_nan() -> None: - """Query end is exactly at end index before "gap" - Now instead of zeros, NaN values should be - used. - .""" - track_starts = cp.asarray([5, 10, 12, 18], dtype=cp.int32) - track_ends = cp.asarray([10, 12, 14, 20], dtype=cp.int32) - track_values = cp.asarray([20.0, 30.0, 40.0, 50.0], dtype=cp.dtype("f4")) - query_starts = cp.asarray([7, 9], dtype=cp.int32) - query_ends = cp.asarray([18, 20], dtype=cp.int32) - reserved = cp.zeros([2, 11], dtype=cp.dtype(" None: +@pytest.mark.parametrize("default_value", [0.0, cp.nan]) +def test_get_values_from_intervals_window(default_value) -> None: """.""" track_starts = cp.asarray([1, 3, 10, 12, 16], dtype=cp.int32) track_ends = cp.asarray([3, 10, 12, 16, 20], dtype=cp.int32) @@ -20,8 +23,9 @@ def test_get_values_from_intervals_window() -> None: track_values, query_starts, query_ends, - reserved, window_size=5, + default_value=default_value, + out=reserved, ) expected = cp.asarray([[16.0, 21.0, 42.0]]) @@ -33,35 +37,45 @@ def test_get_values_from_intervals_window() -> None: assert (values == expected).all() -def test_get_values_from_intervals_edge_case_1() -> None: +@pytest.mark.parametrize("default_value", [0.0, cp.nan, 5.6, 10.0, 7565]) +def test_get_values_from_intervals_edge_case_1(default_value) -> None: """Query start is somewhere in a "gap".""" track_starts = cp.asarray([1, 10, 12, 16], dtype=cp.int32) track_ends = cp.asarray([3, 12, 16, 20], dtype=cp.int32) track_values = cp.asarray([20.0, 30.0, 40.0, 50.0], dtype=cp.dtype("f4")) query_starts = cp.asarray([6], dtype=cp.int32) query_ends = cp.asarray([18], dtype=cp.int32) - reserved = cp.zeros((1, 4), dtype=cp.dtype(" None: +@pytest.mark.parametrize("default_value", [0.0, cp.nan]) +def test_get_values_from_intervals_edge_case_2(default_value) -> None: """Query start is exactly at start index after "gap".""" track_starts = cp.asarray([1, 10, 12, 16], dtype=cp.int32) track_ends = cp.asarray([3, 12, 16, 20], dtype=cp.int32) @@ -75,17 +89,21 @@ def test_get_values_from_intervals_edge_case_2() -> None: track_values, query_starts, query_ends, - reserved, window_size=4, + default_value=default_value, + out=reserved, ) expected = cp.asarray([[35.0, 45.0]]) + print(expected) + print(values) assert ( - cp.allclose(values, expected) + cp.allclose(values, expected, equal_nan=True) and expected.shape[-1] == (query_ends[0] - query_starts[0]) / 4 ) -def test_get_values_from_intervals_edge_case_3() -> None: +@pytest.mark.parametrize("default_value", [0.0, cp.nan]) +def test_get_values_from_intervals_edge_case_3(default_value) -> None: """Query end is somewhere in a "gap".""" track_starts = cp.asarray([5, 10, 12, 18], dtype=cp.int32) track_ends = cp.asarray([10, 12, 14, 20], dtype=cp.int32) @@ -99,18 +117,25 @@ def test_get_values_from_intervals_edge_case_3() -> None: track_values, query_starts, query_ends, - reserved, window_size=4, + default_value=default_value, + out=reserved, ) # expected = cp.asarray([[20, 20, 30, 30, 40, 40, 0, 0]]) - expected = cp.asarray([[25.0, 20.0]]) - - assert (values == expected).all() and expected.shape[-1] == ( - query_ends[0] - query_starts[0] - ) / 4 + if isnan(default_value): + expected = cp.asarray([[25.0, 40.0]]) + else: + expected = cp.asarray([[25.0, 20.0]]) + print(expected) + print(values) + assert ( + cp.allclose(expected, values, equal_nan=True) + and expected.shape[-1] == (query_ends[0] - query_starts[0]) / 4 + ) -def test_get_values_from_intervals_edge_case_4() -> None: +@pytest.mark.parametrize("default_value", [0.0, cp.nan]) +def test_get_values_from_intervals_edge_case_4(default_value) -> None: """Query end is exactly at end index before "gap".""" track_starts = cp.asarray([5, 10, 12, 18], dtype=cp.int32) track_ends = cp.asarray([10, 12, 14, 20], dtype=cp.int32) @@ -124,8 +149,9 @@ def test_get_values_from_intervals_edge_case_4() -> None: track_values, query_starts, query_ends, - reserved, window_size=3, + default_value=default_value, + out=reserved, ) # without window function:[[20, 20, 30, 30, 40, 40]] expected = cp.asarray([[23.333334, 36.666668]]) @@ -134,12 +160,13 @@ def test_get_values_from_intervals_edge_case_4() -> None: print(values) assert ( - cp.allclose(values, expected) + cp.allclose(values, expected, equal_nan=True) and expected.shape[-1] == (query_ends[0] - query_starts[0]) / 3 ) -def test_get_values_from_intervals_edge_case_5() -> None: +@pytest.mark.parametrize("default_value", [0.0, cp.nan]) +def test_get_values_from_intervals_edge_case_5(default_value) -> None: """Query end is exactly at end index before "gap".""" track_starts = cp.asarray([5, 10, 12, 18], dtype=cp.uint32) track_ends = cp.asarray([10, 12, 14, 20], dtype=cp.uint32) @@ -153,21 +180,27 @@ def test_get_values_from_intervals_edge_case_5() -> None: track_values, query_starts, query_ends, - reserved, window_size=3, + default_value=default_value, + out=reserved, ) - expected = cp.asarray([[23.333334, 36.666668, 0.0, 33.333332]]) + x = default_value + if isnan(default_value): + expected = cp.asarray([[23.333334, 36.666668, x, 50.0]]) + else: + expected = cp.asarray([[23.333334, 36.666668, x, 33.333332]]) print(expected) print(values) assert ( - cp.allclose(values, expected) + cp.allclose(values, expected, equal_nan=True) and expected.shape[-1] == (query_ends[0] - query_starts[0]) / 3 ) -def test_get_values_from_intervals_batch_of_2() -> None: +@pytest.mark.parametrize("default_value", [0.0, cp.nan]) +def test_get_values_from_intervals_batch_of_2(default_value) -> None: """Query end is exactly at end index before "gap".""" track_starts = cp.asarray([5, 10, 12, 18], dtype=cp.int32) track_ends = cp.asarray([10, 12, 14, 20], dtype=cp.int32) @@ -181,22 +214,31 @@ def test_get_values_from_intervals_batch_of_2() -> None: track_values, query_starts, query_ends, - reserved, window_size=3, + default_value=default_value, + out=reserved, ) # expected without window function: # [20.0, 20.0, 20.0, 20.0, 30.0, 30.0, 40.0, 40.0, 0.0, 0.0, 0.0, 0.0] # [20.0, 20.0, 30.0, 30.0, 40.0, 40.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0] - expected = cp.asarray( - [[20.0, 26.666666, 26.666666, 0.0], [23.333334, 36.666668, 0.0, 33.333332]] - ) + if isnan(default_value): + expected = cp.asarray( + [[20.0, 26.666666, 40.0, cp.nan], [23.333334, 36.666668, cp.nan, 50.0]] + ) + else: + expected = cp.asarray( + [[20.0, 26.666666, 26.666666, 0.0], [23.333334, 36.666668, 0.0, 33.333332]] + ) + print("expected:") print(expected) + print("actual:") print(values) - assert cp.allclose(values, expected) + assert cp.allclose(values, expected, equal_nan=True) -def test_one_track_one_sample() -> None: +@pytest.mark.parametrize("default_value", [0.0, cp.nan]) +def test_one_track_one_sample(default_value) -> None: """ This tests a specific combination of track and batch index of the larger test case below: @@ -220,15 +262,17 @@ def test_one_track_one_sample() -> None: track_values, query_starts, query_ends, - reserved, sizes=cp.asarray([4], dtype=cp.int32), window_size=3, + default_value=default_value, + out=reserved, ) + x = default_value expected = cp.asarray( [ [ - [20.0, 30.0, 30.0, 40.0, 40.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0], + [20.0, 30.0, 30.0, 40.0, 40.0, x, x, x, x, 50.0, 50.0], ], ] ) @@ -236,9 +280,9 @@ def test_one_track_one_sample() -> None: def apply_window(full_matrix): return cp.stack( [ - cp.mean(full_matrix[:, :, :3], axis=2), - cp.mean(full_matrix[:, :, 3:6], axis=2), - cp.mean(full_matrix[:, :, 6:9], axis=2), + cp.nanmean(full_matrix[:, :, :3], axis=2), + cp.nanmean(full_matrix[:, :, 3:6], axis=2), + cp.nanmean(full_matrix[:, :, 6:9], axis=2), ], axis=-1, ) @@ -249,10 +293,11 @@ def apply_window(full_matrix): print(expected) print("actual:") print(values) - assert cp.allclose(values, expected) + assert cp.allclose(values, expected, equal_nan=True) -def test_one_track_one_sample_2() -> None: +@pytest.mark.parametrize("default_value", [0.0, cp.nan]) +def test_one_track_one_sample_2(default_value) -> None: """ This tests a specific combination of track and batch index of the larger test case below: @@ -276,16 +321,17 @@ def test_one_track_one_sample_2() -> None: track_values, query_starts, query_ends, - reserved, sizes=cp.asarray([3], dtype=cp.int32), window_size=3, + default_value=default_value, + out=reserved, ) - + x = default_value expected = cp.asarray( [ [ [ - 0.0, + x, 110.0, 110.0, 110.0, @@ -304,9 +350,9 @@ def test_one_track_one_sample_2() -> None: def apply_window(full_matrix): return cp.stack( [ - cp.mean(full_matrix[:, :, :3], axis=2), - cp.mean(full_matrix[:, :, 3:6], axis=2), - cp.mean(full_matrix[:, :, 6:9], axis=2), + cp.nanmean(full_matrix[:, :, :3], axis=2), + cp.nanmean(full_matrix[:, :, 3:6], axis=2), + cp.nanmean(full_matrix[:, :, 6:9], axis=2), ], axis=-1, ) @@ -319,10 +365,11 @@ def apply_window(full_matrix): print(values) print("difference") print(values - expected) - assert cp.allclose(values, expected, atol=1e-2, rtol=1e-2) + assert cp.allclose(values, expected, atol=1e-2, rtol=1e-2, equal_nan=True) -def test_get_values_from_intervals_multiple_tracks() -> None: +@pytest.mark.parametrize("default_value", [0.0, cp.nan]) +def test_get_values_from_intervals_multiple_tracks(default_value) -> None: """Test interval_to_values with 3 tracks and a batch size of 1.""" track_starts = cp.asarray( [5, 10, 12, 18, 8, 9, 10, 18, 25, 10, 100, 1000], dtype=cp.int32 @@ -343,22 +390,24 @@ def test_get_values_from_intervals_multiple_tracks() -> None: track_values, query_starts, query_ends, - reserved, sizes=cp.asarray([4, 5, 3], dtype=cp.int32), window_size=3, + default_value=default_value, + out=reserved, ) + x = default_value expected = cp.asarray( [ [ - [20.0, 30.0, 30.0, 40.0, 40.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0], + [20.0, 30.0, 30.0, 40.0, 40.0, x, x, x, x, 50.0, 50.0], ], [ - [70.0, 80.0, 80.0, 80.0, 80.0, 0.0, 0.0, 0.0, 0.0, 90.0, 90.0], + [70.0, 80.0, 80.0, 80.0, 80.0, x, x, x, x, 90.0, 90.0], ], [ [ - 0.0, + x, 110.0, 110.0, 110.0, @@ -377,9 +426,9 @@ def test_get_values_from_intervals_multiple_tracks() -> None: def apply_window(full_matrix): return cp.stack( [ - cp.mean(full_matrix[:, :, :3], axis=2), - cp.mean(full_matrix[:, :, 3:6], axis=2), - cp.mean(full_matrix[:, :, 6:9], axis=2), + cp.nanmean(full_matrix[:, :, :3], axis=2), + cp.nanmean(full_matrix[:, :, 3:6], axis=2), + cp.nanmean(full_matrix[:, :, 6:9], axis=2), ], axis=-1, ) @@ -392,13 +441,16 @@ def apply_window(full_matrix): print(values) print("difference") print(values - expected) - assert cp.allclose(values, expected, atol=1e-2, rtol=1e-2) + assert cp.allclose(values, expected, atol=1e-2, rtol=1e-2, equal_nan=True) @pytest.mark.parametrize( "sequence_length, window_size", product([8, 9, 10, 11, 13, 23], [2, 3, 4, 10, 11]) ) -def test_combinations_sequence_length_window_size(sequence_length, window_size) -> None: +@pytest.mark.parametrize("default_value", [0.0, cp.nan]) +def test_combinations_sequence_length_window_size( + sequence_length, window_size, default_value +) -> None: """Test intervals_to_values with 3 tracks and a batch size of 4.""" track_starts = cp.asarray( [5, 10, 12, 18, 8, 9, 10, 18, 25, 10, 100, 1000], dtype=cp.int32 @@ -421,9 +473,10 @@ def test_combinations_sequence_length_window_size(sequence_length, window_size) track_values, query_starts, query_ends, - reserved, sizes=cp.asarray([4, 5, 3], dtype=cp.int32), window_size=window_size, + default_value=default_value, + out=reserved, ) reserved = cp.zeros([3, 4, sequence_length], dtype=cp.dtype(" None: """.""" track_starts = cp.asarray([1, 3, 10, 12, 16] * n_tracks, dtype=cp.int32) - track_ends = cp.asarray([3, 10, 12, 16, 20] * n_tracks, dtype=cp.int32) + track_ends = cp.asarray([3, 8, 12, 16, 20] * n_tracks, dtype=cp.int32) track_values = cp.asarray( [20.0, 15.0, 30.0, 40.0, 50.0] * n_tracks, dtype=cp.dtype("f4") ) @@ -475,6 +530,92 @@ def test_combinations_window_size_batch_size_n_tracks( query_ends, sizes=sizes, window_size=window_size, + default_value=default_value, + ) + + values_with_window_size_1 = intervals_to_values( + track_starts, + track_ends, + track_values, + query_starts, + query_ends, + sizes=sizes, + window_size=1, + default_value=default_value, + ) + + reduced_dim = sequence_length // window_size + full_matrix = values_with_window_size_1[:, :, : reduced_dim * window_size] + full_matrix = full_matrix.reshape( + full_matrix.shape[0], full_matrix.shape[1], reduced_dim, window_size + ) + expected = cp.nanmean(full_matrix, axis=-1) + + print("expected:") + print(expected) + print("actual:") + print(values) + + assert cp.allclose(values, expected, equal_nan=True) + + +def create_random_track_data(n_tracks, min_intervals=10, max_intervals=20): + """Create random track data for testing.""" + track_starts = [] + track_ends = [] + values = [] + sizes = [] + for _ in range(n_tracks): + current_start = 0 + generate_n_intervals = np.random.randint(min_intervals, max_intervals) + for i in range(generate_n_intervals): + start = current_start + np.random.randint( + 1, 50 + ) # Ensure a gap between intervals + end = start + np.random.randint(1, 100) # Random interval length + track_starts.append(start) + track_ends.append(end) + values.append(np.random.random()) + current_start = end # Update the start for the next interval + sizes.append(generate_n_intervals) + + return ( + cp.asarray(track_starts, dtype=cp.int32), + cp.asarray(track_ends, dtype=cp.int32), + cp.asarray(values, dtype=cp.int32), + cp.asarray(sizes, dtype=cp.int32), + ) + + +def create_random_queries(batch_size, sequence_length=200): + start = np.random.randint(1, 50, size=batch_size) + end = start + sequence_length + return cp.asarray(start, dtype=cp.int32), cp.asarray(end, dtype=cp.int32) + + +@pytest.mark.parametrize( + "window_size, batch_size, n_tracks", + product([1, 2, 3, 7], [1, 2, 3, 7], [1, 2, 3, 7]), +) +@pytest.mark.parametrize("default_value", [0.0, cp.nan]) +def test_combinations_window_size_batch_size_n_tracks_on_random_data( + window_size, batch_size, n_tracks, default_value +) -> None: + sequence_length = 200 + track_starts, track_ends, track_values, sizes = create_random_track_data(n_tracks) + query_starts, query_ends = create_random_queries( + batch_size, sequence_length=sequence_length + ) + + values = intervals_to_values( + track_starts, + track_ends, + track_values, + query_starts, + query_ends, + sizes=sizes, + window_size=window_size, + default_value=default_value, ) values_with_window_size_1 = intervals_to_values( @@ -485,18 +626,25 @@ def test_combinations_window_size_batch_size_n_tracks( query_ends, sizes=sizes, window_size=1, + default_value=cp.nan, ) + cp.nan_to_num(values_with_window_size_1, copy=False, nan=default_value) + reduced_dim = sequence_length // window_size full_matrix = values_with_window_size_1[:, :, : reduced_dim * window_size] full_matrix = full_matrix.reshape( full_matrix.shape[0], full_matrix.shape[1], reduced_dim, window_size ) - expected = cp.mean(full_matrix, axis=-1) + expected = cp.nanmean(full_matrix, axis=-1) print("expected:") print(expected) print("actual:") print(values) - assert cp.allclose(values, expected) + assert cp.allclose(values, expected, equal_nan=True) + + +if __name__ == "__main__": + print(create_random_track_data(3)) diff --git a/tests/test_position_sampler.py b/tests/test_position_sampler.py new file mode 100644 index 0000000..9ef1a15 --- /dev/null +++ b/tests/test_position_sampler.py @@ -0,0 +1,39 @@ +from bigwig_loader.sampler.position_sampler import RandomPositionSampler + + +def test_repeat_same_positions(merged_intervals): + sampler = RandomPositionSampler( + regions_of_interest=merged_intervals, repeat_same=True + ) + + first_samples = [] + for i, sample in enumerate(sampler): + first_samples.append(sample) + if i == 5: + break + second_samples = [] + for i, sample in enumerate(sampler): + second_samples.append(sample) + if i == 5: + break + + assert first_samples == second_samples + + +def test_not_repeat_same_positions(merged_intervals): + sampler = RandomPositionSampler( + regions_of_interest=merged_intervals, repeat_same=False + ) + + first_samples = [] + for i, sample in enumerate(sampler): + first_samples.append(sample) + if i == 5: + break + second_samples = [] + for i, sample in enumerate(sampler): + second_samples.append(sample) + if i == 5: + break + + assert first_samples != second_samples diff --git a/tests/test_pytorch_dataset.py b/tests/test_pytorch_dataset.py index 4b6745e..d7ca95d 100644 --- a/tests/test_pytorch_dataset.py +++ b/tests/test_pytorch_dataset.py @@ -1,5 +1,10 @@ +from math import isnan + +import pandas as pd import pytest +from bigwig_loader import config + torch = pytest.importorskip("torch") @@ -30,3 +35,81 @@ def test_input_and_target_is_torch_tensor(pytorch_dataset): sequence, target = next(iter(pytorch_dataset)) assert isinstance(sequence, torch.Tensor) assert isinstance(target, torch.Tensor) + + +@pytest.mark.parametrize("default_value", [0.0, torch.nan, 4.0, 5.6]) +def test_pytorch_dataset_with_window_function( + default_value, bigwig_path, reference_genome_path, merged_intervals +): + from bigwig_loader.pytorch import PytorchBigWigDataset + + center_bin_to_predict = 2048 + window_size = 128 + reduced_dim = center_bin_to_predict // window_size + + batch_size = 16 + + df = pd.read_csv(config.example_positions, sep="\t") + df = df[df["chr"].isin({"chr1", "chr3", "chr5"})] + chromosomes = list(df["chr"])[:batch_size] + centers = list(df["center"])[:batch_size] + + position_sampler = [(chrom, center) for chrom, center in zip(chromosomes, centers)] + + dataset = PytorchBigWigDataset( + regions_of_interest=merged_intervals, + collection=bigwig_path, + reference_genome_path=reference_genome_path, + sequence_length=center_bin_to_predict * 2, + center_bin_to_predict=center_bin_to_predict, + window_size=1, + batch_size=batch_size, + batches_per_epoch=1, + maximum_unknown_bases_fraction=0.1, + first_n_files=3, + custom_position_sampler=position_sampler, + default_value=default_value, + return_batch_objects=True, + ) + + dataset_with_window = PytorchBigWigDataset( + regions_of_interest=merged_intervals, + collection=bigwig_path, + reference_genome_path=reference_genome_path, + sequence_length=center_bin_to_predict * 2, + center_bin_to_predict=center_bin_to_predict, + window_size=window_size, + batch_size=batch_size, + batches_per_epoch=1, + maximum_unknown_bases_fraction=0.1, + first_n_files=3, + custom_position_sampler=position_sampler, + default_value=default_value, + return_batch_objects=True, + ) + + print(dataset_with_window._dataset.bigwig_collection.bigwig_paths) + + for batch, batch_with_window in zip(dataset, dataset_with_window): + print(batch) + print(batch_with_window) + print(batch.chromosomes) + print(batch_with_window.chromosomes) + print(batch.starts) + print(batch_with_window.starts) + print(batch.ends) + print(batch_with_window.ends) + expected = batch.values.reshape( + batch.values.shape[0], batch.values.shape[1], reduced_dim, window_size + ) + if not isnan(default_value) or default_value == 0: + expected = torch.nan_to_num(expected, nan=default_value) + expected = torch.nanmean(expected, axis=-1) + print("---") + print("expected") + print(expected) + print("batch_with_window") + print(batch_with_window.values) + assert torch.allclose(expected, batch_with_window.values, equal_nan=True) + if isnan(default_value): + assert torch.isnan(batch_with_window.values).any() diff --git a/tests/test_window_function.py b/tests/test_window_function.py new file mode 100644 index 0000000..7c828a3 --- /dev/null +++ b/tests/test_window_function.py @@ -0,0 +1,120 @@ +from math import isnan + +import cupy as cp +import pandas as pd +import pytest + +from bigwig_loader import config +from bigwig_loader.collection import BigWigCollection +from bigwig_loader.dataset import BigWigDataset + + +@pytest.mark.parametrize("window_size", [2, 11, 32, 128]) +@pytest.mark.parametrize("default_value", [0.0, cp.nan]) +def test_same_output(bigwig_path, window_size, default_value): + collection = BigWigCollection(bigwig_path, first_n_files=3) + print(collection.bigwig_paths) + + df = pd.read_csv(config.example_positions, sep="\t") + df = df[df["chr"].isin(collection.get_chromosomes_present_in_all_files())] + chromosomes, starts, ends = ( + list(df["chr"]), + list(df["center"] - 1024), + list(df["center"] + 1024), + ) + full_batch = collection.get_batch( + chromosomes, starts, ends, window_size=1, default_value=default_value + ) + batch_with_window = collection.get_batch( + chromosomes, starts, ends, window_size=window_size, default_value=default_value + ) + sequence_length = 2048 + reduced_dim = sequence_length // window_size + + if isnan(default_value): + assert cp.isnan(full_batch).any() + else: + assert not cp.isnan(full_batch).any() + + full_matrix = full_batch[:, :, : reduced_dim * window_size] + full_matrix = full_matrix.reshape( + full_matrix.shape[0], full_matrix.shape[1], reduced_dim, window_size + ) + expected = cp.nanmean(full_matrix, axis=-1) + print(batch_with_window) + print(expected) + assert cp.allclose(expected, batch_with_window, equal_nan=True) + + +@pytest.mark.parametrize("default_value", [0.0, cp.nan, 4.0, 5.6]) +def test_dataset_with_window_function( + default_value, bigwig_path, reference_genome_path, merged_intervals +): + center_bin_to_predict = 2048 + window_size = 128 + reduced_dim = center_bin_to_predict // window_size + + batch_size = 16 + + df = pd.read_csv(config.example_positions, sep="\t") + df = df[df["chr"].isin({"chr1", "chr3", "chr5"})] + chromosomes = list(df["chr"])[:batch_size] + centers = list(df["center"])[:batch_size] + + position_sampler = [(chrom, center) for chrom, center in zip(chromosomes, centers)] + + dataset = BigWigDataset( + regions_of_interest=merged_intervals, + collection=bigwig_path, + reference_genome_path=reference_genome_path, + sequence_length=center_bin_to_predict * 2, + center_bin_to_predict=center_bin_to_predict, + window_size=1, + batch_size=batch_size, + batches_per_epoch=1, + maximum_unknown_bases_fraction=0.1, + first_n_files=3, + custom_position_sampler=position_sampler, + default_value=default_value, + return_batch_objects=True, + ) + + dataset_with_window = BigWigDataset( + regions_of_interest=merged_intervals, + collection=bigwig_path, + reference_genome_path=reference_genome_path, + sequence_length=center_bin_to_predict * 2, + center_bin_to_predict=center_bin_to_predict, + window_size=window_size, + batch_size=batch_size, + batches_per_epoch=1, + maximum_unknown_bases_fraction=0.1, + first_n_files=3, + custom_position_sampler=position_sampler, + default_value=default_value, + return_batch_objects=True, + ) + + for batch, batch_with_window in zip(dataset, dataset_with_window): + print(batch) + print(batch_with_window) + print(batch.chromosomes) + print(batch_with_window.chromosomes) + print(batch.starts) + print(batch_with_window.starts) + print(batch.ends) + print(batch_with_window.ends) + expected = batch.values.reshape( + batch.values.shape[0], batch.values.shape[1], reduced_dim, window_size + ) + if not isnan(default_value) or default_value == 0: + cp.nan_to_num(expected, copy=False, nan=default_value) + expected = cp.nanmean(expected, axis=-1) + print("---") + print("expected") + print(expected) + print("batch_with_window") + print(batch_with_window.values) + assert cp.allclose(expected, batch_with_window.values, equal_nan=True) + if isnan(default_value): + assert cp.isnan(batch_with_window.values).any()