diff --git a/apis/python/tests/conftest.py b/apis/python/tests/conftest.py new file mode 100644 index 000000000..dfc6629c5 --- /dev/null +++ b/apis/python/tests/conftest.py @@ -0,0 +1,229 @@ +import glob +import os +import platform +import shutil +import subprocess + +import numpy as np +import pytest +import tiledb +import tiledbvcf + +# Directory containing this file +CONTAINING_DIR = os.path.abspath(os.path.dirname(__file__)) + +# Test inputs directory +TESTS_INPUT_DIR = os.path.abspath( + os.path.join(CONTAINING_DIR, "../../../libtiledbvcf/test/inputs") +) + + +# Skip marker for tests that require bcftools, which may be absent on Windows CI. +skip_if_no_bcftools = pytest.mark.skipif( + os.environ.get("CI") == "true" + and platform.system() == "Windows" + and shutil.which("bcftools") is None, + reason="no bcftools", +) + + +def assert_dfs_equal(expected, actual): + """Assert that two DataFrames are equal, with type-aware column comparison. + + Floating-point columns are compared with np.isclose (NaN-safe). + Integer columns are cast to int64 before comparison. + All other columns use pandas Series.equals. + + Args: + expected: DataFrame containing the expected values. + actual: DataFrame containing the values under test. + + Raises: + AssertionError: If any column differs between expected and actual. + """ + + def assert_series(s1, s2): + if np.issubdtype(s2.dtype, np.floating): + assert np.isclose(s1, s2, equal_nan=True).all() + elif np.issubdtype(s2.dtype, np.integer): + assert s1.astype("int64").equals(s2.astype("int64")) + else: + assert s1.equals(s2) + + for k in expected: + assert_series(expected[k], actual[k]) + + for k in actual: + assert_series(expected[k], actual[k]) + + +def skip_if_incompatible(uri): + """Skip the current test if the TileDB array at uri is incompatible with the current environment. + + Attempts to open the array; if TileDB raises a format-version mismatch or + any other TileDBError the test is skipped rather than failed, because the + error indicates an environment incompatibility rather than a code defect. + + Args: + uri: Path to the TileDB array to check. + + Returns: + True if the array opened successfully. + + Raises: + pytest.skip.Exception: If the array has an incompatible format version + or any other TileDBError occurs. + """ + try: + with tiledb.open(uri): + return True + except tiledb.libtiledb.TileDBError as e: + if "incompatible format version" in str(e).lower(): + raise pytest.skip.Exception( + "Test skipped due to incompatible format version" + ) + raise pytest.skip.Exception(f"Test skipped due to TileDB error: {str(e)}") + + +@pytest.fixture +def bgzip_and_index_vcfs(): + """Fixture that provides a helper for bgzipping and indexing VCF files. + + The returned callable compresses every ``*.vcf`` file in ``input_dir`` with + ``bcftools view -Oz`` and then indexes each resulting ``.gz`` file with + ``bcftools index``. + + Usage:: + + vcf_files = bgzip_and_index_vcfs(input_dir) + vcf_files = bgzip_and_index_vcfs(input_dir, output_dir=tmp_path) + + Args: + input_dir: Directory containing the ``.vcf`` files to compress. + output_dir: Directory where the ``.gz`` files will be written. + Defaults to ``input_dir`` when omitted. + + Returns: + List of absolute paths to the produced ``.gz`` files. + """ + + def _bgzip_and_index(input_dir, output_dir=None): + if output_dir is None: + output_dir = input_dir + raw_inputs = glob.glob(os.path.join(input_dir, "*.vcf")) + for vcf_file in raw_inputs: + out = os.path.join(output_dir, os.path.basename(vcf_file)) + ".gz" + subprocess.run( + f"bcftools view --no-version -Oz -o {out} {vcf_file}", + shell=True, + check=True, + ) + bgzipped = glob.glob(os.path.join(output_dir, "*.gz")) + for vcf_file in bgzipped: + assert ( + subprocess.run( + f"bcftools index {vcf_file}", shell=True + ).returncode + == 0 + ) + return bgzipped + + return _bgzip_and_index + + +@pytest.fixture +def v4_dataset(): + """Open the pre-ingested v4 2-sample dataset in read mode. + + Returns: + tiledbvcf.Dataset: Read-mode dataset backed by arrays/v4/ingested_2samples. + """ + return tiledbvcf.Dataset( + os.path.join(TESTS_INPUT_DIR, "arrays/v4/ingested_2samples") + ) + + +@pytest.fixture +def v3_dataset(): + """Open the pre-ingested v3 2-sample dataset in read mode. + + Returns: + tiledbvcf.Dataset: Read-mode dataset backed by arrays/v3/ingested_2samples. + """ + return tiledbvcf.Dataset( + os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") + ) + + +@pytest.fixture +def v3_dataset_with_attrs(): + """Open the pre-ingested v3 2-sample dataset that includes GT, DP, and PL attributes. + + Returns: + tiledbvcf.Dataset: Read-mode dataset backed by arrays/v3/ingested_2samples_GT_DP_PL. + """ + return tiledbvcf.Dataset( + os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples_GT_DP_PL") + ) + + +@pytest.fixture +def stats_bgzipped_vcfs(tmp_path, bgzip_and_index_vcfs): + """Copy the stats VCF test inputs to tmp_path, bgzip and index them. + + Args: + tmp_path: Pytest-provided temporary directory for this test. + bgzip_and_index_vcfs: Fixture-provided helper that compresses and indexes VCF files. + + Returns: + List[str]: Paths to the bgzipped and indexed ``.gz`` files inside tmp_path. + """ + shutil.copytree( + os.path.join(TESTS_INPUT_DIR, "stats"), os.path.join(tmp_path, "stats") + ) + return bgzip_and_index_vcfs(os.path.join(tmp_path, "stats")) + + +@pytest.fixture +def stats_sample_names(stats_bgzipped_vcfs): + """Return the sample names for the 8 bgzipped stats inputs. + + Sample names are extracted from the file names: each file is named + ``.vcf.gz``, so splitting on ``"."`` and taking the first + part yields the sample name. + + Args: + stats_bgzipped_vcfs: Fixture-provided list of bgzipped VCF file paths. + + Returns: + List[str]: One sample name per bgzipped input file. + """ + assert len(stats_bgzipped_vcfs) == 8 + return [ + sample_name + for f in stats_bgzipped_vcfs + for sample_name, *_ in [os.path.basename(f).split(".")] + ] + + +@pytest.fixture +def stats_v3_dataset(tmp_path, stats_bgzipped_vcfs): + """Create and return a v3 dataset with variant stats and allele counting enabled. + + All 8 stats samples are ingested before the dataset is returned in read mode. + + Args: + tmp_path: Pytest-provided temporary directory for this test. + stats_bgzipped_vcfs: Fixture-provided list of bgzipped VCF file paths to ingest. + + Returns: + tiledbvcf.Dataset: Read-mode dataset with variant_stats_version=3, + enable_variant_stats=True, and enable_allele_count=True. + """ + assert len(stats_bgzipped_vcfs) == 8 + ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="w") + ds.create_dataset( + enable_variant_stats=True, enable_allele_count=True, variant_stats_version=3 + ) + ds.ingest_samples(stats_bgzipped_vcfs) + return tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="r") diff --git a/apis/python/tests/test_compression.py b/apis/python/tests/test_compression.py new file mode 100644 index 000000000..0535451ab --- /dev/null +++ b/apis/python/tests/test_compression.py @@ -0,0 +1,48 @@ +import os + +import pytest +import tiledb +import tiledbvcf + +from .conftest import skip_if_incompatible, TESTS_INPUT_DIR + +@pytest.mark.parametrize("compress", [True, False]) +def test_sample_compression(tmp_path, compress): + """Verify that compress_sample_dim controls whether the sample dimension has a Zstd filter.""" + # Create the dataset + dataset_uri = os.path.join(tmp_path, "sample_compression") + array_uri = os.path.join(dataset_uri, "data") + ds = tiledbvcf.Dataset(dataset_uri, mode="w") + ds.create_dataset(compress_sample_dim=compress) + + skip_if_incompatible(array_uri) + + # Check for the presence of the Zstd filter + found_zstd = False + with tiledb.open(array_uri) as A: + for filter in A.domain.dim("sample").filters: + found_zstd = found_zstd or "Zstd" in str(filter) + + assert found_zstd == compress + + +@pytest.mark.parametrize("level", [1, 4, 16, 22]) +def test_compression_level(tmp_path, level): + """Verify that compression_level sets the Zstd level on all attributes.""" + # Create the dataset + dataset_uri = os.path.join(tmp_path, "compression_level") + array_uri = os.path.join(dataset_uri, "data") + ds = tiledbvcf.Dataset(dataset_uri, mode="w") + ds.create_dataset(compression_level=level) + + skip_if_incompatible(array_uri) + + # Check for the expected compression level + with tiledb.open(array_uri) as A: + for i in range(A.schema.nattr): + attr = A.schema.attr(i) + for filter in attr.filters: + if "Zstd" in str(filter): + assert filter.level == level + + diff --git a/apis/python/tests/test_config.py b/apis/python/tests/test_config.py new file mode 100644 index 000000000..563b3fd9c --- /dev/null +++ b/apis/python/tests/test_config.py @@ -0,0 +1,243 @@ +import os + +import pytest +import tiledb +import tiledbvcf + +from .conftest import TESTS_INPUT_DIR + +@pytest.mark.parametrize("level", ["fatal", "error", "warn", "info", "debug", "trace"]) +def test_config_logging_valid_levels(level): + """Smoke Test: Verify all documented log levels are accepted.""" + tiledbvcf.config_logging(level) + + +def test_config_logging_invalid_level_raises(): + """Verify an unrecognized log level raises an exception.""" + with pytest.raises(Exception, match="Unsupported log level"): + tiledbvcf.config_logging("verbose") + + +def test_config_logging_log_file(tmp_path): + """Smoke Test: Verify a log_file path is accepted.""" + tiledbvcf.config_logging("fatal", log_file=str(tmp_path / "tiledbvcf.log")) + + +def test_read_config(): + """Verify that ReadConfig parameters are accepted and that invalid parameters raise.""" + uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") + cfg = tiledbvcf.ReadConfig() + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + + cfg = tiledbvcf.ReadConfig( + memory_budget_mb=512, + region_partition=(0, 3), + tiledb_config=["sm.tile_cache_size=0", "sm.compute_concurrency_level=1"], + ) + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + + with pytest.raises(TypeError): + cfg = tiledbvcf.ReadConfig(abc=123) + + # Expect an exception when passing both cfg and tiledb_config + with pytest.raises(Exception): + cfg = tiledbvcf.ReadConfig() + tiledb_config = {"foo": "bar"} + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg, tiledb_config=tiledb_config) + + +def test_read_limit(): + """Verify that ReadConfig limit truncates results to the specified number of rows.""" + uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") + cfg = tiledbvcf.ReadConfig(limit=3) + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + df = ds.read( + attrs=["sample_name", "pos_start", "pos_end", "fmt_DP", "fmt_PL"], + regions=["1:12100-13360", "1:13500-17350"], + ) + assert len(df) == 3 + + +def test_region_partitioned_read(): + """Verify that region_partition splits reads across partitions correctly.""" + uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") + + cfg = tiledbvcf.ReadConfig(region_partition=(0, 2)) + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + df = ds.read( + attrs=["sample_name", "pos_start", "pos_end"], + regions=["1:12000-13000", "1:17000-18000"], + ) + assert len(df) == 4 + + cfg = tiledbvcf.ReadConfig(region_partition=(1, 2)) + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + df = ds.read( + attrs=["sample_name", "pos_start", "pos_end"], + regions=["1:12000-13000", "1:17000-18000"], + ) + assert len(df) == 2 + + # Too many partitions still produces results + cfg = tiledbvcf.ReadConfig(region_partition=(1, 3)) + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + df = ds.read( + attrs=["sample_name", "pos_start", "pos_end"], + regions=["1:12000-13000", "1:17000-18000"], + ) + assert len(df) == 2 + + # Error: index >= num partitions + cfg = tiledbvcf.ReadConfig(region_partition=(2, 2)) + with pytest.raises(RuntimeError): + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + + +def test_sample_partitioned_read(): + """Verify that sample_partition splits reads by sample correctly.""" + uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") + + cfg = tiledbvcf.ReadConfig(sample_partition=(0, 2)) + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + df = ds.read( + attrs=["sample_name", "pos_start", "pos_end"], regions=["1:12000-18000"] + ) + assert len(df) == 11 + assert (df.sample_name == "HG00280").all() + + cfg = tiledbvcf.ReadConfig(sample_partition=(1, 2)) + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + df = ds.read( + attrs=["sample_name", "pos_start", "pos_end"], regions=["1:12000-18000"] + ) + assert len(df) == 3 + assert (df.sample_name == "HG01762").all() + + # Error: too many partitions + cfg = tiledbvcf.ReadConfig(sample_partition=(1, 3)) + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + with pytest.raises(RuntimeError): + df = ds.read( + attrs=["sample_name", "pos_start", "pos_end"], regions=["1:12000-18000"] + ) + + # Error: index >= num partitions + cfg = tiledbvcf.ReadConfig(sample_partition=(2, 2)) + with pytest.raises(RuntimeError): + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + + +def test_sample_and_region_partitioned_read(): + """Verify combined sample and region partitioning.""" + uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") + + cfg = tiledbvcf.ReadConfig(region_partition=(0, 2), sample_partition=(0, 2)) + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + df = ds.read( + attrs=["sample_name", "pos_start", "pos_end"], + regions=["1:12000-13000", "1:17000-18000"], + ) + assert len(df) == 2 + assert (df.sample_name == "HG00280").all() + + cfg = tiledbvcf.ReadConfig(region_partition=(0, 2), sample_partition=(1, 2)) + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + df = ds.read( + attrs=["sample_name", "pos_start", "pos_end"], + regions=["1:12000-13000", "1:17000-18000"], + ) + assert len(df) == 2 + assert (df.sample_name == "HG01762").all() + + cfg = tiledbvcf.ReadConfig(region_partition=(1, 2), sample_partition=(0, 2)) + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + df = ds.read( + attrs=["sample_name", "pos_start", "pos_end"], + regions=["1:12000-13000", "1:17000-18000"], + ) + assert len(df) == 2 + assert (df.sample_name == "HG00280").all() + + cfg = tiledbvcf.ReadConfig(region_partition=(1, 2), sample_partition=(1, 2)) + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + df = ds.read( + attrs=["sample_name", "pos_start", "pos_end"], + regions=["1:12000-13000", "1:17000-18000"], + ) + assert len(df) == 0 + + +def test_sort_regions(): + """Verify disabling region sorting returns the same records as sorted.""" + uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") + regions = ["1:17000-18000", "1:12000-13000"] # intentionally out of order + + cfg_sorted = tiledbvcf.ReadConfig(sort_regions=True) + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg_sorted) + df_sorted = ds.read(attrs=["sample_name", "pos_start"], regions=regions) + + cfg_unsorted = tiledbvcf.ReadConfig(sort_regions=False) + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg_unsorted) + df_unsorted = ds.read(attrs=["sample_name", "pos_start"], regions=regions) + + assert len(df_sorted) == len(df_unsorted) + + +def test_buffer_percentage_and_tile_cache_percentage(): + """Smoke Test: Verify non-default buffer and tile cache percentages are accepted.""" + uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") + cfg = tiledbvcf.ReadConfig(buffer_percentage=30, tiledb_tile_cache_percentage=5) + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + df = ds.read(attrs=["sample_name", "pos_start"], regions=["1:12000-13000"]) + assert len(df) > 0 + + +def test_tiledb_config_as_dict(): + """Smoke Test: Verify tiledb_config accepts a dict.""" + uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") + cfg = tiledbvcf.ReadConfig( + tiledb_config={"sm.tile_cache_size": "0", "sm.compute_concurrency_level": "1"} + ) + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + df = ds.read(attrs=["sample_name", "pos_start"], regions=["1:12000-13000"]) + assert len(df) > 0 + + +def test_tiledb_config_as_tiledb_config_object(): + """Smoke Test: Verify tiledb_config accepts a tiledb.Config object.""" + uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") + tiledb_cfg = tiledb.Config({"sm.tile_cache_size": "0"}) + cfg = tiledbvcf.ReadConfig(tiledb_config=tiledb_cfg) + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + df = ds.read(attrs=["sample_name", "pos_start"], regions=["1:12000-13000"]) + assert len(df) > 0 + + +@pytest.mark.skipif(os.environ.get("CI") != "true", reason="CI only") +def test_large_export_correctness(): + """Verify large export from S3 produces the expected total and unique record counts.""" + uri = "s3://tiledb-inc-demo-data/tiledbvcf-arrays/v4/vcf-samples-20" + + ds = tiledbvcf.Dataset(uri) + df = ds.read( + attrs=[ + "sample_name", + "contig", + "pos_start", + "pos_end", + "query_bed_start", + "query_bed_end", + ], + samples=["v2-DjrIAzkP", "v2-YMaDHIoW", "v2-usVwJUmo", "v2-ZVudhauk"], + bed_file=os.path.join( + TESTS_INPUT_DIR, "E001_15_coreMarks_dense_filtered.bed.gz" + ), + ) + + # total number of exported records + assert df.shape[0] == 1172081 + + # number of unique exported records + record_index = ["sample_name", "contig", "pos_start"] + assert df[record_index].drop_duplicates().shape[0] == 1168430 + diff --git a/apis/python/tests/test_dataset.py b/apis/python/tests/test_dataset.py new file mode 100644 index 000000000..c921ace64 --- /dev/null +++ b/apis/python/tests/test_dataset.py @@ -0,0 +1,244 @@ +import json +import os + +import pytest +import tiledbvcf + +from .conftest import TESTS_INPUT_DIR + + +def test_invalid_mode_raises(): + """Verify an unrecognized mode string raises at construction time.""" + uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") + with pytest.raises(Exception, match="Unsupported dataset mode"): + tiledbvcf.Dataset(uri, mode="x") + + +def test_version(v3_dataset, v4_dataset): + """Verify version() reports TileDB-VCF, TileDB, and htslib versions.""" + for ds in [v3_dataset, v4_dataset]: + v = ds.version() + assert "TileDB-VCF version" in v + assert "TileDB version" in v + assert "htslib version" in v + + +def test_schema_version(v3_dataset, v4_dataset): + """Verify schema_version() returns the correct version for v3 and v4 datasets.""" + assert v3_dataset.schema_version() == 3 + assert v4_dataset.schema_version() == 4 + + +def test_basic_count(v3_dataset): + """Verify count() returns the expected total record count.""" + assert v3_dataset.count() == 14 + + +def test_retrieve_attributes(v3_dataset): + """Verify attributes() returns the correct builtin, info, and fmt attribute lists.""" + builtin_attrs = [ + "sample_name", + "contig", + "pos_start", + "pos_end", + "alleles", + "id", + "fmt", + "info", + "filters", + "qual", + "query_bed_end", + "query_bed_start", + "query_bed_line", + ] + assert sorted(v3_dataset.attributes(attr_type="builtin")) == sorted(builtin_attrs) + + info_attrs = [ + "info_BaseQRankSum", + "info_ClippingRankSum", + "info_DP", + "info_DS", + "info_END", + "info_HaplotypeScore", + "info_InbreedingCoeff", + "info_MLEAC", + "info_MLEAF", + "info_MQ", + "info_MQ0", + "info_MQRankSum", + "info_ReadPosRankSum", + ] + assert v3_dataset.attributes(attr_type="info") == info_attrs + + fmt_attrs = [ + "fmt_AD", + "fmt_DP", + "fmt_GQ", + "fmt_GT", + "fmt_MIN_DP", + "fmt_PL", + "fmt_SB", + ] + assert v3_dataset.attributes(attr_type="fmt") == fmt_attrs + + +def test_retrieve_attributes_invalid_type_raises(v3_dataset): + """Verify attributes() raises for an unrecognized attr_type.""" + with pytest.raises(TypeError): + v3_dataset.attributes(attr_type="unknown") + + +def test_retrieve_samples(v3_dataset): + """Verify samples() returns the expected sample names.""" + assert v3_dataset.samples() == ["HG00280", "HG01762"] + + +def test_sample_count(v3_dataset, v4_dataset): + """Verify sample_count() is consistent with len(samples()).""" + assert v3_dataset.sample_count() == 2 + assert v3_dataset.sample_count() == len(v3_dataset.samples()) + assert v4_dataset.sample_count() == 2 + assert v4_dataset.sample_count() == len(v4_dataset.samples()) + + +def test_sample_count_write_mode_raises(tmp_path): + """Verify sample_count() raises in write mode.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset() + with pytest.raises(Exception, match="Samples can only be retrieved for reader"): + ds.sample_count() + + +def test_multiple_counts(v3_dataset): + """Verify count() with various region and sample filters returns correct counts.""" + assert v3_dataset.count() == 14 + assert v3_dataset.count() == 14 + assert v3_dataset.count(regions=["1:12700-13400"]) == 6 + assert v3_dataset.count(samples=["HG00280"], regions=["1:12700-13400"]) == 4 + assert v3_dataset.count() == 14 + assert v3_dataset.count(samples=["HG01762"]) == 3 + assert v3_dataset.count(samples=["HG00280"]) == 11 + + +def test_empty_region(v3_dataset): + """Verify count() returns 0 for a region with no data.""" + assert v3_dataset.count(regions=["12:1-1000000"]) == 0 + + +def test_missing_sample_raises_exception(v3_dataset): + """Verify count() raises RuntimeError for a nonexistent sample name.""" + with pytest.raises(RuntimeError): + v3_dataset.count(samples=["abcde"]) + + +# TODO remove skip +@pytest.mark.skip +def test_bad_contig_raises_exception(v3_dataset): + """Verify count() raises RuntimeError for invalid contig or region formats.""" + with pytest.raises(RuntimeError): + v3_dataset.count(regions=["chr1:1-1000000"]) + with pytest.raises(RuntimeError): + v3_dataset.count(regions=["1"]) + with pytest.raises(RuntimeError): + v3_dataset.count(regions=["1:100-"]) + with pytest.raises(RuntimeError): + v3_dataset.count(regions=["1:-100"]) + + +def test_read_write_mode_exceptions(): + """Verify that read operations fail in write mode and write operations fail in read mode.""" + ds = tiledbvcf.Dataset(os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples")) + samples = [os.path.join(TESTS_INPUT_DIR, s) for s in ["small.bcf", "small2.bcf"]] + + with pytest.raises(Exception): + ds.create_dataset() + + with pytest.raises(Exception): + ds.ingest_samples(samples) + + ds = tiledbvcf.Dataset( + os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples"), mode="w" + ) + with pytest.raises(Exception): + ds.count() + + +def test_context_manager(): + """Verify that Dataset works as a context manager and raises after close().""" + ds1_uri = os.path.join(TESTS_INPUT_DIR, "arrays/v4/ingested_2samples") + expected_count1 = 14 + ds2_uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/synth-array") + expected_count2 = 19565 + + # Test the context manager + with tiledbvcf.Dataset(ds1_uri) as ds: + assert ds.count() == expected_count1 + + with tiledbvcf.Dataset(ds2_uri) as ds: + assert ds.count() == expected_count2 + + # Open the datasets outside the context manager + ds1 = tiledbvcf.Dataset(ds1_uri) + assert ds1.count() == expected_count1 + + ds2 = tiledbvcf.Dataset(ds2_uri) + assert ds2.count() == expected_count2 + + # Check that an exception is raised when trying to access a closed dataset + ds1.close() + with pytest.raises(Exception): + assert ds1.count() == expected_count1 + + assert ds2.count() == expected_count2 + + ds2.close() + with pytest.raises(Exception): + assert ds2.count() == expected_count2 + + +# get_tiledb_stats_enabled is referenced without () in the guard condition, so it +# always evaluates to the method object (truthy) and the check never fires. +# Once that bug is fixed, this test should pass and the skip can be removed. +@pytest.mark.skip(reason="bug: get_tiledb_stats_enabled called without () so the guard never raises") +def test_tiledb_stats_raises_when_not_enabled(): + """Verify tiledb_stats() raises when stats were not enabled at open time.""" + uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") + ds = tiledbvcf.Dataset(uri, mode="r") # stats=False by default + ds.count() + with pytest.raises(Exception, match="TileDB read stats not enabled"): + ds.tiledb_stats() + + +def test_deprecated_tiledbvcfdataset_warns(v3_dataset): + """Verify the deprecated TileDBVCFDataset constructor emits a DeprecationWarning.""" + uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") + with pytest.warns(DeprecationWarning, match="TileDBVCFDataset is deprecated"): + tiledbvcf.TileDBVCFDataset(uri, mode="r") + + +def test_deprecated_tiledbvcfdataset_is_functional(v3_dataset): + """Verify the deprecated TileDBVCFDataset is still functional.""" + uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") + with pytest.warns(DeprecationWarning): + ds = tiledbvcf.TileDBVCFDataset(uri, mode="r") + assert ds.count() == 14 + + +def test_tiledb_stats_read_mode(v3_dataset): + """Verify tiledb_stats() returns valid JSON after a read operation.""" + v3_dataset.count() + stats = v3_dataset.tiledb_stats() + assert len(stats) > 0 + json.loads(stats) # raises if not valid JSON + + +def test_tiledb_stats_write_mode(tmp_path): + """Verify tiledb_stats() returns valid JSON after an ingest operation.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w", stats=True) + ds.create_dataset() + ds.ingest_samples([os.path.join(TESTS_INPUT_DIR, "small.bcf")]) + stats = ds.tiledb_stats() + assert len(stats) > 0 + json.loads(stats) # raises if not valid JSON diff --git a/apis/python/tests/test_delete.py b/apis/python/tests/test_delete.py new file mode 100644 index 000000000..a6485b706 --- /dev/null +++ b/apis/python/tests/test_delete.py @@ -0,0 +1,101 @@ +import os + +import pytest +import tiledb +import tiledbvcf + +from .conftest import skip_if_no_bcftools, TESTS_INPUT_DIR + + +def test_delete_dataset(tmp_path): + """Verify that Dataset.delete() removes the dataset from disk.""" + uri = os.path.join(tmp_path, "delete_dataset") + + with tiledbvcf.Dataset(uri, mode="w") as ds: + ds.create_dataset() + + assert os.path.exists(uri) + tiledbvcf.Dataset.delete(uri) + assert not os.path.exists(uri) + + +def test_delete_dataset_with_config(tmp_path): + """Smoke Test: Verify Dataset.delete() accepts a config parameter.""" + uri = os.path.join(tmp_path, "delete_dataset") + + with tiledbvcf.Dataset(uri, mode="w") as ds: + ds.create_dataset() + + assert os.path.exists(uri) + tiledbvcf.Dataset.delete(uri, config={"sm.tile_cache_size": "0"}) + assert not os.path.exists(uri) + + +def test_delete_dataset_nonexistent_uri_raises(tmp_path): + """Verify deleting a nonexistent URI raises TileDBError.""" + uri = os.path.join(tmp_path, "nonexistent") + with pytest.raises(tiledb.TileDBError): + tiledbvcf.Dataset.delete(uri) + + +@skip_if_no_bcftools +def test_delete_samples(tmp_path, stats_v3_dataset, stats_sample_names): + """Verify that delete_samples() removes the specified samples from the dataset.""" + assert "second" in stats_sample_names + assert "fifth" in stats_sample_names + assert "third" in stats_sample_names + ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="w") + ds.delete_samples(["second", "fifth"]) + ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="r") + sample_names = ds.samples() + assert "second" not in sample_names + assert "fifth" not in sample_names + assert "third" in sample_names + + +def test_delete_samples_empty_list_is_noop(tmp_path): + """Verify delete_samples with an empty list is a no-op.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset() + ds.ingest_samples([os.path.join(TESTS_INPUT_DIR, s) for s in ["small.bcf", "small2.bcf"]]) + + ds = tiledbvcf.Dataset(uri, mode="w") + ds.delete_samples([]) + + ds = tiledbvcf.Dataset(uri, mode="r") + assert set(ds.samples()) == {"HG00280", "HG01762"} + + +def test_delete_samples_none_raises(tmp_path): + """Verify delete_samples(None) raises TypeError.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset() + + with pytest.raises(TypeError): + ds.delete_samples(None) + + +def test_delete_samples_nonexistent_raises(tmp_path): + """Verify deleting a nonexistent sample raises RuntimeError.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset() + ds.ingest_samples([os.path.join(TESTS_INPUT_DIR, "small.bcf")]) + + ds = tiledbvcf.Dataset(uri, mode="w") + with pytest.raises(RuntimeError, match="Sample not found in dataset"): + ds.delete_samples(["NONEXISTENT"]) + + +def test_delete_samples_read_mode_raises(tmp_path): + """Verify delete_samples() raises in read mode.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset() + ds.ingest_samples([os.path.join(TESTS_INPUT_DIR, "small.bcf")]) + + ds = tiledbvcf.Dataset(uri, mode="r") + with pytest.raises(Exception, match="Dataset not open in write mode"): + ds.delete_samples(["HG00280"]) diff --git a/apis/python/tests/test_export.py b/apis/python/tests/test_export.py new file mode 100644 index 000000000..041337572 --- /dev/null +++ b/apis/python/tests/test_export.py @@ -0,0 +1,95 @@ +import os + +import pytest +import tiledbvcf + +from .conftest import TESTS_INPUT_DIR + + +def test_export_default(tmp_path, v4_dataset): + """Verify default export produces one compressed VCF per sample.""" + v4_dataset.export(output_dir=str(tmp_path)) + assert set(os.listdir(tmp_path)) == {"HG00280.vcf.gz", "HG01762.vcf.gz"} + + +def test_export_samples_filter(tmp_path, v4_dataset): + """Verify export can be filtered to specific samples.""" + v4_dataset.export(samples=["HG00280"], output_dir=str(tmp_path)) + assert os.listdir(tmp_path) == ["HG00280.vcf.gz"] + + +def test_export_regions_filter(tmp_path, v4_dataset): + """Verify export can be filtered to a specific genomic region.""" + v4_dataset.export(regions=["1:12000-13000"], output_dir=str(tmp_path)) + assert set(os.listdir(tmp_path)) == {"HG00280.vcf.gz", "HG01762.vcf.gz"} + + +@pytest.mark.parametrize( + "output_format, expected_files", + [ + ("z", {"HG00280.vcf.gz", "HG01762.vcf.gz"}), + ("v", {"HG00280.vcf", "HG01762.vcf"}), + ("b", {"HG00280.bcf", "HG01762.bcf"}), + ("u", {"HG00280.bcf", "HG01762.bcf"}), + ], +) +def test_export_output_format(tmp_path, output_format, expected_files): + """Verify each output format produces files with the correct extension.""" + ds = tiledbvcf.Dataset( + os.path.join(TESTS_INPUT_DIR, "arrays/v4/ingested_2samples"), mode="r" + ) + ds.export(output_format=output_format, output_dir=str(tmp_path)) + assert set(os.listdir(tmp_path)) == expected_files + + +def test_export_merge(tmp_path, v4_dataset): + """Verify merged export produces a single combined output file.""" + out = str(tmp_path / "merged.vcf.gz") + v4_dataset.export(merge=True, output_path=out, output_dir=str(tmp_path)) + assert os.path.exists(out) + assert os.listdir(tmp_path) == ["merged.vcf.gz"] + + +def test_export_merge_without_output_path_raises(tmp_path, v4_dataset): + """Verify merged export requires an output_path.""" + with pytest.raises(Exception, match="output_path required when merge=True"): + v4_dataset.export(merge=True, output_dir=str(tmp_path)) + + +def test_export_samples_file(tmp_path, v4_dataset): + """Verify export can be filtered by a samples file.""" + samples_file = str(tmp_path / "samples.txt") + out = str(tmp_path / "out") + os.makedirs(out) + with open(samples_file, "w") as f: + f.write("HG00280\n") + v4_dataset.export(samples_file=samples_file, output_dir=out) + assert os.listdir(out) == ["HG00280.vcf.gz"] + + +def test_export_bed_file(tmp_path, v4_dataset): + """Verify export can be filtered by a BED file.""" + bed_file = str(tmp_path / "regions.bed") + out = str(tmp_path / "out") + os.makedirs(out) + with open(bed_file, "w") as f: + f.write("1\t12000\t13000\n") + v4_dataset.export(bed_file=bed_file, output_dir=out) + assert set(os.listdir(out)) == {"HG00280.vcf.gz", "HG01762.vcf.gz"} + + +def test_export_skip_check_samples(tmp_path, v4_dataset): + """Verify skipping sample existence checks silently produces no output for unknown samples.""" + v4_dataset.export( + samples=["NOSUCHSAMPLE"], skip_check_samples=True, output_dir=str(tmp_path) + ) + assert os.listdir(tmp_path) == [] + + +def test_export_write_mode_raises(tmp_path): + """Verify export raises when the dataset is open in write mode.""" + uri = str(tmp_path / "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset() + with pytest.raises(Exception, match="Dataset not open in read mode"): + ds.export(output_dir=str(tmp_path)) diff --git a/apis/python/tests/test_formats.py b/apis/python/tests/test_formats.py new file mode 100644 index 000000000..f53f4e6e3 --- /dev/null +++ b/apis/python/tests/test_formats.py @@ -0,0 +1,401 @@ +import os +import platform +import shutil + +import numpy as np +import pandas as pd +import pytest +import tiledb +import tiledbvcf + +from .conftest import assert_dfs_equal, skip_if_no_bcftools, TESTS_INPUT_DIR + +@skip_if_no_bcftools +def test_gvcf_export(tmp_path, bgzip_and_index_vcfs): + """Verify gVCF export reads correct samples with optional IAF filtering and reporting.""" + vcf_files = bgzip_and_index_vcfs( + os.path.join(TESTS_INPUT_DIR, "gvcf-export"), output_dir=str(tmp_path) + ) + + # Ingest the VCFs + uri = os.path.join(tmp_path, "vcf.tdb") + ds = tiledbvcf.Dataset(uri=uri, mode="w") + ds.create_dataset() + ds.ingest_samples(vcf_files) + ds = tiledbvcf.Dataset(uri=uri, mode="r") + + # List of tests. + tests = [ + {"region": "chr1:100-120", "samples": ["s0", "s1", "s2"]}, + {"region": "chr1:110-120", "samples": ["s0", "s1"]}, + {"region": "chr1:149-149", "samples": ["s0", "s1", "s3"]}, + {"region": "chr1:150-150", "samples": ["s0", "s1", "s3", "s4"]}, + ] + + # No IAF filtering or reporting + for test in tests: + df = ds.read(regions=test["region"]) + assert set(df["sample_name"].unique()) == set(test["samples"]) + + attrs = [ + "sample_name", + "contig", + "pos_start", + "alleles", + "fmt_GT", + "info_TILEDB_IAF", + ] + + # IAF reporting + for test in tests: + df = ds.read(attrs=attrs, regions=test["region"]) + assert set(df["sample_name"].unique()) == set(test["samples"]) + + # IAF filtering and reporting + for test in tests: + df = ds.read(attrs=attrs, regions=test["region"], set_af_filter="<=1.0") + assert set(df["sample_name"].unique()) == set(test["samples"]) + + +def test_flag_export(tmp_path): + """Verify that INFO flag attributes (DB, DS) are read correctly from an ingested VCF.""" + # Create the dataset + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + samples = [os.path.join(TESTS_INPUT_DIR, s) for s in ["small.vcf.gz"]] + ds.create_dataset() + ds.ingest_samples(samples) + + # Read info flags + ds = tiledbvcf.Dataset(uri, mode="r") + df = ds.read(attrs=["pos_start", "info_DB", "info_DS"]) + df = df.sort_values(by=["pos_start"]) + + # Check if flags match the expected values + expected_db = [1, 1, 1, 0, 0, 1] + assert df["info_DB"].tolist() == expected_db + + expected_ds = [1, 1, 0, 0, 1, 1] + assert df["info_DS"].tolist() == expected_ds + + +@pytest.mark.parametrize("use_arrow", [False, True], ids=["pandas", "arrow"]) +def test_bed_filestore(tmp_path, v4_dataset, use_arrow): + """Verify reading with a BED file stored as a TileDB Filestore.""" + # Expected DataFrame + expected_df = pd.DataFrame( + { + "sample_name": pd.Series( + [ + "HG00280", + "HG01762", + "HG00280", + "HG01762", + "HG00280", + ] + ), + "pos_start": pd.Series( + [ + 12141, + 12141, + 12546, + 12546, + 17319, + ], + dtype=np.int32, + ), + "pos_end": pd.Series( + [ + 12277, + 12277, + 12771, + 12771, + 17479, + ], + dtype=np.int32, + ), + } + ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + + # Create BED file + bed_file = os.path.join(tmp_path, "test.bed") + + regions = [ + (1, 12000, 13000), + (1, 17000, 17479), + ] + + with open(bed_file, "w") as f: + for region in regions: + f.write(f"{region[0]}\t{region[1]}\t{region[2]}\n") + + # Create BED filestore from BED file + bed_filestore = os.path.join(tmp_path, "test.bed.filestore") + tiledb.Array.create(bed_filestore, tiledb.ArraySchema.from_file(bed_file)) + tiledb.Filestore.copy_from(bed_filestore, bed_file) + + func = v4_dataset.read_arrow if use_arrow else v4_dataset.read + df = func(attrs=["sample_name", "pos_start", "pos_end"], bed_file=bed_filestore) + if use_arrow: + df = df.to_pandas() + assert_dfs_equal( + expected_df, + df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]), + ) + + +@pytest.mark.parametrize("use_arrow", [False, True], ids=["pandas", "arrow"]) +def test_bed_array(tmp_path, v4_dataset, use_arrow): + """Verify reading with a BED file stored as a TileDB sparse array with metadata aliases.""" + expected_df = pd.DataFrame( + { + "sample_name": pd.Series( + [ + "HG00280", + "HG01762", + "HG00280", + "HG01762", + "HG00280", + ] + ), + "pos_start": pd.Series( + [ + 12141, + 12141, + 12546, + 12546, + 17319, + ], + dtype=np.int32, + ), + "pos_end": pd.Series( + [ + 12277, + 12277, + 12771, + 12771, + 17479, + ], + dtype=np.int32, + ), + } + ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + + # Create bed array + bed_array = os.path.join(tmp_path, "bed_array") + tiledb.from_pandas( + bed_array, + pd.DataFrame( + { + "chrom": ["1", "1"], + "chromStart": [12000, 17000], + "chromEnd": [13000, 17479], + } + ), + sparse=True, + index_col=["chrom", "chromStart"], + ) + + # Add aliases to the array metadata + with tiledb.Array(bed_array, "w") as A: + A.meta["alias contig"] = "chrom" + A.meta["alias start"] = "chromStart" + A.meta["alias end"] = "chromEnd" + + func = v4_dataset.read_arrow if use_arrow else v4_dataset.read + df = func(attrs=["sample_name", "pos_start", "pos_end"], bed_file=bed_array) + if use_arrow: + df = df.to_pandas() + + assert_dfs_equal( + expected_df, + df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]), + ) + +def test_info_end(tmp_path): + """Verify info_END is handled correctly even when the VCF header defines END as a string.""" + + expected_end = pd.DataFrame( + { + "pos_end": pd.Series( + [ + 12277, + 12771, + 13374, + 13395, + 13413, + 13451, + 13519, + 13544, + 13689, + 17479, + 17486, + 30553, + 35224, + 35531, + 35786, + 69096, + 69103, + 69104, + 69109, + 69110, + 69111, + 69112, + 69114, + 69115, + 69122, + 69123, + 69128, + 69129, + 69130, + 69192, + 69195, + 69196, + 69215, + 69222, + 69227, + 69228, + 69261, + 69262, + 69269, + 69270, + 69346, + 69349, + 69352, + 69353, + 69370, + 69510, + 69511, + 69760, + 69761, + 69770, + 69834, + 69835, + 69838, + 69861, + 69863, + 69866, + 69896, + 69897, + 69912, + 69938, + 69939, + 69941, + 69946, + 69947, + 69948, + 69949, + 69953, + 70012, + 866511, + 1289369, + ], + dtype=np.int32, + ), + # Expected values are strings because the small3.vcf.gz defines END as a string + "info_END": pd.Series( + [ + "12277", + "12771", + "13374", + "13395", + "13413", + "13451", + "13519", + "13544", + "13689", + "17479", + "17486", + "30553", + "35224", + "35531", + "35786", + "69096", + "69103", + "69104", + "69109", + "69110", + "69111", + "69112", + "69114", + "69115", + "69122", + "69123", + "69128", + "69129", + "69130", + "69192", + "69195", + "69196", + "69215", + "69222", + "69227", + "69228", + "69261", + "69262", + "69269", + None, + "69346", + "69349", + "69352", + "69353", + "69370", + "69510", + None, + "69760", + None, + "69770", + "69834", + "69835", + "69838", + "69861", + "69863", + "69866", + "69896", + None, + "69912", + "69938", + "69939", + "69941", + "69946", + "69947", + "69948", + "69949", + "69953", + "70012", + None, + None, + ], + dtype=object, + ), + } + ) + + # Ingest the data + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + samples = [os.path.join(TESTS_INPUT_DIR, s) for s in ["small3.vcf.gz"]] + ds.create_dataset() + ds.ingest_samples(samples) + + # Read the data + ds = tiledbvcf.Dataset(uri) + df = ds.read(attrs=["sample_name", "pos_start", "pos_end", "info_END"]) + + # Sort the results because VCF uses an unordered reader + df.sort_values(ignore_index=True, by=["sample_name", "pos_start"], inplace=True) + + # Drop the columns that are not used for comparison + df.drop(columns=["sample_name", "pos_start"], inplace=True) + + # Check the results + assert_dfs_equal(df, expected_end) + +def test_equality_old_new_format(): + """Verify that old and new format arrays produce identical counts, samples, and reads.""" + old_ds = tiledbvcf.Dataset(os.path.join(TESTS_INPUT_DIR, "arrays/old_format")) + new_ds = tiledbvcf.Dataset(os.path.join(TESTS_INPUT_DIR, "arrays/new_format")) + + assert old_ds.count() == new_ds.count() + assert old_ds.samples() == new_ds.samples() + assert old_ds.read().equals(new_ds.read()) diff --git a/apis/python/tests/test_ingest.py b/apis/python/tests/test_ingest.py new file mode 100644 index 000000000..8a9c62dc9 --- /dev/null +++ b/apis/python/tests/test_ingest.py @@ -0,0 +1,754 @@ +import os +import platform +import shutil + +import numpy as np +import pandas as pd +import pytest +import tiledb +import tiledbvcf + +from .conftest import ( + assert_dfs_equal, + skip_if_incompatible, + skip_if_no_bcftools, + TESTS_INPUT_DIR, +) + +def test_basic_ingest(tmp_path): + """Verify basic two-sample BCF ingestion and query counts.""" + # Create the dataset + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + samples = [os.path.join(TESTS_INPUT_DIR, s) for s in ["small.bcf", "small2.bcf"]] + ds.create_dataset() + ds.ingest_samples(samples) + + # Open it back in read mode and check some queries + ds = tiledbvcf.Dataset(uri, mode="r") + assert ds.count() == 14 + assert ds.count(regions=["1:12700-13400"]) == 6 + assert ds.count(samples=["HG00280"], regions=["1:12700-13400"]) == 4 + + +def test_disable_ingestion_tasks(tmp_path): + """Verify that disabling stats tasks prevents creation of stats arrays.""" + # Create the dataset + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + samples = [os.path.join(TESTS_INPUT_DIR, s) for s in ["small.bcf", "small3.bcf"]] + ds.create_dataset( + enable_allele_count=False, enable_variant_stats=False, enable_sample_stats=False + ) + ds.ingest_samples(samples) + + # TODO: remove this workaround when sc-19721 is resolved + if platform.system() != "Linux": + return + + # Validate that stats arrays were not created + ac_uri = os.path.join(tmp_path, "dataset", "allele_count") + vs_uri = os.path.join(tmp_path, "dataset", "variant_stats") + ss_uri = os.path.join(tmp_path, "dataset", "sample_stats") + + assert not os.path.exists(ac_uri) + assert not os.path.exists(vs_uri) + assert not os.path.exists(ss_uri) + + +def test_ingestion_tasks(tmp_path): + """Verify that allele_count, variant_stats, and sample_stats arrays are created and populated.""" + # Create the dataset + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + samples = [os.path.join(TESTS_INPUT_DIR, s) for s in ["small.bcf", "small3.bcf"]] + ds.create_dataset(enable_allele_count=True, enable_variant_stats=True) + ds.ingest_samples(samples) + + # TODO: remove this workaround when sc-19721 is resolved + if platform.system() != "Linux": + return + + # Query allele_count array with TileDB + ac_uri = tiledb.Group(uri)["allele_count"].uri + + skip_if_incompatible(ac_uri) + + contig = "1" + region = slice(69896) + with tiledb.open(ac_uri) as A: + df = A.query(attrs=["alt", "count"], dims=["pos"]).df[contig, region] + + assert df["pos"].array == 69896 + assert df["alt"].array == "C" + assert df["count"].array == 1 + + # Query variant_stats array with TileDB + vs_uri = tiledb.Group(uri)["variant_stats"].uri + + contig = "1" + region = slice(12140) + with tiledb.open(vs_uri) as A: + df = A.query(attrs=["allele", "ac"], dims=["pos"]).df[contig, region] + + assert df["pos"].array == 12140 + assert df["allele"].array == "C" + assert df["ac"].array == 4 + + # Test raw sample_stats + + expected_df = pd.DataFrame( + { + "sample": ["HG00280", "HG01762"], + "dp_sum": [879, 64], + "dp_sum2": [56375, 4096], + "dp_count": [68, 2], + "dp_min": [0, 0], + "dp_max": [180, 64], + "gq_sum": [1489, 99], + "gq_sum2": [79129, 9801], + "gq_count": [68, 2], + "gq_min": [0, 0], + "gq_max": [99, 99], + "n_records": [70, 3], + "n_called": [70, 3], + "n_not_called": [0, 0], + "n_hom_ref": [64, 3], + "n_het": [3, 0], + "n_singleton": [4, 0], + "n_snp": [7, 0], + "n_insertion": [2, 0], + "n_deletion": [1, 0], + "n_transition": [6, 0], + "n_transversion": [1, 0], + "n_star": [0, 0], + "n_multiallelic": [5, 0], + } + ).astype("uint64", errors="ignore") + + ss_uri = tiledb.Group(uri)["sample_stats"].uri + with tiledb.open(ss_uri) as A: + df = A.df[:] + + # Convert to uint64 for comparison to expected_df + df = df.astype("uint64", errors="ignore") + + assert df.equals(expected_df) + + # Test sample_qc + expected_qc = pd.DataFrame( + { + "sample": ["HG00280", "HG01762"], + "dp_mean": [12.92647, 32.0], + "dp_stddev": [25.728399, 32.0], + "dp_min": [0, 0], + "dp_max": [180, 64], + "gq_mean": [21.897058, 49.5], + "gq_stddev": [26.156845, 49.5], + "gq_min": [0, 0], + "gq_max": [99, 99], + "call_rate": [1.0, 1.0], + "n_called": [70, 3], + "n_not_called": [0, 0], + "n_hom_ref": [64, 3], + "n_het": [3, 0], + "n_hom_var": [3, 0], + "n_non_ref": [6, 0], + "n_singleton": [4, 0], + "n_snp": [7, 0], + "n_insertion": [2, 0], + "n_deletion": [1, 0], + "n_transition": [6, 0], + "n_transversion": [1, 0], + "n_star": [0, 0], + "r_ti_tv": [6.0, np.nan], + "r_het_hom_var": [1.0, np.nan], + "r_insertion_deletion": [2.0, np.nan], + "n_records": [70, 3], + "n_multiallelic": [5, 0], + } + ) + + qc = tiledbvcf.sample_qc(uri) + assert_dfs_equal(expected_qc, qc) + + +def test_incremental_ingest(tmp_path): + """Verify that samples can be ingested incrementally with the same result.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset() + ds.ingest_samples([os.path.join(TESTS_INPUT_DIR, "small.bcf")]) + ds.ingest_samples([os.path.join(TESTS_INPUT_DIR, "small2.bcf")]) + + # Open it back in read mode and check some queries + ds = tiledbvcf.Dataset(uri, mode="r") + assert ds.count() == 14 + assert ds.count(regions=["1:12700-13400"]) == 6 + assert ds.count(samples=["HG00280"], regions=["1:12700-13400"]) == 4 + + +def test_ingest_disable_merging(tmp_path): + """Verify contig_fragment_merging=False produces identical results to contigs_to_keep_separate.""" + # Create the dataset + uri = os.path.join(tmp_path, "dataset_disable_merging") + + cfg = tiledbvcf.ReadConfig(memory_budget_mb=1024) + attrs = ["sample_name", "contig", "pos_start", "pos_end"] + + ds = tiledbvcf.Dataset(uri, mode="w") + samples = [ + os.path.join(TESTS_INPUT_DIR, s) for s in ["v2-DjrIAzkP-downsampled.vcf.gz"] + ] + ds.create_dataset() + ds.ingest_samples(samples, contig_fragment_merging=False) + + # Open it back in read mode and check some queries + ds = tiledbvcf.Dataset(uri, cfg=cfg, mode="r", verbose=False) + df = ds.read(attrs=attrs) + assert ds.count() == 246 + assert ds.count(regions=["chrX:9032893-9032893"]) == 1 + + # Create the dataset + uri = os.path.join(tmp_path, "dataset_merging_separate") + ds2 = tiledbvcf.Dataset(uri, mode="w", verbose=False) + samples = [ + os.path.join(TESTS_INPUT_DIR, s) for s in ["v2-DjrIAzkP-downsampled.vcf.gz"] + ] + ds2.create_dataset() + ds2.ingest_samples(samples, contigs_to_keep_separate=["chr1"]) + + # Open it back in read mode and check some queries + ds2 = tiledbvcf.Dataset(uri, cfg=cfg, mode="r", verbose=False) + df2 = ds2.read(attrs=attrs) + assert df.equals(df2) + + assert ds.count() == 246 + assert ds.count(regions=["chrX:9032893-9032893"]) == 1 + + +def test_ingest_merging_separate(tmp_path): + """Verify ingestion with contigs_to_keep_separate produces correct counts.""" + # Create the dataset + uri = os.path.join(tmp_path, "dataset_merging_separate") + ds = tiledbvcf.Dataset(uri, mode="w") + samples = [ + os.path.join(TESTS_INPUT_DIR, s) for s in ["v2-DjrIAzkP-downsampled.vcf.gz"] + ] + ds.create_dataset() + ds.ingest_samples(samples, contigs_to_keep_separate=["chr1"]) + + # Open it back in read mode and check some queries + ds = tiledbvcf.Dataset(uri, mode="r") + assert ds.count() == 246 + assert ds.count(regions=["chrX:9032893-9032893"]) == 1 + + +def test_ingest_merging(tmp_path): + """Verify ingestion with contigs_to_allow_merging produces correct counts.""" + # Create the dataset + uri = os.path.join(tmp_path, "dataset_merging") + ds = tiledbvcf.Dataset(uri, mode="w") + samples = [ + os.path.join(TESTS_INPUT_DIR, s) for s in ["v2-DjrIAzkP-downsampled.vcf.gz"] + ] + ds.create_dataset() + ds.ingest_samples(samples, contigs_to_allow_merging=["chr1", "chr2"]) + + # Open it back in read mode and check some queries + ds = tiledbvcf.Dataset(uri, mode="r") + assert ds.count() == 246 + assert ds.count(regions=["chrX:9032893-9032893"]) == 1 + + +def test_ingest_mode_merged(tmp_path): + """Verify contig_mode='merged' ingests only pseudo-contigs.""" + # Create the dataset + uri = os.path.join(tmp_path, "dataset_merging") + ds = tiledbvcf.Dataset(uri, mode="w") + samples = [ + os.path.join(TESTS_INPUT_DIR, s) for s in ["v2-DjrIAzkP-downsampled.vcf.gz"] + ] + ds.create_dataset() + # Ingest only merged contigs (pseudo-contigs) + ds.ingest_samples(samples, contig_mode="merged") + + # Open it back in read mode and check some queries + ds = tiledbvcf.Dataset(uri, mode="r") + assert ds.count() == 19 + assert ds.count(regions=["chrX:9032893-9032893"]) == 0 + + +@skip_if_no_bcftools +def test_ingest_with_stats_v2(tmp_path, bgzip_and_index_vcfs): + """Verify ingestion with v2 stats, AF filtering, scan_all_samples, and allele counts.""" + shutil.copytree( + os.path.join(TESTS_INPUT_DIR, "stats"), os.path.join(tmp_path, "stats") + ) + bgzipped_inputs = bgzip_and_index_vcfs(os.path.join(tmp_path, "stats")) + ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="w") + ds.create_dataset(enable_variant_stats=True, enable_allele_count=True) + ds.ingest_samples(bgzipped_inputs) + ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="r") + sample_names = [os.path.basename(file).split(".")[0] for file in bgzipped_inputs] + data_frame = ds.read( + samples=sample_names, + attrs=["contig", "pos_start", "id", "qual", "info_TILEDB_IAF", "sample_name"], + set_af_filter="<0.2", + ) + assert data_frame.shape == (1, 8) + assert data_frame.query("sample_name == 'second'")["qual"].iloc[0] == pytest.approx( + 343.73 + ) + assert ( + data_frame[data_frame["sample_name"] == "second"]["info_TILEDB_IAF"].iloc[0][0] + == 0.9375 + ) + data_frame = ds.read( + samples=sample_names, + attrs=["contig", "pos_start", "id", "qual", "info_TILEDB_IAF", "sample_name"], + scan_all_samples=True, + ) + assert ( + data_frame[ + (data_frame["sample_name"] == "second") & (data_frame["pos_start"] == 4) + ]["info_TILEDB_IAF"].iloc[0][0] + == 0.9375 + ) + ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="r") + df = ds.read_variant_stats(regions=["chr1:1-10000"]) + assert df.shape == (13, 6) + # read_allele_frequency internally uses the deprecated `region` parameter. + with pytest.warns(DeprecationWarning, match='"region" parameter is deprecated'): + df = tiledbvcf.allele_frequency.read_allele_frequency( + os.path.join(tmp_path, "stats_test"), "chr1:1-10000" + ) + assert df.pos.is_monotonic_increasing + df["an_check"] = (df.ac / df.af).round(0).astype("int32") + assert df.an_check.equals(df.an) + df = ds.read_variant_stats(regions=["chr1:1-10000"]) + assert df.shape == (13, 6) + df = ds.read_allele_count(regions=["chr1:1-10000"]) + assert df.shape == (7, 7) + assert sum(df["pos"] == (0, 1, 1, 2, 2, 2, 3)) == 7 + assert sum(df["count"] == (8, 5, 3, 4, 2, 2, 1)) == 7 + + +@skip_if_no_bcftools +def test_ingest_polyploid(tmp_path, bgzip_and_index_vcfs): + """Smoke Test: Verify ingestion and AF filtering on polyploid VCF data.""" + shutil.copytree( + os.path.join(TESTS_INPUT_DIR, "polyploid"), os.path.join(tmp_path, "polyploid") + ) + bgzipped_inputs = bgzip_and_index_vcfs(os.path.join(tmp_path, "polyploid")) + ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "polyploid_test"), mode="w") + ds.create_dataset(enable_variant_stats=True) + ds.ingest_samples(bgzipped_inputs) + ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "polyploid_test"), mode="r") + sample_names = [os.path.basename(file).split(".")[0] for file in bgzipped_inputs] + data_frame = ds.read( + samples=sample_names, + attrs=["contig", "pos_start", "id", "qual", "info_TILEDB_IAF", "sample_name"], + set_af_filter="<0.8", + ) + + +def test_ingest_mode_separate(tmp_path): + """Verify contig_mode='separate' ingests only non-merged contigs.""" + # Create the dataset + uri = os.path.join(tmp_path, "dataset_merging") + ds = tiledbvcf.Dataset(uri, mode="w") + samples = [ + os.path.join(TESTS_INPUT_DIR, s) for s in ["v2-DjrIAzkP-downsampled.vcf.gz"] + ] + ds.create_dataset() + # Ingest only merged contigs (pseudo-contigs) + ds.ingest_samples( + samples, contigs_to_keep_separate=["chr1"], contig_mode="separate" + ) + + # Open it back in read mode and check some queries + ds = tiledbvcf.Dataset(uri, mode="r") + assert ds.count() == 17 + assert ds.count(regions=["chrX:9032893-9032893"]) == 0 + + +def test_vcf_attrs(tmp_path): + """Verify create_dataset with vcf_attrs populates queryable attributes from a VCF header.""" + # Create the dataset with vcf info and fmt attributes + uri = os.path.join(tmp_path, "vcf_attrs_dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + vcf_uri = os.path.join(TESTS_INPUT_DIR, "v2-DjrIAzkP-downsampled.vcf.gz") + ds.create_dataset(vcf_attrs=vcf_uri) + + # Open it back in read mode and check attributes + ds = tiledbvcf.Dataset(uri, mode="r") + + queryable_attrs = [ + "alleles", + "contig", + "filters", + "fmt", + "fmt_DP", + "fmt_GQ", + "fmt_GT", + "fmt_MIN_DP", + "fmt_PS", + "fmt_SB", + "fmt_STR_MAX_LEN", + "fmt_STR_PERIOD", + "fmt_STR_TIMES", + "fmt_VAR_CONTEXT", + "fmt_VAR_TYPE", + "id", + "info", + "info_AC", + "info_AC_AFR", + "info_AC_AMR", + "info_AC_Adj", + "info_AC_CONSANGUINEOUS", + "info_AC_EAS", + "info_AC_FEMALE", + "info_AC_FIN", + "info_AC_Hemi", + "info_AC_Het", + "info_AC_Hom", + "info_AC_MALE", + "info_AC_NFE", + "info_AC_OTH", + "info_AC_POPMAX", + "info_AC_SAS", + "info_AF", + "info_AF_AFR", + "info_AF_AMR", + "info_AF_Adj", + "info_AF_EAS", + "info_AF_FIN", + "info_AF_NFE", + "info_AF_OTH", + "info_AF_SAS", + "info_AGE_HISTOGRAM_HET", + "info_AGE_HISTOGRAM_HOM", + "info_AN", + "info_AN_AFR", + "info_AN_AMR", + "info_AN_Adj", + "info_AN_CONSANGUINEOUS", + "info_AN_EAS", + "info_AN_FEMALE", + "info_AN_FIN", + "info_AN_MALE", + "info_AN_NFE", + "info_AN_OTH", + "info_AN_POPMAX", + "info_AN_SAS", + "info_BaseQRankSum", + "info_CCC", + "info_CSQ", + "info_ClippingRankSum", + "info_DB", + "info_DOUBLETON_DIST", + "info_DP", + "info_DP_HIST", + "info_DS", + "info_END", + "info_ESP_AC", + "info_ESP_AF_GLOBAL", + "info_ESP_AF_POPMAX", + "info_FS", + "info_GQ_HIST", + "info_GQ_MEAN", + "info_GQ_STDDEV", + "info_HWP", + "info_HaplotypeScore", + "info_Hemi_AFR", + "info_Hemi_AMR", + "info_Hemi_EAS", + "info_Hemi_FIN", + "info_Hemi_NFE", + "info_Hemi_OTH", + "info_Hemi_SAS", + "info_Het_AFR", + "info_Het_AMR", + "info_Het_EAS", + "info_Het_FIN", + "info_Het_NFE", + "info_Het_OTH", + "info_Het_SAS", + "info_Hom_AFR", + "info_Hom_AMR", + "info_Hom_CONSANGUINEOUS", + "info_Hom_EAS", + "info_Hom_FIN", + "info_Hom_NFE", + "info_Hom_OTH", + "info_Hom_SAS", + "info_InbreedingCoeff", + "info_K1_RUN", + "info_K2_RUN", + "info_K3_RUN", + "info_KG_AC", + "info_KG_AF_GLOBAL", + "info_KG_AF_POPMAX", + "info_MLEAC", + "info_MLEAF", + "info_MQ", + "info_MQ0", + "info_MQRankSum", + "info_NCC", + "info_NEGATIVE_TRAIN_SITE", + "info_OLD_VARIANT", + "info_POPMAX", + "info_POSITIVE_TRAIN_SITE", + "info_QD", + "info_ReadPosRankSum", + "info_VQSLOD", + "info_clinvar_conflicted", + "info_clinvar_measureset_id", + "info_clinvar_mut", + "info_clinvar_pathogenic", + "info_culprit", + "pos_end", + "pos_start", + "qual", + "query_bed_end", + "query_bed_line", + "query_bed_start", + "sample_name", + ] + + assert ds.attributes(attr_type="info") == [] + assert ds.attributes(attr_type="fmt") == [] + assert sorted(ds.attributes()) == sorted(queryable_attrs) + + +def test_create_dataset_extra_attrs(tmp_path): + """Verify extra_attrs adds the specified fmt fields as queryable attributes.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset(extra_attrs=["fmt_GT", "fmt_DP"]) + ds.ingest_samples([os.path.join(TESTS_INPUT_DIR, "small.bcf")]) + + ds = tiledbvcf.Dataset(uri, mode="r") + attrs = ds.attributes() + assert "fmt_GT" in attrs + assert "fmt_DP" in attrs + + +def test_create_dataset_extra_attrs_and_vcf_attrs_raises(tmp_path): + """Verify extra_attrs and vcf_attrs cannot be combined.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + vcf_uri = os.path.join(TESTS_INPUT_DIR, "v2-DjrIAzkP-downsampled.vcf.gz") + with pytest.raises(Exception, match="Cannot provide both extra_attrs and vcf_attrs"): + ds.create_dataset(extra_attrs=["fmt_GT"], vcf_attrs=vcf_uri) + + +def test_create_dataset_invalid_checksum_type_raises(tmp_path): + """Verify an unrecognized checksum_type raises before creating the dataset.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + with pytest.raises(Exception, match="Invalid checksum_type"): + ds.create_dataset(checksum_type="crc32") + + +def test_create_dataset_checksum_md5(tmp_path): + """Smoke Test: Verify checksum_type='md5' creates a functional dataset.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset(checksum_type="md5") + ds.ingest_samples([os.path.join(TESTS_INPUT_DIR, "small.bcf")]) + + ds = tiledbvcf.Dataset(uri, mode="r") + assert ds.count() == 3 + + +def test_create_dataset_already_exists_raises(tmp_path): + """Verify create_dataset raises when the dataset already exists.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset() + + ds2 = tiledbvcf.Dataset(uri, mode="w") + with pytest.raises(Exception): + ds2.create_dataset() + + +def test_create_dataset_tile_capacity(tmp_path): + """Smoke Test: Verify a custom tile_capacity creates a functional dataset.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset(tile_capacity=100) + ds.ingest_samples([os.path.join(TESTS_INPUT_DIR, "small.bcf")]) + + ds = tiledbvcf.Dataset(uri, mode="r") + assert ds.count() == 3 + + +def test_create_dataset_anchor_gap(tmp_path): + """Smoke Test: Verify a custom anchor_gap creates a functional dataset.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset(anchor_gap=500) + ds.ingest_samples([os.path.join(TESTS_INPUT_DIR, "small.bcf")]) + + ds = tiledbvcf.Dataset(uri, mode="r") + assert ds.count() == 3 + + +def test_create_dataset_allow_duplicates_false(tmp_path): + """Smoke Test: Verify allow_duplicates=False creates a functional dataset.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset(allow_duplicates=False) + ds.ingest_samples([os.path.join(TESTS_INPUT_DIR, "small.bcf")]) + + ds = tiledbvcf.Dataset(uri, mode="r") + assert ds.count() == 3 + + +def test_create_dataset_variant_stats_version2(tmp_path): + """Smoke Test: Verify variant_stats_version=2 creates a functional dataset with readable stats.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset(enable_variant_stats=True, variant_stats_version=2) + ds.ingest_samples( + [os.path.join(TESTS_INPUT_DIR, s) for s in ["small.bcf", "small3.bcf"]] + ) + + ds = tiledbvcf.Dataset(uri, mode="r") + df = ds.read_variant_stats(regions=["1:1-200000"]) + assert len(df) > 0 + + +def test_ingest_samples_none_is_noop(tmp_path): + """Verify ingest_samples with no samples is a no-op.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset() + ds.ingest_samples() + ds.ingest_samples(sample_uris=None) + + ds = tiledbvcf.Dataset(uri, mode="r") + assert ds.count() == 0 + + +def test_ingest_samples_scratch_space_path_only_raises(tmp_path): + """Verify scratch_space_path requires scratch_space_size.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset() + with pytest.raises(Exception, match="Must set both scratch_space_path and scratch_space_size"): + ds.ingest_samples( + [os.path.join(TESTS_INPUT_DIR, "small.bcf")], + scratch_space_path=str(tmp_path), + ) + + +def test_ingest_samples_scratch_space_size_only_raises(tmp_path): + """Verify scratch_space_size requires scratch_space_path.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset() + with pytest.raises(Exception, match="Must set both scratch_space_path and scratch_space_size"): + ds.ingest_samples( + [os.path.join(TESTS_INPUT_DIR, "small.bcf")], + scratch_space_size=1024, + ) + + +def test_ingest_samples_invalid_contig_mode_raises(tmp_path): + """Verify an unrecognized contig_mode raises before ingestion.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset() + with pytest.raises(Exception, match="contig_mode must be"): + ds.ingest_samples( + [os.path.join(TESTS_INPUT_DIR, "small.bcf")], + contig_mode="invalid", + ) + + +def test_ingest_samples_contigs_to_keep_separate_not_list_raises(tmp_path): + """Verify contigs_to_keep_separate must be a list.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset() + with pytest.raises(Exception, match="contigs_to_keep_separate must be a list"): + ds.ingest_samples( + [os.path.join(TESTS_INPUT_DIR, "small.bcf")], + contigs_to_keep_separate="1", + ) + + +def test_ingest_samples_contigs_to_allow_merging_not_list_raises(tmp_path): + """Verify contigs_to_allow_merging must be a list.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset() + with pytest.raises(Exception, match="contigs_to_allow_merging must be a list"): + ds.ingest_samples( + [os.path.join(TESTS_INPUT_DIR, "small.bcf")], + contigs_to_allow_merging="1", + ) + + +def test_ingest_samples_resume(tmp_path): + """Smoke Test: Verify resume=True produces the same result as a normal ingest.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset() + ds.ingest_samples( + [os.path.join(TESTS_INPUT_DIR, "small.bcf")], + resume=True, + ) + + ds = tiledbvcf.Dataset(uri, mode="r") + assert ds.count() == 3 + + +def test_ingest_samples_sample_batch_size(tmp_path): + """Verify sample_batch_size controls the number of ingestion fragments.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + samples = [os.path.join(TESTS_INPUT_DIR, s) for s in ["small.bcf", "small2.bcf"]] + ds.create_dataset() + ds.ingest_samples(samples, sample_batch_size=1) + + data_uri = tiledb.Group(uri)["data"].uri + assert len(tiledb.array_fragments(data_uri)) == 2 + + +def test_ingest_samples_memory_and_thread_params(tmp_path): + """Smoke Test: Verify memory and thread tuning parameters are accepted.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset() + ds.ingest_samples( + [os.path.join(TESTS_INPUT_DIR, "small.bcf")], + threads=2, + total_memory_budget_mb=512, + ratio_tiledb_memory=0.5, + max_tiledb_memory_mb=256, + input_record_buffer_mb=2, + avg_vcf_record_size=512, + ratio_task_size=0.5, + ratio_output_flush=0.5, + ) + + ds = tiledbvcf.Dataset(uri, mode="r") + assert ds.count() == 3 + + +def test_ingest_samples_total_memory_percentage(tmp_path): + """Smoke Test: Verify total_memory_percentage is accepted.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset() + ds.ingest_samples( + [os.path.join(TESTS_INPUT_DIR, "small.bcf")], + total_memory_percentage=0.5, + ) + + ds = tiledbvcf.Dataset(uri, mode="r") + assert ds.count() == 3 diff --git a/apis/python/tests/test_read.py b/apis/python/tests/test_read.py new file mode 100644 index 000000000..b59b2f3f2 --- /dev/null +++ b/apis/python/tests/test_read.py @@ -0,0 +1,806 @@ +import os + +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest +import tiledbvcf + +from .conftest import assert_dfs_equal, skip_if_incompatible, TESTS_INPUT_DIR + +def test_read_unsupported_regions_type(v3_dataset): + """Verify that unsupported or wrong-dimension regions types raise appropriate errors.""" + unsupported_region = 3.14 + unsupported_type_error = f'"regions" parameter cannot have type: {type(unsupported_region)}' + wrong_dimension_region = np.array([["1:12700-13400"], ["1:12700-13400"]]) + ndarray_wrong_dimension_error = f'"regions" parameter of type {type(wrong_dimension_region)} must be 1-dimensional' + with pytest.raises(Exception, match=unsupported_type_error): + v3_dataset.read(regions=unsupported_region) + with pytest.raises(Exception, match=ndarray_wrong_dimension_error): + v3_dataset.read(regions=wrong_dimension_region) + with pytest.raises(Exception, match=unsupported_type_error): + v3_dataset.read_arrow(regions=unsupported_region) + with pytest.raises(Exception, match=ndarray_wrong_dimension_error): + v3_dataset.read_arrow(regions=wrong_dimension_region) + with pytest.raises(Exception, match=unsupported_type_error): + for variant in v3_dataset.read_iter(regions=unsupported_region): + print(variant) + with pytest.raises(Exception, match=ndarray_wrong_dimension_error): + for variant in v3_dataset.read_iter(regions=wrong_dimension_region): + print(variant) + + +def test_read_attrs(v3_dataset_with_attrs): + """Verify that read() returns only the requested attributes as columns.""" + attrs = ["sample_name"] + df = v3_dataset_with_attrs.read(attrs=attrs) + assert df.columns.values.tolist() == attrs + + attrs = ["sample_name", "fmt_GT"] + df = v3_dataset_with_attrs.read(attrs=attrs) + assert df.columns.values.tolist() == attrs + + attrs = ["sample_name"] + df = v3_dataset_with_attrs.read(attrs=attrs) + assert df.columns.values.tolist() == attrs + + +@pytest.mark.parametrize("use_arrow", [False, True], ids=["pandas", "arrow"]) +def test_basic_reads(v3_dataset, use_arrow): + """Verify basic reads with region, sample, and format filters via both pandas and Arrow.""" + expected_df = pd.DataFrame( + { + "sample_name": pd.Series( + [ + "HG00280", + "HG01762", + "HG00280", + "HG01762", + "HG00280", + "HG01762", + "HG00280", + "HG00280", + "HG00280", + "HG00280", + "HG00280", + "HG00280", + "HG00280", + "HG00280", + ] + ), + "pos_start": pd.Series( + [ + 12141, + 12141, + 12546, + 12546, + 13354, + 13354, + 13375, + 13396, + 13414, + 13452, + 13520, + 13545, + 17319, + 17480, + ], + dtype=np.int32, + ), + "pos_end": pd.Series( + [ + 12277, + 12277, + 12771, + 12771, + 13374, + 13389, + 13395, + 13413, + 13451, + 13519, + 13544, + 13689, + 17479, + 17486, + ], + dtype=np.int32, + ), + } + ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + + func = v3_dataset.read_arrow if use_arrow else v3_dataset.read + df = func(attrs=["sample_name", "pos_start", "pos_end"]) + if use_arrow: + df = df.to_pandas() + assert_dfs_equal( + expected_df, + df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]), + ) + + # Region intersection + df = v3_dataset.read( + attrs=["sample_name", "pos_start", "pos_end"], regions=["1:12700-13400"] + ) + expected_df = pd.DataFrame( + { + "sample_name": pd.Series( + ["HG00280", "HG01762", "HG00280", "HG01762", "HG00280", "HG00280"] + ), + "pos_start": pd.Series( + [12546, 12546, 13354, 13354, 13375, 13396], dtype=np.int32 + ), + "pos_end": pd.Series( + [12771, 12771, 13374, 13389, 13395, 13413], dtype=np.int32 + ), + } + ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + assert_dfs_equal( + expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + ) + df = v3_dataset.read_arrow( + attrs=["sample_name", "pos_start", "pos_end"], regions=["1:12700-13400"] + ).to_pandas() + assert_dfs_equal( + expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + ) + + # Regions as string + df = v3_dataset.read( + attrs=["sample_name", "pos_start", "pos_end"], regions="1:12700-13400" + ) + assert_dfs_equal( + expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + ) + df = v3_dataset.read_arrow( + attrs=["sample_name", "pos_start", "pos_end"], regions="1:12700-13400" + ).to_pandas() + assert_dfs_equal( + expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + ) + + # Regions as numpy.ndarray + df = v3_dataset.read( + attrs=["sample_name", "pos_start", "pos_end"], regions=np.array(["1:12700-13400"]) + ) + assert_dfs_equal( + expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + ) + df = v3_dataset.read_arrow( + attrs=["sample_name", "pos_start", "pos_end"], regions=np.array(["1:12700-13400"]) + ).to_pandas() + assert_dfs_equal( + expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + ) + + # Region and sample intersection + df = v3_dataset.read( + attrs=["sample_name", "pos_start", "pos_end"], + regions=["1:12700-13400"], + samples=["HG01762"], + ) + expected_df = pd.DataFrame( + { + "sample_name": pd.Series(["HG01762", "HG01762"]), + "pos_start": pd.Series([12546, 13354], dtype=np.int32), + "pos_end": pd.Series([12771, 13389], dtype=np.int32), + } + ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + assert_dfs_equal( + expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + ) + + # Sample only + df = v3_dataset.read( + attrs=["sample_name", "pos_start", "pos_end"], samples=["HG01762"] + ) + expected_df = pd.DataFrame( + { + "sample_name": pd.Series(["HG01762", "HG01762", "HG01762"]), + "pos_start": pd.Series([12141, 12546, 13354], dtype=np.int32), + "pos_end": pd.Series([12277, 12771, 13389], dtype=np.int32), + } + ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + assert_dfs_equal( + expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + ) + + +def test_bad_attr_raises_exception(v3_dataset): + """Verify that read() raises RuntimeError for an unknown attribute name.""" + with pytest.raises(RuntimeError): + v3_dataset.read(attrs=["abcde"], regions=["1:12700-13400"]) + + +def test_incomplete_reads(): + """Verify incomplete reads with low memory budget and continue_read for pandas and Arrow.""" + # Using undocumented "0 MB" budget to test incomplete reads. + uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") + cfg = tiledbvcf.ReadConfig(memory_budget_mb=0) + v3_dataset = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + + df = v3_dataset.read(attrs=["pos_end"], regions=["1:12700-13400"]) + assert not v3_dataset.read_completed() + assert len(df) == 2 + assert_dfs_equal( + pd.DataFrame.from_dict({"pos_end": np.array([12771, 12771], dtype=np.int32)}), + df, + ) + + df = v3_dataset.continue_read() + assert not v3_dataset.read_completed() + assert len(df) == 2 + assert_dfs_equal( + pd.DataFrame.from_dict({"pos_end": np.array([13374, 13389], dtype=np.int32)}), + df, + ) + + df = v3_dataset.continue_read() + assert v3_dataset.read_completed() + assert len(df) == 2 + assert_dfs_equal( + pd.DataFrame.from_dict({"pos_end": np.array([13395, 13413], dtype=np.int32)}), + df, + ) + + # test incomplete via read_arrow + table = v3_dataset.read_arrow(attrs=["pos_end"], regions=["1:12700-13400"]) + assert not v3_dataset.read_completed() + assert len(table) == 2 + assert_dfs_equal( + pd.DataFrame.from_dict({"pos_end": np.array([12771, 12771], dtype=np.int32)}), + table.to_pandas(), + ) + + table = v3_dataset.continue_read_arrow() + assert not v3_dataset.read_completed() + assert len(table) == 2 + assert_dfs_equal( + pd.DataFrame.from_dict({"pos_end": np.array([13374, 13389], dtype=np.int32)}), + table.to_pandas(), + ) + + table = v3_dataset.continue_read_arrow() + assert v3_dataset.read_completed() + assert len(table) == 2 + assert_dfs_equal( + pd.DataFrame.from_dict({"pos_end": np.array([13395, 13413], dtype=np.int32)}), + table.to_pandas(), + ) + + +def test_continue_read_release_buffers_false(): + """Verify continue_read(release_buffers=False) accumulates previous batches.""" + # Using undocumented "0 MB" budget to force batched reads. + uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") + cfg = tiledbvcf.ReadConfig(memory_budget_mb=0) + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + + df = ds.read(attrs=["pos_end"], regions=["1:12700-13400"]) + assert not ds.read_completed() + assert list(df["pos_end"]) == [12771, 12771] + + # With release_buffers=False the previous buffer is not cleared. + # The result contains the unreleased batch alongside the new batch + # as two columns both named "pos_end". + df = ds.continue_read(release_buffers=False) + assert not ds.read_completed() + assert df.columns.tolist() == ["pos_end", "pos_end"] + assert list(df.iloc[:, 0]) == [12771, 12771] # previous batch (unreleased) + assert list(df.iloc[:, 1]) == [13374, 13389] # new batch + + df = ds.continue_read(release_buffers=False) + assert ds.read_completed() + # Both previous unreleased batches accumulate alongside the new one. + assert df.columns.tolist() == ["pos_end", "pos_end", "pos_end"] + assert list(df.iloc[:, 0]) == [12771, 12771] # batch 1 (still unreleased) + assert list(df.iloc[:, 1]) == [13374, 13389] # batch 2 (unreleased) + assert list(df.iloc[:, 2]) == [13395, 13413] # new batch + + +def test_continue_read_arrow_release_buffers_false(): + """Verify continue_read_arrow(release_buffers=False) accumulates previous batches.""" + # Using undocumented "0 MB" budget to force batched reads. + uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") + cfg = tiledbvcf.ReadConfig(memory_budget_mb=0) + ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + + table = ds.read_arrow(attrs=["pos_end"], regions=["1:12700-13400"]) + assert not ds.read_completed() + assert table.column("pos_end").to_pylist() == [12771, 12771] + + # With release_buffers=False the previous buffer is not cleared. + # The result contains the unreleased batch alongside the new batch + # as two columns both named "pos_end". + table = ds.continue_read_arrow(release_buffers=False) + assert not ds.read_completed() + assert table.schema.names == ["pos_end", "pos_end"] + assert table.column(0).to_pylist() == [12771, 12771] # previous batch (unreleased) + assert table.column(1).to_pylist() == [13374, 13389] # new batch + + table = ds.continue_read_arrow(release_buffers=False) + assert ds.read_completed() + # Both previous unreleased batches accumulate alongside the new one. + assert table.schema.names == ["pos_end", "pos_end", "pos_end"] + assert table.column(0).to_pylist() == [12771, 12771] # batch 1 (still unreleased) + assert table.column(1).to_pylist() == [13374, 13389] # batch 2 (unreleased) + assert table.column(2).to_pylist() == [13395, 13413] # new batch + + +def test_incomplete_read_generator(): + """Verify read_iter() yields all batches across string, list, and ndarray region types.""" + # Using undocumented "0 MB" budget to test incomplete reads. + uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") + cfg = tiledbvcf.ReadConfig(memory_budget_mb=0) + v3_dataset = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) + expected_df = pd.DataFrame.from_dict( + { + "pos_end": np.array( + [12771, 12771, 13374, 13389, 13395, 13413], dtype=np.int32 + ) + } + ) + + # NOTE: Running multiple test shows that the iterator can be reused + + # Regions as string + dfs = [] + for df in v3_dataset.read_iter(attrs=["pos_end"], regions="1:12700-13400"): + dfs.append(df) + overall_df = pd.concat(dfs, ignore_index=True) + assert len(overall_df) == 6 + assert_dfs_equal(expected_df, overall_df) + + # Regions as list + dfs = [] + for df in v3_dataset.read_iter(attrs=["pos_end"], regions=["1:12700-13400"]): + dfs.append(df) + overall_df = pd.concat(dfs, ignore_index=True) + assert len(overall_df) == 6 + assert_dfs_equal(expected_df, overall_df) + + # Regions as numpy.ndarray + dfs = [] + for df in v3_dataset.read_iter(attrs=["pos_end"], regions=np.array(["1:12700-13400"])): + dfs.append(df) + overall_df = pd.concat(dfs, ignore_index=True) + assert len(overall_df) == 6 + assert_dfs_equal(expected_df, overall_df) + + +def test_read_iter_samples_file(tmp_path, v3_dataset): + """Verify read_iter can be filtered by a samples file.""" + samples_file = str(tmp_path / "samples.txt") + with open(samples_file, "w") as f: + f.write("HG00280\n") + + dfs = [] + for df in v3_dataset.read_iter(attrs=["sample_name"], samples_file=samples_file): + dfs.append(df) + result = pd.concat(dfs, ignore_index=True) + assert set(result["sample_name"]) == {"HG00280"} + + +def test_read_iter_bed_file(tmp_path, v3_dataset): + """Verify read_iter can be filtered by a BED file.""" + bed_file = str(tmp_path / "regions.bed") + with open(bed_file, "w") as f: + f.write("1\t12700\t13400\n") + + dfs = [] + for df in v3_dataset.read_iter(attrs=["pos_end"], bed_file=bed_file): + dfs.append(df) + result = pd.concat(dfs, ignore_index=True) + assert len(result) == 6 + + +def test_read_iter_samples(v3_dataset): + """Verify read_iter can be filtered to specific samples.""" + dfs = [] + for df in v3_dataset.read_iter(attrs=["sample_name"], samples=["HG01762"]): + dfs.append(df) + result = pd.concat(dfs, ignore_index=True) + assert set(result["sample_name"]) == {"HG01762"} + + +def test_read_filters(v3_dataset): + """Verify that the filters attribute is read correctly, including LowQual entries.""" + df = v3_dataset.read( + attrs=["sample_name", "pos_start", "pos_end", "filters"], + regions=["1:12700-13400"], + ) + expected_df = pd.DataFrame( + { + "sample_name": pd.Series( + ["HG00280", "HG01762", "HG00280", "HG01762", "HG00280", "HG00280"] + ), + "pos_start": pd.Series( + [12546, 12546, 13354, 13354, 13375, 13396], dtype=np.int32 + ), + "pos_end": pd.Series( + [12771, 12771, 13374, 13389, 13395, 13413], dtype=np.int32 + ), + "filters": pd.Series( + map( + lambda lst: np.array(lst, dtype=object), + [None, None, ["LowQual"], None, None, None], + ) + ), + } + ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + assert_dfs_equal( + expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + ) + + +def test_read_var_length_filters(tmp_path): + """Verify reading variable-length filter arrays with multiple filter values per record.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + samples = [os.path.join(TESTS_INPUT_DIR, s) for s in ["varLenFilter.vcf.gz"]] + ds.create_dataset() + ds.ingest_samples(samples) + + ds = tiledbvcf.Dataset(uri, mode="r") + df = ds.read(["pos_start", "filters"]) + + expected_df = pd.DataFrame( + { + "pos_start": pd.Series( + [ + 12141, + 12546, + 13354, + 13375, + 13396, + 13414, + 13452, + 13520, + 13545, + 17319, + 17480, + ], + dtype=np.int32, + ), + "filters": pd.Series( + map( + lambda lst: np.array(lst, dtype=object), + [ + ["PASS"], + ["PASS"], + ["ANEUPLOID", "LowQual"], + ["PASS"], + ["PASS"], + ["ANEUPLOID", "LOWQ", "LowQual"], + ["PASS"], + ["PASS"], + ["PASS"], + ["LowQual"], + ["PASS"], + ], + ) + ), + } + ).sort_values(ignore_index=True, by=["pos_start"]) + + assert_dfs_equal(expected_df, df.sort_values(ignore_index=True, by=["pos_start"])) + + +def test_read_alleles(v3_dataset): + """Verify that the alleles attribute returns correct ref/alt arrays for each record.""" + df = v3_dataset.read( + attrs=["sample_name", "pos_start", "pos_end", "alleles"], + regions=["1:12100-13360", "1:13500-17350"], + ) + expected_df = pd.DataFrame( + { + "sample_name": pd.Series( + [ + "HG00280", + "HG01762", + "HG00280", + "HG01762", + "HG00280", + "HG01762", + "HG00280", + "HG00280", + "HG00280", + "HG00280", + ] + ), + "pos_start": pd.Series( + [12141, 12141, 12546, 12546, 13354, 13354, 13452, 13520, 13545, 17319], + dtype=np.int32, + ), + "pos_end": pd.Series( + [12277, 12277, 12771, 12771, 13374, 13389, 13519, 13544, 13689, 17479], + dtype=np.int32, + ), + "alleles": pd.Series( + map( + lambda lst: np.array(lst, dtype=object), + [ + ["C", ""], + ["C", ""], + ["G", ""], + ["G", ""], + ["T", ""], + ["T", ""], + ["G", ""], + ["G", ""], + ["G", ""], + ["T", ""], + ], + ) + ), + } + ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + assert_dfs_equal( + expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + ) + + +def test_read_multiple_alleles(tmp_path): + """Verify reading records with multiple alternate alleles from a multi-sample dataset.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + samples = [os.path.join(TESTS_INPUT_DIR, s) for s in ["small3.bcf", "small.bcf"]] + ds.create_dataset() + ds.ingest_samples(samples) + + ds = tiledbvcf.Dataset(uri, mode="r") + df = ds.read( + attrs=["sample_name", "pos_start", "alleles", "id", "filters"], + regions=["1:70100-1300000"], + ) + expected_df = pd.DataFrame( + { + "sample_name": pd.Series(["HG00280", "HG00280"]), + "pos_start": pd.Series([866511, 1289367], dtype=np.int32), + "alleles": pd.Series( + map( + lambda lst: np.array(lst, dtype=object), + [["T", "CCCCTCCCT", "C", "CCCCTCCCTCCCT", "CCCCT"], ["CTG", "C"]], + ) + ), + "id": pd.Series([".", "rs1497816"]), + "filters": pd.Series( + map( + lambda lst: np.array(lst, dtype=object), + [["LowQual"], ["LowQual"]], + ) + ), + } + ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + assert_dfs_equal( + expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + ) + + +def test_read_var_len_attrs(v3_dataset): + """Verify reading variable-length format attributes (DP, PL) with region filtering.""" + df = v3_dataset.read( + attrs=["sample_name", "pos_start", "pos_end", "fmt_DP", "fmt_PL"], + regions=["1:12100-13360", "1:13500-17350"], + ) + expected_df = pd.DataFrame( + { + "sample_name": pd.Series( + [ + "HG00280", + "HG01762", + "HG00280", + "HG01762", + "HG00280", + "HG01762", + "HG00280", + "HG00280", + "HG00280", + "HG00280", + ] + ), + "pos_start": pd.Series( + [12141, 12141, 12546, 12546, 13354, 13354, 13452, 13520, 13545, 17319], + dtype=np.int32, + ), + "pos_end": pd.Series( + [12277, 12277, 12771, 12771, 13374, 13389, 13519, 13544, 13689, 17479], + dtype=np.int32, + ), + "fmt_DP": pd.Series([0, 0, 0, 0, 15, 64, 10, 6, 0, 0], dtype=np.int32), + "fmt_PL": pd.Series( + map( + lambda lst: np.array(lst, dtype=np.int32), + [ + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 24, 360], + [0, 66, 990], + [0, 21, 210], + [0, 6, 90], + [0, 0, 0], + [0, 0, 0], + ], + ) + ), + } + ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + + assert_dfs_equal( + expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + ) + + +def test_sample_args(v3_dataset, tmp_path): + """Verify that samples= and samples_file= produce equivalent results and cannot be combined.""" + sample_file = os.path.join(tmp_path, "1_sample.txt") + with open(sample_file, "w") as file: + file.write("HG00280") + + region = ["1:12141-12141"] + df1 = v3_dataset.read(["sample_name"], regions=region, samples=["HG00280"]) + df2 = v3_dataset.read(["sample_name"], regions=region, samples_file=sample_file) + assert_dfs_equal(df1, df2) + + with pytest.raises(TypeError): + v3_dataset.read( + attrs=["sample_name"], + regions=region, + samples=["HG00280"], + samples_file=sample_file, + ) + + +def test_read_arrow_samples(v3_dataset): + """Verify read_arrow can be filtered to specific samples.""" + tbl = v3_dataset.read_arrow( + attrs=["sample_name", "pos_start", "pos_end"], + regions=["1:12700-13400"], + samples=["HG01762"], + ) + df = tbl.to_pandas() + assert set(df["sample_name"]) == {"HG01762"} + assert len(df) == 2 + + +def test_read_arrow_samples_file(tmp_path, v3_dataset): + """Verify read_arrow can be filtered by a samples file.""" + samples_file = str(tmp_path / "samples.txt") + with open(samples_file, "w") as f: + f.write("HG00280\n") + + tbl = v3_dataset.read_arrow(attrs=["sample_name"], samples_file=samples_file) + assert set(tbl.column("sample_name").to_pylist()) == {"HG00280"} + + +def test_read_bed_file(tmp_path, v3_dataset): + """Verify read and read_arrow can be filtered by a BED file.""" + bed_file = str(tmp_path / "regions.bed") + with open(bed_file, "w") as f: + f.write("1\t12700\t13400\n") + + df = v3_dataset.read(attrs=["pos_end"], bed_file=bed_file) + assert len(df) == 6 + + tbl = v3_dataset.read_arrow(attrs=["pos_end"], bed_file=bed_file) + assert tbl.num_rows == 6 + + +def test_read_null_attrs(tmp_path): + """Verify that nullable info and fmt attributes return None for missing values.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + samples = [os.path.join(TESTS_INPUT_DIR, s) for s in ["small3.bcf", "small.bcf"]] + ds.create_dataset() + ds.ingest_samples(samples) + + ds = tiledbvcf.Dataset(uri, mode="r") + df = ds.read( + attrs=[ + "sample_name", + "pos_start", + "pos_end", + "info_BaseQRankSum", + "info_DP", + "fmt_DP", + "fmt_MIN_DP", + ], + regions=["1:12700-13400", "1:69500-69800"], + ) + expected_df = pd.DataFrame( + { + "sample_name": pd.Series( + [ + "HG00280", + "HG00280", + "HG00280", + "HG00280", + "HG01762", + "HG01762", + "HG00280", + "HG00280", + "HG00280", + "HG00280", + "HG00280", + "HG00280", + ] + ), + "pos_start": pd.Series( + [ + 12546, + 13354, + 13375, + 13396, + 12546, + 13354, + 69371, + 69511, + 69512, + 69761, + 69762, + 69771, + ], + dtype=np.int32, + ), + "pos_end": pd.Series( + [ + 12771, + 13374, + 13395, + 13413, + 12771, + 13389, + 69510, + 69511, + 69760, + 69761, + 69770, + 69834, + ], + dtype=np.int32, + ), + "info_BaseQRankSum": pd.Series( + [ + None, + None, + None, + None, + None, + None, + None, + np.array([-0.787], dtype=np.float32), + None, + np.array([1.97], dtype=np.float32), + None, + None, + ] + ), + "info_DP": pd.Series( + [ + None, + None, + None, + None, + None, + None, + None, + np.array([89], dtype=np.int32), + None, + np.array([24], dtype=np.int32), + None, + None, + ] + ), + "fmt_DP": pd.Series( + [0, 15, 6, 2, 0, 64, 180, 88, 97, 24, 23, 21], dtype=np.int32 + ), + "fmt_MIN_DP": pd.Series([0, 14, 3, 1, 0, 30, 20, None, 24, None, 23, 19]), + } + ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + assert_dfs_equal( + expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) + ) + + diff --git a/apis/python/tests/test_stats.py b/apis/python/tests/test_stats.py new file mode 100644 index 000000000..110d776be --- /dev/null +++ b/apis/python/tests/test_stats.py @@ -0,0 +1,405 @@ +import os +import platform +import shutil + +import pandas as pd +import pyarrow as pa +import pytest +import tiledbvcf + +from .conftest import skip_if_no_bcftools, TESTS_INPUT_DIR, assert_dfs_equal + +@skip_if_no_bcftools +def test_read_with_af_filter(stats_v3_dataset, stats_sample_names): + """Verify that set_af_filter restricts results by allele frequency for both pandas and Arrow.""" + attrs = ["contig", "pos_start", "id", "qual", "info_TILEDB_IAF", "sample_name"] + df = stats_v3_dataset.read( + samples=stats_sample_names, + attrs=attrs, + set_af_filter="<0.2", + ) + assert df.shape == (1, 8) + assert df.query("sample_name == 'second'")["qual"].iloc[0] == pytest.approx(343.73) + assert df[df["sample_name"] == "second"]["info_TILEDB_IAF"].iloc[0][0] == 0.9375 + + tbl = stats_v3_dataset.read_arrow( + samples=stats_sample_names, + attrs=attrs, + set_af_filter="<0.2", + ) + assert tbl.num_rows == 1 + assert tbl.to_pandas().equals(df) + + +@skip_if_no_bcftools +def test_read_with_scan_all_samples(stats_v3_dataset, stats_sample_names): + """Verify scan_all_samples normalizes IAF across all samples for both pandas and Arrow.""" + attrs = ["contig", "pos_start", "id", "qual", "info_TILEDB_IAF", "sample_name"] + df = stats_v3_dataset.read( + samples=stats_sample_names, + attrs=attrs, + scan_all_samples=True, + ) + assert ( + df[(df["sample_name"] == "second") & (df["pos_start"] == 4)][ + "info_TILEDB_IAF" + ].iloc[0][0] + == 0.9375 + ) + + tbl = stats_v3_dataset.read_arrow( + samples=stats_sample_names, + attrs=attrs, + scan_all_samples=True, + ) + assert tbl.num_rows == len(df) + assert tbl.to_pandas().equals(df) + + +@skip_if_no_bcftools +def test_read_with_af_filter_and_scan_all_samples(stats_v3_dataset, stats_sample_names): + """Verify set_af_filter and scan_all_samples can be combined to widen the result set.""" + attrs = ["contig", "pos_start", "sample_name"] + + df_filter_only = stats_v3_dataset.read( + samples=stats_sample_names, + attrs=attrs, + set_af_filter="<0.2", + ) + + df = stats_v3_dataset.read( + samples=stats_sample_names, + attrs=attrs, + set_af_filter="<0.2", + scan_all_samples=True, + ) + assert len(df) > len(df_filter_only) + + tbl = stats_v3_dataset.read_arrow( + samples=stats_sample_names, + attrs=attrs, + set_af_filter="<0.2", + scan_all_samples=True, + ) + assert tbl.to_pandas().equals(df) + + +@skip_if_no_bcftools +def test_variant_stats_parameter_errors(stats_v3_dataset): + """Verify that read_variant_stats and read_variant_stats_arrow reject invalid parameters.""" + no_region = '"region" or "regions" parameter is required' + exclusive = '"region" and "regions" parameters are mutually exclusive' + bad_format = '"region" parameter must have format ":-"' + empty_contig = "Region contig cannot be empty" + base_1 = "Regions must be 1-based" + bad_interval = '"100-1" is not a valid region interval' + + for fn in [stats_v3_dataset.read_variant_stats, stats_v3_dataset.read_variant_stats_arrow]: + with pytest.raises(Exception, match=no_region): + fn() + with pytest.raises(Exception, match=exclusive): + fn("chr1:1-100", regions=["chr1:1-100"]) + with pytest.raises(Exception, match=bad_format): + fn(regions=[""]) + with pytest.raises(Exception, match=bad_format): + fn(regions=["chr1"]) + with pytest.raises(Exception, match=bad_format): + fn(regions=["chr1:-"]) + with pytest.raises(Exception, match=empty_contig): + fn(regions=[":1-100"]) + with pytest.raises(Exception, match=base_1): + fn(regions=["chr1:0-100"]) + with pytest.raises(Exception, match=bad_interval): + fn(regions=["chr1:100-1"]) + + +@skip_if_no_bcftools +def test_variant_stats_empty_region(stats_v3_dataset): + """Verify read_variant_stats returns an empty DataFrame for a region with no variants.""" + assert stats_v3_dataset.read_variant_stats(regions=["chr3:1-10000"]).empty + + +@skip_if_no_bcftools +def test_variant_stats_return_types(stats_v3_dataset): + """Verify read_variant_stats returns a DataFrame and read_variant_stats_arrow returns an Arrow Table.""" + # Both the deprecated positional `region` parameter and the `regions` list + # should return a DataFrame / Arrow Table of the same shape and content. + region = "chr1:1-10000" + with pytest.warns(DeprecationWarning, match='"region" parameter is deprecated'): + for kwargs in [{"region": region}, {"regions": [region]}]: + # Workaround: read_variant_stats takes region as positional-or-keyword + if "region" in kwargs: + df = stats_v3_dataset.read_variant_stats(kwargs["region"]) + tbl = stats_v3_dataset.read_variant_stats_arrow(kwargs["region"]) + else: + df = stats_v3_dataset.read_variant_stats(**kwargs) + tbl = stats_v3_dataset.read_variant_stats_arrow(**kwargs) + assert isinstance(df, pd.DataFrame) + assert isinstance(tbl, pa.Table) + assert df.shape == (13, 6) + assert df.equals(tbl.to_pandas()) + + +@skip_if_no_bcftools +def test_variant_stats_multi_contig_regions(stats_v3_dataset): + """Verify read_variant_stats handles multiple contig regions and sorts results by contig.""" + # Results are always returned in contig-sorted order regardless of input order. + region_chr1 = "chr1:1-10000" + region_chr2 = "chr2:1-10000" + expected_contigs = ["chr1"] * 13 + ["chr2"] * 2 + + df = stats_v3_dataset.read_variant_stats(regions=[region_chr1, region_chr2]) + assert df.shape == (15, 6) + assert expected_contigs == list(df["contig"].values) + + df_reversed = stats_v3_dataset.read_variant_stats(regions=[region_chr2, region_chr1]) + assert df.equals(df_reversed) + + tbl = stats_v3_dataset.read_variant_stats_arrow(regions=[region_chr1, region_chr2]) + tbl_reversed = stats_v3_dataset.read_variant_stats_arrow(regions=[region_chr2, region_chr1]) + assert tbl.equals(tbl_reversed) + assert df.equals(tbl.to_pandas()) + + +@skip_if_no_bcftools +def test_variant_stats_overlapping_regions(stats_v3_dataset): + """Verify read_variant_stats deduplicates and merges overlapping regions on the same contig.""" + # Overlapping regions on the same contig are merged; results are deduped and sorted. + expected_contigs = ["chr1"] * 13 + ["chr2"] * 2 + + assert stats_v3_dataset.read_variant_stats(regions=["chr1:1-1"]).shape == (2, 6) + assert stats_v3_dataset.read_variant_stats(regions=["chr1:1-2"]).shape == (5, 6) + assert stats_v3_dataset.read_variant_stats(regions=["chr1:3-4"]).shape == (6, 6) + assert stats_v3_dataset.read_variant_stats(regions=["chr1:2-5"]).shape == (11, 6) + + regions_chr1 = ["chr1:1-1", "chr1:1-2", "chr1:3-4", "chr1:2-5"] + df = stats_v3_dataset.read_variant_stats(regions=regions_chr1) + assert df.shape == (13, 6) + assert df.equals(stats_v3_dataset.read_variant_stats(regions=reversed(regions_chr1))) + + assert stats_v3_dataset.read_variant_stats(regions=["chr2:1-1"]).shape == (1, 6) + assert stats_v3_dataset.read_variant_stats(regions=["chr2:3-3"]).shape == (1, 6) + + regions_chr2 = ["chr2:1-1", "chr2:3-3"] + df = stats_v3_dataset.read_variant_stats(regions=regions_chr2) + assert df.shape == (2, 6) + assert df.equals(stats_v3_dataset.read_variant_stats(regions=reversed(regions_chr2))) + + for regions in [regions_chr1 + regions_chr2, regions_chr2 + regions_chr1]: + df = stats_v3_dataset.read_variant_stats(regions=regions) + assert df.shape == (15, 6) + assert expected_contigs == list(df["contig"].values) + assert df.equals(stats_v3_dataset.read_variant_stats(regions=reversed(regions))) + + +@skip_if_no_bcftools +def test_variant_stats_scan_all_samples(stats_v3_dataset): + """Verify scan_all_samples normalizes allele number (an) across all samples in variant stats.""" + # Without scan_all_samples, an reflects only the queried samples' allele number. + # With scan_all_samples=True, an is normalised across all samples in the dataset. + regions = ["chr2:1-1", "chr2:3-3", "chr1:1-1", "chr1:1-2", "chr1:3-4", "chr1:2-5"] + ac = [8, 8, 5, 6, 5, 4, 4, 4, 4, 1, 15, 1, 2, 2, 2] + + df = stats_v3_dataset.read_variant_stats(regions=regions) + assert ac == list(df["ac"].values) + assert [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 3, 3, 2, 2] == list(df["an"].values) + assert [0.5, 0.5, 0.3125, 0.375, 0.3125, 0.25, 0.25, 0.25, 0.25, 0.0625, 0.9375, + 0.33333334, 0.6666667, 1.0, 1.0] == list(df["af"].values) + + df = stats_v3_dataset.read_variant_stats(regions=regions, scan_all_samples=True) + assert ac == list(df["ac"].values) + assert [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16] == list(df["an"].values) + assert [0.5, 0.5, 0.3125, 0.375, 0.3125, 0.25, 0.25, 0.25, 0.25, 0.0625, 0.9375, + 0.0625, 0.125, 0.125, 0.125] == list(df["af"].values) + + +@skip_if_no_bcftools +def test_variant_stats_drop_ref(stats_v3_dataset): + """Verify drop_ref=True filters out reference allele rows from variant stats.""" + # drop_ref=True filters out rows where the alternate allele is "ref". + regions = ["chr2:1-1", "chr2:3-3", "chr1:1-1", "chr1:1-2", "chr1:3-4", "chr1:2-5"] + + df = stats_v3_dataset.read_variant_stats(regions=regions) + assert ["T,C", "ref", "G,GTTTA", "G,T", "ref", "C,A", "C,G", "C,T", "ref", + "G,GTTTA", "ref", "C,T", "ref", "G,GTTTA", "G,GTTTA"] == list(df["alleles"].values) + + df = stats_v3_dataset.read_variant_stats(regions=regions, drop_ref=True) + assert ["T,C", "G,GTTTA", "G,T", "C,A", "C,G", "C,T", "G,GTTTA", + "C,T", "G,GTTTA", "G,GTTTA"] == list(df["alleles"].values) + + +@skip_if_no_bcftools +def test_allele_count_parameter_errors(stats_v3_dataset): + """Verify that read_allele_count and read_allele_count_arrow reject invalid parameters.""" + no_region = '"region" or "regions" parameter is required' + exclusive = '"region" and "regions" parameters are mutually exclusive' + bad_format = '"region" parameter must have format ":-"' + empty_contig = "Region contig cannot be empty" + base_1 = "Regions must be 1-based" + bad_interval = '"100-1" is not a valid region interval' + + for fn in [stats_v3_dataset.read_allele_count, stats_v3_dataset.read_allele_count_arrow]: + with pytest.raises(Exception, match=no_region): + fn() + with pytest.raises(Exception, match=exclusive): + fn("chr1:1-100", regions=["chr1:1-100"]) + with pytest.raises(Exception, match=bad_format): + fn(regions=[""]) + with pytest.raises(Exception, match=bad_format): + fn(regions=["chr1"]) + with pytest.raises(Exception, match=bad_format): + fn(regions=["chr1:-"]) + with pytest.raises(Exception, match=empty_contig): + fn(regions=[":1-100"]) + with pytest.raises(Exception, match=base_1): + fn(regions=["chr1:0-100"]) + with pytest.raises(Exception, match=bad_interval): + fn(regions=["chr1:100-1"]) + + +@skip_if_no_bcftools +def test_allele_count_empty_region(stats_v3_dataset): + """Verify read_allele_count returns an empty DataFrame for a region with no data.""" + assert stats_v3_dataset.read_allele_count(regions=["chr3:1-10000"]).empty + + +@skip_if_no_bcftools +def test_allele_count_return_types(stats_v3_dataset): + """Verify read_allele_count returns a DataFrame and read_allele_count_arrow returns an Arrow Table.""" + # Both the deprecated positional `region` parameter and the `regions` list + # should return a DataFrame / Arrow Table of the same shape and content. + region = "chr1:1-10000" + expected_pos = (0, 1, 1, 2, 2, 2, 3) + expected_count = (8, 5, 3, 4, 2, 2, 1) + + with pytest.warns(DeprecationWarning, match='"region" parameter is deprecated'): + for kwargs in [{"region": region}, {"regions": [region]}]: + if "region" in kwargs: + df = stats_v3_dataset.read_allele_count(kwargs["region"]) + tbl = stats_v3_dataset.read_allele_count_arrow(kwargs["region"]) + else: + df = stats_v3_dataset.read_allele_count(**kwargs) + tbl = stats_v3_dataset.read_allele_count_arrow(**kwargs) + assert isinstance(df, pd.DataFrame) + assert isinstance(tbl, pa.Table) + assert df.shape == (7, 7) + assert df.equals(tbl.to_pandas()) + assert sum(df["pos"] == expected_pos) == 7 + assert sum(df["count"] == expected_count) == 7 + + +@skip_if_no_bcftools +def test_allele_count_multi_contig_regions(stats_v3_dataset): + """Verify read_allele_count handles multiple contig regions and sorts results by contig.""" + # Results are always returned in contig-sorted order regardless of input order. + region_chr1 = "chr1:1-10000" + region_chr2 = "chr2:1-10000" + expected_contigs = ["chr1"] * 7 + ["chr2"] * 2 + + df = stats_v3_dataset.read_allele_count(regions=[region_chr1, region_chr2]) + assert df.shape == (9, 7) + assert expected_contigs == list(df["contig"].values) + + df_reversed = stats_v3_dataset.read_allele_count(regions=[region_chr2, region_chr1]) + assert df.equals(df_reversed) + + tbl = stats_v3_dataset.read_allele_count_arrow(regions=[region_chr1, region_chr2]) + tbl_reversed = stats_v3_dataset.read_allele_count_arrow(regions=[region_chr2, region_chr1]) + assert tbl.equals(tbl_reversed) + assert df.equals(tbl.to_pandas()) + + +@skip_if_no_bcftools +def test_allele_count_overlapping_regions(stats_v3_dataset): + """Verify read_allele_count deduplicates and merges overlapping regions on the same contig.""" + # Overlapping regions on the same contig are merged; results are deduped and sorted. + expected_contigs = ["chr1"] * 7 + ["chr2"] * 2 + + assert stats_v3_dataset.read_allele_count(regions=["chr1:1-1"]).shape == (1, 7) + assert stats_v3_dataset.read_allele_count(regions=["chr1:1-2"]).shape == (3, 7) + assert stats_v3_dataset.read_allele_count(regions=["chr1:3-4"]).shape == (4, 7) + assert stats_v3_dataset.read_allele_count(regions=["chr1:2-5"]).shape == (6, 7) + + regions_chr1 = ["chr1:1-1", "chr1:1-2", "chr1:3-4", "chr1:2-5"] + df = stats_v3_dataset.read_allele_count(regions=regions_chr1) + assert df.shape == (7, 7) + assert df.equals(stats_v3_dataset.read_allele_count(regions=reversed(regions_chr1))) + + assert stats_v3_dataset.read_allele_count(regions=["chr2:1-1"]).shape == (1, 7) + assert stats_v3_dataset.read_allele_count(regions=["chr2:3-3"]).shape == (1, 7) + + regions_chr2 = ["chr2:1-1", "chr2:3-3"] + df = stats_v3_dataset.read_allele_count(regions=regions_chr2) + assert df.shape == (2, 7) + assert df.equals(stats_v3_dataset.read_allele_count(regions=reversed(regions_chr2))) + + for regions in [regions_chr1 + regions_chr2, regions_chr2 + regions_chr1]: + df = stats_v3_dataset.read_allele_count(regions=regions) + assert df.shape == (9, 7) + assert expected_contigs == list(df["contig"].values) + assert df.equals(stats_v3_dataset.read_allele_count(regions=reversed(regions))) + + +@skip_if_no_bcftools +def test_allele_frequency(stats_v3_dataset, tmp_path): + """Verify allele frequency consistency: ac / af rounds to an.""" + # Verify that ac / af ≈ an (i.e. allele frequency is consistent with counts). + region = "chr1:1-10000" + # read_allele_frequency internally uses the deprecated `region` parameter. + with pytest.warns(DeprecationWarning, match='"region" parameter is deprecated'): + df = tiledbvcf.allele_frequency.read_allele_frequency( + os.path.join(tmp_path, "stats_test"), region + ) + assert df.pos.is_monotonic_increasing + df["an_check"] = (df.ac / df.af).round(0).astype("int32") + assert df.an_check.equals(df.an) + assert stats_v3_dataset.read_variant_stats(regions=[region]).shape == (13, 6) + + +@skip_if_no_bcftools +def test_allele_frequency_invalid_region_format(stats_v3_dataset, tmp_path): + """Verify read_allele_frequency rejects a badly-formatted region string.""" + uri = os.path.join(tmp_path, "stats_test") + with pytest.warns(DeprecationWarning, match='"region" parameter is deprecated'): + with pytest.raises(Exception, match='"region" parameter must have format'): + tiledbvcf.allele_frequency.read_allele_frequency(uri, "chr1") + + +@skip_if_no_bcftools +def test_allele_frequency_empty_region(stats_v3_dataset, tmp_path): + """Verify read_allele_frequency returns an empty DataFrame for a region with no data.""" + uri = os.path.join(tmp_path, "stats_test") + with pytest.warns(DeprecationWarning, match='"region" parameter is deprecated'): + df = tiledbvcf.allele_frequency.read_allele_frequency(uri, "chr3:1-10000") + assert df.empty + + +def test_sample_qc_samples_parameter(tmp_path): + """Verify sample_qc can be filtered to specific samples.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset(enable_variant_stats=True, enable_allele_count=True) + ds.ingest_samples( + [os.path.join(TESTS_INPUT_DIR, s) for s in ["small.bcf", "small3.bcf"]] + ) + + qc_all = tiledbvcf.sample_qc(uri) + assert set(qc_all["sample"]) == {"HG00280", "HG01762"} + + qc_one = tiledbvcf.sample_qc(uri, samples=["HG00280"]) + assert list(qc_one["sample"]) == ["HG00280"] + assert len(qc_one) == 1 + + +def test_sample_qc_config_parameter(tmp_path): + """Smoke Test: Verify sample_qc accepts a config parameter.""" + uri = os.path.join(tmp_path, "dataset") + ds = tiledbvcf.Dataset(uri, mode="w") + ds.create_dataset(enable_variant_stats=True, enable_allele_count=True) + ds.ingest_samples( + [os.path.join(TESTS_INPUT_DIR, s) for s in ["small.bcf", "small3.bcf"]] + ) + + qc_default = tiledbvcf.sample_qc(uri) + qc_with_config = tiledbvcf.sample_qc(uri, config={"sm.tile_cache_size": "0"}) + assert_dfs_equal(qc_default, qc_with_config) diff --git a/apis/python/tests/test_tiledbvcf.py b/apis/python/tests/test_tiledbvcf.py deleted file mode 100755 index c726ffe33..000000000 --- a/apis/python/tests/test_tiledbvcf.py +++ /dev/null @@ -1,2523 +0,0 @@ -import numpy as np -import subprocess -import os -import pandas as pd -import pyarrow as pa -import glob -import shutil -import platform -import pytest -import tiledbvcf -import tiledb - -# Directory containing this file -CONTAINING_DIR = os.path.abspath(os.path.dirname(__file__)) - -# Test inputs directory -TESTS_INPUT_DIR = os.path.abspath( - os.path.join(CONTAINING_DIR, "../../../libtiledbvcf/test/inputs") -) - - -def _check_dfs(expected, actual): - def assert_series(s1, s2): - if np.issubdtype(s2.dtype, np.floating): - assert np.isclose(s1, s2, equal_nan=True).all() - elif np.issubdtype(s2.dtype, np.integer): - assert s1.astype("int64").equals(s2.astype("int64")) - else: - assert s1.equals(s2) - - for k in expected: - assert_series(expected[k], actual[k]) - - for k in actual: - assert_series(expected[k], actual[k]) - - -def check_if_compatible(uri): - try: - with tiledb.open(uri): - return True - except tiledb.libtiledb.TileDBError as e: - if "incompatible format version" in str(e).lower(): - raise pytest.skip.Exception( - "Test skipped due to incompatible format version" - ) - raise pytest.skip.Exception(f"Test skipped due to TileDB error: {str(e)}") - - -@pytest.fixture -def test_ds_v4(): - return tiledbvcf.Dataset( - os.path.join(TESTS_INPUT_DIR, "arrays/v4/ingested_2samples") - ) - - -@pytest.fixture -def test_ds(): - return tiledbvcf.Dataset( - os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") - ) - - -@pytest.fixture -def test_ds_attrs(): - return tiledbvcf.Dataset( - os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples_GT_DP_PL") - ) - - -def test_basic_count(test_ds): - assert test_ds.count() == 14 - - -def test_retrieve_attributes(test_ds): - builtin_attrs = [ - "sample_name", - "contig", - "pos_start", - "pos_end", - "alleles", - "id", - "fmt", - "info", - "filters", - "qual", - "query_bed_end", - "query_bed_start", - "query_bed_line", - ] - assert sorted(test_ds.attributes(attr_type="builtin")) == sorted(builtin_attrs) - - info_attrs = [ - "info_BaseQRankSum", - "info_ClippingRankSum", - "info_DP", - "info_DS", - "info_END", - "info_HaplotypeScore", - "info_InbreedingCoeff", - "info_MLEAC", - "info_MLEAF", - "info_MQ", - "info_MQ0", - "info_MQRankSum", - "info_ReadPosRankSum", - ] - assert test_ds.attributes(attr_type="info") == info_attrs - - fmt_attrs = [ - "fmt_AD", - "fmt_DP", - "fmt_GQ", - "fmt_GT", - "fmt_MIN_DP", - "fmt_PL", - "fmt_SB", - ] - assert test_ds.attributes(attr_type="fmt") == fmt_attrs - - -def test_retrieve_samples(test_ds): - assert test_ds.samples() == ["HG00280", "HG01762"] - - -def test_read_unsupported_regions_type(test_ds): - unsupported_region = 3.14 - unsupported_type_error = f'"regions" parameter cannot have type: {type(unsupported_region)}' - wrong_dimension_region = np.array([["1:12700-13400"], ["1:12700-13400"]]) - ndarray_wrong_dimension_error = f'"regions" parameter of type {type(wrong_dimension_region)} must be 1-dimensional' - with pytest.raises(Exception, match=unsupported_type_error): - test_ds.read(regions=unsupported_region) - with pytest.raises(Exception, match=ndarray_wrong_dimension_error): - test_ds.read(regions=wrong_dimension_region) - with pytest.raises(Exception, match=unsupported_type_error): - test_ds.read_arrow(regions=unsupported_region) - with pytest.raises(Exception, match=ndarray_wrong_dimension_error): - test_ds.read_arrow(regions=wrong_dimension_region) - with pytest.raises(Exception, match=unsupported_type_error): - for variant in test_ds.read_iter(regions=unsupported_region): - print(variant) - with pytest.raises(Exception, match=ndarray_wrong_dimension_error): - for variant in test_ds.read_iter(regions=wrong_dimension_region): - print(variant) - - -def test_read_attrs(test_ds_attrs): - attrs = ["sample_name"] - df = test_ds_attrs.read(attrs=attrs) - assert df.columns.values.tolist() == attrs - - attrs = ["sample_name", "fmt_GT"] - df = test_ds_attrs.read(attrs=attrs) - assert df.columns.values.tolist() == attrs - - attrs = ["sample_name"] - df = test_ds_attrs.read(attrs=attrs) - assert df.columns.values.tolist() == attrs - - -def test_basic_reads(test_ds): - expected_df = pd.DataFrame( - { - "sample_name": pd.Series( - [ - "HG00280", - "HG01762", - "HG00280", - "HG01762", - "HG00280", - "HG01762", - "HG00280", - "HG00280", - "HG00280", - "HG00280", - "HG00280", - "HG00280", - "HG00280", - "HG00280", - ] - ), - "pos_start": pd.Series( - [ - 12141, - 12141, - 12546, - 12546, - 13354, - 13354, - 13375, - 13396, - 13414, - 13452, - 13520, - 13545, - 17319, - 17480, - ], - dtype=np.int32, - ), - "pos_end": pd.Series( - [ - 12277, - 12277, - 12771, - 12771, - 13374, - 13389, - 13395, - 13413, - 13451, - 13519, - 13544, - 13689, - 17479, - 17486, - ], - dtype=np.int32, - ), - } - ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - - for use_arrow in [False, True]: - func = test_ds.read_arrow if use_arrow else test_ds.read - - df = func(attrs=["sample_name", "pos_start", "pos_end"]) - if use_arrow: - df = df.to_pandas() - - _check_dfs( - expected_df, - df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]), - ) - - # Region intersection - df = test_ds.read( - attrs=["sample_name", "pos_start", "pos_end"], regions=["1:12700-13400"] - ) - expected_df = pd.DataFrame( - { - "sample_name": pd.Series( - ["HG00280", "HG01762", "HG00280", "HG01762", "HG00280", "HG00280"] - ), - "pos_start": pd.Series( - [12546, 12546, 13354, 13354, 13375, 13396], dtype=np.int32 - ), - "pos_end": pd.Series( - [12771, 12771, 13374, 13389, 13395, 13413], dtype=np.int32 - ), - } - ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - _check_dfs( - expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - ) - df = test_ds.read_arrow( - attrs=["sample_name", "pos_start", "pos_end"], regions=["1:12700-13400"] - ).to_pandas() - _check_dfs( - expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - ) - - # Regions as string - df = test_ds.read( - attrs=["sample_name", "pos_start", "pos_end"], regions="1:12700-13400" - ) - _check_dfs( - expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - ) - df = test_ds.read_arrow( - attrs=["sample_name", "pos_start", "pos_end"], regions="1:12700-13400" - ).to_pandas() - _check_dfs( - expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - ) - - # Regions as numpy.ndarray - df = test_ds.read( - attrs=["sample_name", "pos_start", "pos_end"], regions=np.array(["1:12700-13400"]) - ) - _check_dfs( - expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - ) - df = test_ds.read_arrow( - attrs=["sample_name", "pos_start", "pos_end"], regions=np.array(["1:12700-13400"]) - ).to_pandas() - _check_dfs( - expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - ) - - # Region and sample intersection - df = test_ds.read( - attrs=["sample_name", "pos_start", "pos_end"], - regions=["1:12700-13400"], - samples=["HG01762"], - ) - expected_df = pd.DataFrame( - { - "sample_name": pd.Series(["HG01762", "HG01762"]), - "pos_start": pd.Series([12546, 13354], dtype=np.int32), - "pos_end": pd.Series([12771, 13389], dtype=np.int32), - } - ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - _check_dfs( - expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - ) - - # Sample only - df = test_ds.read( - attrs=["sample_name", "pos_start", "pos_end"], samples=["HG01762"] - ) - expected_df = pd.DataFrame( - { - "sample_name": pd.Series(["HG01762", "HG01762", "HG01762"]), - "pos_start": pd.Series([12141, 12546, 13354], dtype=np.int32), - "pos_end": pd.Series([12277, 12771, 13389], dtype=np.int32), - } - ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - _check_dfs( - expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - ) - - -def test_multiple_counts(test_ds): - assert test_ds.count() == 14 - assert test_ds.count() == 14 - assert test_ds.count(regions=["1:12700-13400"]) == 6 - assert test_ds.count(samples=["HG00280"], regions=["1:12700-13400"]) == 4 - assert test_ds.count() == 14 - assert test_ds.count(samples=["HG01762"]) == 3 - assert test_ds.count(samples=["HG00280"]) == 11 - - -def test_empty_region(test_ds): - assert test_ds.count(regions=["12:1-1000000"]) == 0 - - -def test_missing_sample_raises_exception(test_ds): - with pytest.raises(RuntimeError): - test_ds.count(samples=["abcde"]) - - -# TODO remove skip -@pytest.mark.skip -def test_bad_contig_raises_exception(test_ds): - with pytest.raises(RuntimeError): - test_ds.count(regions=["chr1:1-1000000"]) - with pytest.raises(RuntimeError): - test_ds.count(regions=["1"]) - with pytest.raises(RuntimeError): - test_ds.count(regions=["1:100-"]) - with pytest.raises(RuntimeError): - test_ds.count(regions=["1:-100"]) - - -def test_bad_attr_raises_exception(test_ds): - with pytest.raises(RuntimeError): - test_ds.read(attrs=["abcde"], regions=["1:12700-13400"]) - - -def test_read_write_mode_exceptions(): - ds = tiledbvcf.Dataset(os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples")) - samples = [os.path.join(TESTS_INPUT_DIR, s) for s in ["small.bcf", "small2.bcf"]] - - with pytest.raises(Exception): - ds.create_dataset() - - with pytest.raises(Exception): - ds.ingest_samples(samples) - - ds = tiledbvcf.Dataset( - os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples"), mode="w" - ) - with pytest.raises(Exception): - ds.count() - - -def test_incomplete_reads(): - # Using undocumented "0 MB" budget to test incomplete reads. - uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") - cfg = tiledbvcf.ReadConfig(memory_budget_mb=0) - test_ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) - - df = test_ds.read(attrs=["pos_end"], regions=["1:12700-13400"]) - assert not test_ds.read_completed() - assert len(df) == 2 - _check_dfs( - pd.DataFrame.from_dict({"pos_end": np.array([12771, 12771], dtype=np.int32)}), - df, - ) - - df = test_ds.continue_read() - assert not test_ds.read_completed() - assert len(df) == 2 - _check_dfs( - pd.DataFrame.from_dict({"pos_end": np.array([13374, 13389], dtype=np.int32)}), - df, - ) - - df = test_ds.continue_read() - assert test_ds.read_completed() - assert len(df) == 2 - _check_dfs( - pd.DataFrame.from_dict({"pos_end": np.array([13395, 13413], dtype=np.int32)}), - df, - ) - - # test incomplete via read_arrow - table = test_ds.read_arrow(attrs=["pos_end"], regions=["1:12700-13400"]) - assert not test_ds.read_completed() - assert len(table) == 2 - _check_dfs( - pd.DataFrame.from_dict({"pos_end": np.array([12771, 12771], dtype=np.int32)}), - table.to_pandas(), - ) - - table = test_ds.continue_read_arrow() - assert not test_ds.read_completed() - assert len(table) == 2 - _check_dfs( - pd.DataFrame.from_dict({"pos_end": np.array([13374, 13389], dtype=np.int32)}), - table.to_pandas(), - ) - - table = test_ds.continue_read_arrow() - assert test_ds.read_completed() - assert len(table) == 2 - _check_dfs( - pd.DataFrame.from_dict({"pos_end": np.array([13395, 13413], dtype=np.int32)}), - table.to_pandas(), - ) - - -def test_incomplete_read_generator(): - # Using undocumented "0 MB" budget to test incomplete reads. - uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") - cfg = tiledbvcf.ReadConfig(memory_budget_mb=0) - test_ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) - expected_df = pd.DataFrame.from_dict( - { - "pos_end": np.array( - [12771, 12771, 13374, 13389, 13395, 13413], dtype=np.int32 - ) - } - ) - - # NOTE: Running multiple test shows that the iterator can be reused - - # Regions as string - dfs = [] - for df in test_ds.read_iter(attrs=["pos_end"], regions="1:12700-13400"): - dfs.append(df) - overall_df = pd.concat(dfs, ignore_index=True) - assert len(overall_df) == 6 - _check_dfs(expected_df, overall_df) - - # Regions as list - dfs = [] - for df in test_ds.read_iter(attrs=["pos_end"], regions=["1:12700-13400"]): - dfs.append(df) - overall_df = pd.concat(dfs, ignore_index=True) - assert len(overall_df) == 6 - _check_dfs(expected_df, overall_df) - - # Regions as numpy.ndarray - dfs = [] - for df in test_ds.read_iter(attrs=["pos_end"], regions=np.array(["1:12700-13400"])): - dfs.append(df) - overall_df = pd.concat(dfs, ignore_index=True) - assert len(overall_df) == 6 - _check_dfs(expected_df, overall_df) - - -def test_read_filters(test_ds): - df = test_ds.read( - attrs=["sample_name", "pos_start", "pos_end", "filters"], - regions=["1:12700-13400"], - ) - expected_df = pd.DataFrame( - { - "sample_name": pd.Series( - ["HG00280", "HG01762", "HG00280", "HG01762", "HG00280", "HG00280"] - ), - "pos_start": pd.Series( - [12546, 12546, 13354, 13354, 13375, 13396], dtype=np.int32 - ), - "pos_end": pd.Series( - [12771, 12771, 13374, 13389, 13395, 13413], dtype=np.int32 - ), - "filters": pd.Series( - map( - lambda lst: np.array(lst, dtype=object), - [None, None, ["LowQual"], None, None, None], - ) - ), - } - ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - _check_dfs( - expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - ) - - -def test_read_var_length_filters(tmp_path): - uri = os.path.join(tmp_path, "dataset") - ds = tiledbvcf.Dataset(uri, mode="w") - samples = [os.path.join(TESTS_INPUT_DIR, s) for s in ["varLenFilter.vcf.gz"]] - ds.create_dataset() - ds.ingest_samples(samples) - - ds = tiledbvcf.Dataset(uri, mode="r") - df = ds.read(["pos_start", "filters"]) - - expected_df = pd.DataFrame( - { - "pos_start": pd.Series( - [ - 12141, - 12546, - 13354, - 13375, - 13396, - 13414, - 13452, - 13520, - 13545, - 17319, - 17480, - ], - dtype=np.int32, - ), - "filters": pd.Series( - map( - lambda lst: np.array(lst, dtype=object), - [ - ["PASS"], - ["PASS"], - ["ANEUPLOID", "LowQual"], - ["PASS"], - ["PASS"], - ["ANEUPLOID", "LOWQ", "LowQual"], - ["PASS"], - ["PASS"], - ["PASS"], - ["LowQual"], - ["PASS"], - ], - ) - ), - } - ).sort_values(ignore_index=True, by=["pos_start"]) - - _check_dfs(expected_df, df.sort_values(ignore_index=True, by=["pos_start"])) - - -def test_read_alleles(test_ds): - df = test_ds.read( - attrs=["sample_name", "pos_start", "pos_end", "alleles"], - regions=["1:12100-13360", "1:13500-17350"], - ) - expected_df = pd.DataFrame( - { - "sample_name": pd.Series( - [ - "HG00280", - "HG01762", - "HG00280", - "HG01762", - "HG00280", - "HG01762", - "HG00280", - "HG00280", - "HG00280", - "HG00280", - ] - ), - "pos_start": pd.Series( - [12141, 12141, 12546, 12546, 13354, 13354, 13452, 13520, 13545, 17319], - dtype=np.int32, - ), - "pos_end": pd.Series( - [12277, 12277, 12771, 12771, 13374, 13389, 13519, 13544, 13689, 17479], - dtype=np.int32, - ), - "alleles": pd.Series( - map( - lambda lst: np.array(lst, dtype=object), - [ - ["C", ""], - ["C", ""], - ["G", ""], - ["G", ""], - ["T", ""], - ["T", ""], - ["G", ""], - ["G", ""], - ["G", ""], - ["T", ""], - ], - ) - ), - } - ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - _check_dfs( - expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - ) - - -def test_read_multiple_alleles(tmp_path): - uri = os.path.join(tmp_path, "dataset") - ds = tiledbvcf.Dataset(uri, mode="w") - samples = [os.path.join(TESTS_INPUT_DIR, s) for s in ["small3.bcf", "small.bcf"]] - ds.create_dataset() - ds.ingest_samples(samples) - - ds = tiledbvcf.Dataset(uri, mode="r") - df = ds.read( - attrs=["sample_name", "pos_start", "alleles", "id", "filters"], - regions=["1:70100-1300000"], - ) - expected_df = pd.DataFrame( - { - "sample_name": pd.Series(["HG00280", "HG00280"]), - "pos_start": pd.Series([866511, 1289367], dtype=np.int32), - "alleles": pd.Series( - map( - lambda lst: np.array(lst, dtype=object), - [["T", "CCCCTCCCT", "C", "CCCCTCCCTCCCT", "CCCCT"], ["CTG", "C"]], - ) - ), - "id": pd.Series([".", "rs1497816"]), - "filters": pd.Series( - map( - lambda lst: np.array(lst, dtype=object), - [["LowQual"], ["LowQual"]], - ) - ), - } - ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - _check_dfs( - expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - ) - - -def test_read_var_len_attrs(test_ds): - df = test_ds.read( - attrs=["sample_name", "pos_start", "pos_end", "fmt_DP", "fmt_PL"], - regions=["1:12100-13360", "1:13500-17350"], - ) - expected_df = pd.DataFrame( - { - "sample_name": pd.Series( - [ - "HG00280", - "HG01762", - "HG00280", - "HG01762", - "HG00280", - "HG01762", - "HG00280", - "HG00280", - "HG00280", - "HG00280", - ] - ), - "pos_start": pd.Series( - [12141, 12141, 12546, 12546, 13354, 13354, 13452, 13520, 13545, 17319], - dtype=np.int32, - ), - "pos_end": pd.Series( - [12277, 12277, 12771, 12771, 13374, 13389, 13519, 13544, 13689, 17479], - dtype=np.int32, - ), - "fmt_DP": pd.Series([0, 0, 0, 0, 15, 64, 10, 6, 0, 0], dtype=np.int32), - "fmt_PL": pd.Series( - map( - lambda lst: np.array(lst, dtype=np.int32), - [ - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [0, 24, 360], - [0, 66, 990], - [0, 21, 210], - [0, 6, 90], - [0, 0, 0], - [0, 0, 0], - ], - ) - ), - } - ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - - _check_dfs( - expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - ) - - -def test_sample_args(test_ds, tmp_path): - sample_file = os.path.join(tmp_path, "1_sample.txt") - with open(sample_file, "w") as file: - file.write("HG00280") - - region = ["1:12141-12141"] - df1 = test_ds.read(["sample_name"], regions=region, samples=["HG00280"]) - df2 = test_ds.read(["sample_name"], regions=region, samples_file=sample_file) - _check_dfs(df1, df2) - - with pytest.raises(TypeError): - test_ds.read( - attrs=["sample_name"], - regions=region, - samples=["HG00280"], - samples_file=sample_file, - ) - - -def test_read_null_attrs(tmp_path): - uri = os.path.join(tmp_path, "dataset") - ds = tiledbvcf.Dataset(uri, mode="w") - samples = [os.path.join(TESTS_INPUT_DIR, s) for s in ["small3.bcf", "small.bcf"]] - ds.create_dataset() - ds.ingest_samples(samples) - - ds = tiledbvcf.Dataset(uri, mode="r") - df = ds.read( - attrs=[ - "sample_name", - "pos_start", - "pos_end", - "info_BaseQRankSum", - "info_DP", - "fmt_DP", - "fmt_MIN_DP", - ], - regions=["1:12700-13400", "1:69500-69800"], - ) - expected_df = pd.DataFrame( - { - "sample_name": pd.Series( - [ - "HG00280", - "HG00280", - "HG00280", - "HG00280", - "HG01762", - "HG01762", - "HG00280", - "HG00280", - "HG00280", - "HG00280", - "HG00280", - "HG00280", - ] - ), - "pos_start": pd.Series( - [ - 12546, - 13354, - 13375, - 13396, - 12546, - 13354, - 69371, - 69511, - 69512, - 69761, - 69762, - 69771, - ], - dtype=np.int32, - ), - "pos_end": pd.Series( - [ - 12771, - 13374, - 13395, - 13413, - 12771, - 13389, - 69510, - 69511, - 69760, - 69761, - 69770, - 69834, - ], - dtype=np.int32, - ), - "info_BaseQRankSum": pd.Series( - [ - None, - None, - None, - None, - None, - None, - None, - np.array([-0.787], dtype=np.float32), - None, - np.array([1.97], dtype=np.float32), - None, - None, - ] - ), - "info_DP": pd.Series( - [ - None, - None, - None, - None, - None, - None, - None, - np.array([89], dtype=np.int32), - None, - np.array([24], dtype=np.int32), - None, - None, - ] - ), - "fmt_DP": pd.Series( - [0, 15, 6, 2, 0, 64, 180, 88, 97, 24, 23, 21], dtype=np.int32 - ), - "fmt_MIN_DP": pd.Series([0, 14, 3, 1, 0, 30, 20, None, 24, None, 23, 19]), - } - ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - _check_dfs( - expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - ) - - -def test_read_config(): - uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") - cfg = tiledbvcf.ReadConfig() - ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) - - cfg = tiledbvcf.ReadConfig( - memory_budget_mb=512, - region_partition=(0, 3), - tiledb_config=["sm.tile_cache_size=0", "sm.compute_concurrency_level=1"], - ) - ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) - - with pytest.raises(TypeError): - cfg = tiledbvcf.ReadConfig(abc=123) - - # Expect an exception when passing both cfg and tiledb_config - with pytest.raises(Exception): - cfg = tiledbvcf.ReadConfig() - tiledb_config = {"foo": "bar"} - ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg, tiledb_config=tiledb_config) - - -# This test is skipped because running it in the same process as all the normal -# tests will cause it to fail (the first context created in a process determines -# the number of TBB threads allowed). -@pytest.mark.skip -def test_tbb_threads_config(): - uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") - cfg = tiledbvcf.ReadConfig(tiledb_config=["sm.num_tbb_threads=3"]) - ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) - - cfg = tiledbvcf.ReadConfig(tiledb_config=["sm.num_tbb_threads=4"]) - with pytest.raises(RuntimeError): - ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) - - -def test_read_limit(): - uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") - cfg = tiledbvcf.ReadConfig(limit=3) - ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) - df = ds.read( - attrs=["sample_name", "pos_start", "pos_end", "fmt_DP", "fmt_PL"], - regions=["1:12100-13360", "1:13500-17350"], - ) - assert len(df) == 3 - - -def test_region_partitioned_read(): - uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") - - cfg = tiledbvcf.ReadConfig(region_partition=(0, 2)) - ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) - df = ds.read( - attrs=["sample_name", "pos_start", "pos_end"], - regions=["1:12000-13000", "1:17000-18000"], - ) - assert len(df) == 4 - - cfg = tiledbvcf.ReadConfig(region_partition=(1, 2)) - ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) - df = ds.read( - attrs=["sample_name", "pos_start", "pos_end"], - regions=["1:12000-13000", "1:17000-18000"], - ) - assert len(df) == 2 - - # Too many partitions still produces results - cfg = tiledbvcf.ReadConfig(region_partition=(1, 3)) - ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) - df = ds.read( - attrs=["sample_name", "pos_start", "pos_end"], - regions=["1:12000-13000", "1:17000-18000"], - ) - assert len(df) == 2 - - # Error: index >= num partitions - cfg = tiledbvcf.ReadConfig(region_partition=(2, 2)) - with pytest.raises(RuntimeError): - ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) - - -def test_sample_partitioned_read(): - uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") - - cfg = tiledbvcf.ReadConfig(sample_partition=(0, 2)) - ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) - df = ds.read( - attrs=["sample_name", "pos_start", "pos_end"], regions=["1:12000-18000"] - ) - assert len(df) == 11 - assert (df.sample_name == "HG00280").all() - - cfg = tiledbvcf.ReadConfig(sample_partition=(1, 2)) - ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) - df = ds.read( - attrs=["sample_name", "pos_start", "pos_end"], regions=["1:12000-18000"] - ) - assert len(df) == 3 - assert (df.sample_name == "HG01762").all() - - # Error: too many partitions - cfg = tiledbvcf.ReadConfig(sample_partition=(1, 3)) - ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) - with pytest.raises(RuntimeError): - df = ds.read( - attrs=["sample_name", "pos_start", "pos_end"], regions=["1:12000-18000"] - ) - - # Error: index >= num partitions - cfg = tiledbvcf.ReadConfig(sample_partition=(2, 2)) - with pytest.raises(RuntimeError): - ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) - - -def test_sample_and_region_partitioned_read(): - uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples") - - cfg = tiledbvcf.ReadConfig(region_partition=(0, 2), sample_partition=(0, 2)) - ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) - df = ds.read( - attrs=["sample_name", "pos_start", "pos_end"], - regions=["1:12000-13000", "1:17000-18000"], - ) - assert len(df) == 2 - assert (df.sample_name == "HG00280").all() - - cfg = tiledbvcf.ReadConfig(region_partition=(0, 2), sample_partition=(1, 2)) - ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) - df = ds.read( - attrs=["sample_name", "pos_start", "pos_end"], - regions=["1:12000-13000", "1:17000-18000"], - ) - assert len(df) == 2 - assert (df.sample_name == "HG01762").all() - - cfg = tiledbvcf.ReadConfig(region_partition=(1, 2), sample_partition=(0, 2)) - ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) - df = ds.read( - attrs=["sample_name", "pos_start", "pos_end"], - regions=["1:12000-13000", "1:17000-18000"], - ) - assert len(df) == 2 - assert (df.sample_name == "HG00280").all() - - cfg = tiledbvcf.ReadConfig(region_partition=(1, 2), sample_partition=(1, 2)) - ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg) - df = ds.read( - attrs=["sample_name", "pos_start", "pos_end"], - regions=["1:12000-13000", "1:17000-18000"], - ) - assert len(df) == 0 - - -@pytest.mark.skipif(os.environ.get("CI") != "true", reason="CI only") -def test_large_export_correctness(): - uri = "s3://tiledb-inc-demo-data/tiledbvcf-arrays/v4/vcf-samples-20" - - ds = tiledbvcf.Dataset(uri) - df = ds.read( - attrs=[ - "sample_name", - "contig", - "pos_start", - "pos_end", - "query_bed_start", - "query_bed_end", - ], - samples=["v2-DjrIAzkP", "v2-YMaDHIoW", "v2-usVwJUmo", "v2-ZVudhauk"], - bed_file=os.path.join( - TESTS_INPUT_DIR, "E001_15_coreMarks_dense_filtered.bed.gz" - ), - ) - - # total number of exported records - assert df.shape[0] == 1172081 - - # number of unique exported records - record_index = ["sample_name", "contig", "pos_start"] - assert df[record_index].drop_duplicates().shape[0] == 1168430 - - -def test_basic_ingest(tmp_path): - # Create the dataset - uri = os.path.join(tmp_path, "dataset") - ds = tiledbvcf.Dataset(uri, mode="w") - samples = [os.path.join(TESTS_INPUT_DIR, s) for s in ["small.bcf", "small2.bcf"]] - ds.create_dataset() - ds.ingest_samples(samples) - - # Open it back in read mode and check some queries - ds = tiledbvcf.Dataset(uri, mode="r") - assert ds.count() == 14 - assert ds.count(regions=["1:12700-13400"]) == 6 - assert ds.count(samples=["HG00280"], regions=["1:12700-13400"]) == 4 - - -def test_disable_ingestion_tasks(tmp_path): - # Create the dataset - uri = os.path.join(tmp_path, "dataset") - ds = tiledbvcf.Dataset(uri, mode="w") - samples = [os.path.join(TESTS_INPUT_DIR, s) for s in ["small.bcf", "small3.bcf"]] - ds.create_dataset( - enable_allele_count=False, enable_variant_stats=False, enable_sample_stats=False - ) - ds.ingest_samples(samples) - - # TODO: remove this workaround when sc-19721 is resolved - if platform.system() != "Linux": - return - - # Validate that stats arrays were not created - ac_uri = os.path.join(tmp_path, "dataset", "allele_count") - vs_uri = os.path.join(tmp_path, "dataset", "variant_stats") - ss_uri = os.path.join(tmp_path, "dataset", "sample_stats") - - assert not os.path.exists(ac_uri) - assert not os.path.exists(vs_uri) - assert not os.path.exists(ss_uri) - - -def test_ingestion_tasks(tmp_path): - # Create the dataset - uri = os.path.join(tmp_path, "dataset") - ds = tiledbvcf.Dataset(uri, mode="w") - samples = [os.path.join(TESTS_INPUT_DIR, s) for s in ["small.bcf", "small3.bcf"]] - ds.create_dataset(enable_allele_count=True, enable_variant_stats=True) - ds.ingest_samples(samples) - - # TODO: remove this workaround when sc-19721 is resolved - if platform.system() != "Linux": - return - - # query allele_count array with TileDB - ac_uri = tiledb.Group(uri)["allele_count"].uri - - check_if_compatible(ac_uri) - - contig = "1" - region = slice(69896) - with tiledb.open(ac_uri) as A: - df = A.query(attrs=["alt", "count"], dims=["pos"]).df[contig, region] - - assert df["pos"].array == 69896 - assert df["alt"].array == "C" - assert df["count"].array == 1 - - # query variant_stats array with TileDB - vs_uri = tiledb.Group(uri)["variant_stats"].uri - - contig = "1" - region = slice(12140) - with tiledb.open(vs_uri) as A: - df = A.query(attrs=["allele", "ac"], dims=["pos"]).df[contig, region] - - assert df["pos"].array == 12140 - assert df["allele"].array == "C" - assert df["ac"].array == 4 - - # Test raw sample_stats - - expected_df = pd.DataFrame( - { - "sample": ["HG00280", "HG01762"], - "dp_sum": [879, 64], - "dp_sum2": [56375, 4096], - "dp_count": [68, 2], - "dp_min": [0, 0], - "dp_max": [180, 64], - "gq_sum": [1489, 99], - "gq_sum2": [79129, 9801], - "gq_count": [68, 2], - "gq_min": [0, 0], - "gq_max": [99, 99], - "n_records": [70, 3], - "n_called": [70, 3], - "n_not_called": [0, 0], - "n_hom_ref": [64, 3], - "n_het": [3, 0], - "n_singleton": [4, 0], - "n_snp": [7, 0], - "n_insertion": [2, 0], - "n_deletion": [1, 0], - "n_transition": [6, 0], - "n_transversion": [1, 0], - "n_star": [0, 0], - "n_multiallelic": [5, 0], - } - ).astype("uint64", errors="ignore") - - ss_uri = tiledb.Group(uri)["sample_stats"].uri - with tiledb.open(ss_uri) as A: - df = A.df[:] - - # Convert to uint64 for comparison to expected_df - df = df.astype("uint64", errors="ignore") - - assert df.equals(expected_df) - - # Test sample_qc - expected_qc = pd.DataFrame( - { - "sample": ["HG00280", "HG01762"], - "dp_mean": [12.92647, 32.0], - "dp_stddev": [25.728399, 32.0], - "dp_min": [0, 0], - "dp_max": [180, 64], - "gq_mean": [21.897058, 49.5], - "gq_stddev": [26.156845, 49.5], - "gq_min": [0, 0], - "gq_max": [99, 99], - "call_rate": [1.0, 1.0], - "n_called": [70, 3], - "n_not_called": [0, 0], - "n_hom_ref": [64, 3], - "n_het": [3, 0], - "n_hom_var": [3, 0], - "n_non_ref": [6, 0], - "n_singleton": [4, 0], - "n_snp": [7, 0], - "n_insertion": [2, 0], - "n_deletion": [1, 0], - "n_transition": [6, 0], - "n_transversion": [1, 0], - "n_star": [0, 0], - "r_ti_tv": [6.0, np.nan], - "r_het_hom_var": [1.0, np.nan], - "r_insertion_deletion": [2.0, np.nan], - "n_records": [70, 3], - "n_multiallelic": [5, 0], - } - ) - - qc = tiledbvcf.sample_qc(uri) - _check_dfs(expected_qc, qc) - - -def test_incremental_ingest(tmp_path): - uri = os.path.join(tmp_path, "dataset") - ds = tiledbvcf.Dataset(uri, mode="w") - ds.create_dataset() - ds.ingest_samples([os.path.join(TESTS_INPUT_DIR, "small.bcf")]) - ds.ingest_samples([os.path.join(TESTS_INPUT_DIR, "small2.bcf")]) - - # Open it back in read mode and check some queries - ds = tiledbvcf.Dataset(uri, mode="r") - assert ds.count() == 14 - assert ds.count(regions=["1:12700-13400"]) == 6 - assert ds.count(samples=["HG00280"], regions=["1:12700-13400"]) == 4 - - -def test_ingest_disable_merging(tmp_path): - # Create the dataset - uri = os.path.join(tmp_path, "dataset_disable_merging") - - cfg = tiledbvcf.ReadConfig(memory_budget_mb=1024) - attrs = ["sample_name", "contig", "pos_start", "pos_end"] - - ds = tiledbvcf.Dataset(uri, mode="w") - samples = [ - os.path.join(TESTS_INPUT_DIR, s) for s in ["v2-DjrIAzkP-downsampled.vcf.gz"] - ] - ds.create_dataset() - ds.ingest_samples(samples, contig_fragment_merging=False) - - # Open it back in read mode and check some queries - ds = tiledbvcf.Dataset(uri, cfg=cfg, mode="r", verbose=False) - df = ds.read(attrs=attrs) - assert ds.count() == 246 - assert ds.count(regions=["chrX:9032893-9032893"]) == 1 - - # Create the dataset - uri = os.path.join(tmp_path, "dataset_merging_separate") - ds2 = tiledbvcf.Dataset(uri, mode="w", verbose=False) - samples = [ - os.path.join(TESTS_INPUT_DIR, s) for s in ["v2-DjrIAzkP-downsampled.vcf.gz"] - ] - ds2.create_dataset() - ds2.ingest_samples(samples, contigs_to_keep_separate=["chr1"]) - - # Open it back in read mode and check some queries - ds2 = tiledbvcf.Dataset(uri, cfg=cfg, mode="r", verbose=False) - df2 = ds2.read(attrs=attrs) - assert df.equals(df2) - - assert ds.count() == 246 - assert ds.count(regions=["chrX:9032893-9032893"]) == 1 - - -def test_ingest_merging_separate(tmp_path): - # Create the dataset - uri = os.path.join(tmp_path, "dataset_merging_separate") - ds = tiledbvcf.Dataset(uri, mode="w") - samples = [ - os.path.join(TESTS_INPUT_DIR, s) for s in ["v2-DjrIAzkP-downsampled.vcf.gz"] - ] - ds.create_dataset() - ds.ingest_samples(samples, contigs_to_keep_separate=["chr1"]) - - # Open it back in read mode and check some queries - ds = tiledbvcf.Dataset(uri, mode="r") - assert ds.count() == 246 - assert ds.count(regions=["chrX:9032893-9032893"]) == 1 - - -def test_ingest_merging(tmp_path): - # Create the dataset - uri = os.path.join(tmp_path, "dataset_merging") - ds = tiledbvcf.Dataset(uri, mode="w") - samples = [ - os.path.join(TESTS_INPUT_DIR, s) for s in ["v2-DjrIAzkP-downsampled.vcf.gz"] - ] - ds.create_dataset() - ds.ingest_samples(samples, contigs_to_allow_merging=["chr1", "chr2"]) - - # Open it back in read mode and check some queries - ds = tiledbvcf.Dataset(uri, mode="r") - assert ds.count() == 246 - assert ds.count(regions=["chrX:9032893-9032893"]) == 1 - - -def test_ingest_mode_merged(tmp_path): - # tiledbvcf.config_logging("debug") - # Create the dataset - uri = os.path.join(tmp_path, "dataset_merging") - ds = tiledbvcf.Dataset(uri, mode="w") - samples = [ - os.path.join(TESTS_INPUT_DIR, s) for s in ["v2-DjrIAzkP-downsampled.vcf.gz"] - ] - ds.create_dataset() - # ingest only merged contigs (pseudo-contigs) - ds.ingest_samples(samples, contig_mode="merged") - - # Open it back in read mode and check some queries - ds = tiledbvcf.Dataset(uri, mode="r") - assert ds.count() == 19 - assert ds.count(regions=["chrX:9032893-9032893"]) == 0 - - -@pytest.fixture -def test_stats_bgzipped_inputs(tmp_path): - tmp_path_contents = os.listdir(tmp_path) - if "stats" in tmp_path_contents: - shutil.rmtree(os.path.join(tmp_path, "stats")) - shutil.copytree( - os.path.join(TESTS_INPUT_DIR, "stats"), os.path.join(tmp_path, "stats") - ) - raw_inputs = glob.glob(os.path.join(tmp_path, "stats", "*.vcf")) - # print(f"raw inputs: {raw_inputs}") - for vcf_file in raw_inputs: - subprocess.run( - "bcftools view --no-version -Oz -o " + vcf_file + ".gz " + vcf_file, - shell=True, - check=True, - ) - bgzipped_inputs = glob.glob(os.path.join(tmp_path, "stats", "*.gz")) - for vcf_file in bgzipped_inputs: - assert subprocess.run("bcftools index " + vcf_file, shell=True).returncode == 0 - if "outputs" in tmp_path_contents: - shutil.rmtree(os.path.join(tmp_path, "outputs")) - if "stats_test" in tmp_path_contents: - shutil.rmtree(os.path.join(tmp_path, "stats_test")) - return bgzipped_inputs - - -@pytest.fixture -def test_stats_sample_names(test_stats_bgzipped_inputs): - assert len(test_stats_bgzipped_inputs) == 8 - return [os.path.basename(file).split(".")[0] for file in test_stats_bgzipped_inputs] - - -@pytest.fixture -def test_stats_v3_ingestion(tmp_path, test_stats_bgzipped_inputs): - assert len(test_stats_bgzipped_inputs) == 8 - # print(f"bgzipped inputs: {test_stats_bgzipped_inputs}") - ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="w") - ds.create_dataset( - enable_variant_stats=True, enable_allele_count=True, variant_stats_version=3 - ) - ds.ingest_samples(test_stats_bgzipped_inputs) - ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="r") - return ds - - -# Ok to skip is missing bcftools in Windows CI job -@pytest.mark.skipif( - os.environ.get("CI") == "true" - and platform.system() == "Windows" - and shutil.which("bcftools") is None, - reason="no bcftools", -) -def test_ingest_with_stats_v3( - tmp_path, test_stats_v3_ingestion, test_stats_sample_names -): - data_frame = test_stats_v3_ingestion.read( - samples=test_stats_sample_names, - attrs=["contig", "pos_start", "id", "qual", "info_TILEDB_IAF", "sample_name"], - set_af_filter="<0.2", - ) - assert data_frame.shape == (1, 8) - assert data_frame.query("sample_name == 'second'")["qual"].iloc[0] == pytest.approx( - 343.73 - ) - assert ( - data_frame[data_frame["sample_name"] == "second"]["info_TILEDB_IAF"].iloc[0][0] - == 0.9375 - ) - data_frame = test_stats_v3_ingestion.read( - samples=test_stats_sample_names, - attrs=["contig", "pos_start", "id", "qual", "info_TILEDB_IAF", "sample_name"], - scan_all_samples=True, - ) - assert ( - data_frame[ - (data_frame["sample_name"] == "second") & (data_frame["pos_start"] == 4) - ]["info_TILEDB_IAF"].iloc[0][0] - == 0.9375 - ) - - ###################### - # read_variant_stats # - ###################### - - # test errors - no_parameter_error = '"region" or "regions" parameter is required' - exclusive_parameter_error = ( - '"region" and "regions" parameters are mutually exclusive' - ) - format_error = '"region" parameter must have format ":-"' - empty_contig_error = "Region contig cannot be empty" - base_1_error = "Regions must be 1-based" - interval_error = '"100-1" is not a valid region interval' - with pytest.raises(Exception, match=no_parameter_error): - test_stats_v3_ingestion.read_variant_stats() - with pytest.raises(Exception, match=no_parameter_error): - test_stats_v3_ingestion.read_variant_stats_arrow() - with pytest.raises(Exception, match=exclusive_parameter_error): - test_stats_v3_ingestion.read_variant_stats("chr1:1-100", regions=["chr1:1-100"]) - with pytest.raises(Exception, match=exclusive_parameter_error): - test_stats_v3_ingestion.read_variant_stats_arrow( - "chr1:1-100", regions=["chr1:1-100"] - ) - with pytest.raises(Exception, match=format_error): - test_stats_v3_ingestion.read_variant_stats(regions=[""]) - with pytest.raises(Exception, match=format_error): - test_stats_v3_ingestion.read_variant_stats_arrow(regions=[""]) - with pytest.raises(Exception, match=format_error): - test_stats_v3_ingestion.read_variant_stats(regions=["chr1"]) - with pytest.raises(Exception, match=format_error): - test_stats_v3_ingestion.read_variant_stats_arrow(regions=["chr1"]) - with pytest.raises(Exception, match=format_error): - test_stats_v3_ingestion.read_variant_stats(regions=["chr1:-"]) - with pytest.raises(Exception, match=format_error): - test_stats_v3_ingestion.read_variant_stats_arrow(regions=["chr1:-"]) - with pytest.raises(Exception, match=empty_contig_error): - test_stats_v3_ingestion.read_variant_stats(regions=[":1-100"]) - with pytest.raises(Exception, match=empty_contig_error): - test_stats_v3_ingestion.read_variant_stats_arrow(regions=[":1-100"]) - with pytest.raises(Exception, match=base_1_error): - test_stats_v3_ingestion.read_variant_stats(regions=["chr1:0-100"]) - with pytest.raises(Exception, match=base_1_error): - test_stats_v3_ingestion.read_variant_stats_arrow(regions=["chr1:0-100"]) - with pytest.raises(Exception, match=interval_error): - test_stats_v3_ingestion.read_variant_stats(regions=["chr1:100-1"]) - with pytest.raises(Exception, match=interval_error): - test_stats_v3_ingestion.read_variant_stats_arrow(regions=["chr1:100-1"]) - - # test empty region - assert test_stats_v3_ingestion.read_variant_stats(regions=["chr3:1-10000"]).empty - - # test types and deprecated region parameter - region1 = "chr1:1-10000" - df = test_stats_v3_ingestion.read_variant_stats(region1) - tbl = test_stats_v3_ingestion.read_variant_stats_arrow(region1) - assert isinstance(df, pd.DataFrame) - assert isinstance(tbl, pa.Table) - assert df.shape == (13, 6) - assert df.equals(tbl.to_pandas()) - df = test_stats_v3_ingestion.read_variant_stats(regions=[region1]) - tbl = test_stats_v3_ingestion.read_variant_stats_arrow(regions=[region1]) - assert isinstance(df, pd.DataFrame) - assert isinstance(tbl, pa.Table) - assert df.shape == (13, 6) - assert df.equals(tbl.to_pandas()) - - # test a region on a different contig - region2 = "chr2:1-10000" - df = test_stats_v3_ingestion.read_variant_stats(regions=[region2]) - tbl = test_stats_v3_ingestion.read_variant_stats_arrow(regions=[region2]) - assert df.shape == (2, 6) - assert df.equals(tbl.to_pandas()) - - # test multiple regions from different contigs and their ordering - regions = [region1, region2] - contigs = ["chr1"] * 13 + ["chr2"] * 2 - df = test_stats_v3_ingestion.read_variant_stats(regions=regions) - assert df.shape == (15, 6) - assert contigs == list(df["contig"].values) - df2 = test_stats_v3_ingestion.read_variant_stats(regions=reversed(regions)) - assert df.equals(df2) - tbl = test_stats_v3_ingestion.read_variant_stats_arrow(regions=regions) - tbl2 = test_stats_v3_ingestion.read_variant_stats_arrow(regions=reversed(regions)) - assert tbl.equals(tbl2) - assert df.equals(tbl.to_pandas()) - assert df2.equals(tbl2.to_pandas()) - - # test overlapping regions on different contigs and their order - region1 = "chr1:1-1" - df = test_stats_v3_ingestion.read_variant_stats(regions=[region1]) - assert df.shape == (2, 6) - region2 = "chr1:1-2" - df = test_stats_v3_ingestion.read_variant_stats(regions=[region2]) - assert df.shape == (5, 6) - region3 = "chr1:3-4" - df = test_stats_v3_ingestion.read_variant_stats(regions=[region3]) - assert df.shape == (6, 6) - region4 = "chr1:2-5" - df = test_stats_v3_ingestion.read_variant_stats(regions=[region4]) - assert df.shape == (11, 6) - regions_chr1 = [region1, region2, region3, region4] - df = test_stats_v3_ingestion.read_variant_stats(regions=regions_chr1) - df2 = test_stats_v3_ingestion.read_variant_stats(regions=reversed(regions_chr1)) - assert df.shape == (13, 6) - assert df.equals(df2) - region5 = "chr2:1-1" - df = test_stats_v3_ingestion.read_variant_stats(regions=[region5]) - assert df.shape == (1, 6) - region6 = "chr2:3-3" - df = test_stats_v3_ingestion.read_variant_stats(regions=[region6]) - assert df.shape == (1, 6) - regions_chr2 = [region5, region6] - df = test_stats_v3_ingestion.read_variant_stats(regions=regions_chr2) - df2 = test_stats_v3_ingestion.read_variant_stats(regions=reversed(regions_chr2)) - assert df.shape == (2, 6) - assert df.equals(df2) - regions = regions_chr1 + regions_chr2 - df = test_stats_v3_ingestion.read_variant_stats(regions=regions) - df2 = test_stats_v3_ingestion.read_variant_stats(regions=reversed(regions)) - assert df.shape == (15, 6) - assert contigs == list(df["contig"].values) - assert df.equals(df2) - regions = regions_chr2 + regions_chr1 - df = test_stats_v3_ingestion.read_variant_stats(regions=regions) - df2 = test_stats_v3_ingestion.read_variant_stats(regions=reversed(regions)) - assert df.shape == (15, 6) - assert contigs == list(df["contig"].values) - assert df.equals(df2) - - # test scan_all_samples - ac = [8, 8, 5, 6, 5, 4, 4, 4, 4, 1, 15, 1, 2, 2, 2] - an = [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 3, 3, 2, 2] - af = [ - 0.5, - 0.5, - 0.3125, - 0.375, - 0.3125, - 0.25, - 0.25, - 0.25, - 0.25, - 0.0625, - 0.9375, - 0.33333334, - 0.6666667, - 1.0, - 1.0, - ] - df = test_stats_v3_ingestion.read_variant_stats(regions=regions) - assert ac == list(df["ac"].values) - assert an == list(df["an"].values) - assert af == list(df["af"].values) - ac = [8, 8, 5, 6, 5, 4, 4, 4, 4, 1, 15, 1, 2, 2, 2] - an = [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16] - af = [ - 0.5, - 0.5, - 0.3125, - 0.375, - 0.3125, - 0.25, - 0.25, - 0.25, - 0.25, - 0.0625, - 0.9375, - 0.0625, - 0.125, - 0.125, - 0.125, - ] - df = test_stats_v3_ingestion.read_variant_stats( - regions=regions, - scan_all_samples=True, - ) - assert ac == list(df["ac"].values) - assert an == list(df["an"].values) - assert af == list(df["af"].values) - - # test drop_ref - alleles = [ - "T,C", - "ref", - "G,GTTTA", - "G,T", - "ref", - "C,A", - "C,G", - "C,T", - "ref", - "G,GTTTA", - "ref", - "C,T", - "ref", - "G,GTTTA", - "G,GTTTA", - ] - df = test_stats_v3_ingestion.read_variant_stats(regions=regions) - assert alleles == list(df["alleles"].values) - alleles = [ - "T,C", - "G,GTTTA", - "G,T", - "C,A", - "C,G", - "C,T", - "G,GTTTA", - "C,T", - "G,GTTTA", - "G,GTTTA", - ] - df = test_stats_v3_ingestion.read_variant_stats( - regions=regions, - drop_ref=True, - ) - assert alleles == list(df["alleles"].values) - - ###################### - # read_allele_count # - ###################### - - # test errors - with pytest.raises(Exception, match=no_parameter_error): - test_stats_v3_ingestion.read_allele_count() - with pytest.raises(Exception, match=no_parameter_error): - test_stats_v3_ingestion.read_allele_count_arrow() - with pytest.raises(Exception, match=exclusive_parameter_error): - test_stats_v3_ingestion.read_allele_count("chr1:1-100", regions=["chr1:1-100"]) - with pytest.raises(Exception, match=exclusive_parameter_error): - test_stats_v3_ingestion.read_allele_count_arrow( - "chr1:1-100", regions=["chr1:1-100"] - ) - with pytest.raises(Exception, match=format_error): - test_stats_v3_ingestion.read_allele_count(regions=[""]) - with pytest.raises(Exception, match=format_error): - test_stats_v3_ingestion.read_allele_count_arrow(regions=[""]) - with pytest.raises(Exception, match=format_error): - test_stats_v3_ingestion.read_allele_count(regions=["chr1"]) - with pytest.raises(Exception, match=format_error): - test_stats_v3_ingestion.read_allele_count_arrow(regions=["chr1"]) - with pytest.raises(Exception, match=format_error): - test_stats_v3_ingestion.read_allele_count(regions=["chr1:-"]) - with pytest.raises(Exception, match=format_error): - test_stats_v3_ingestion.read_allele_count_arrow(regions=["chr1:-"]) - with pytest.raises(Exception, match=empty_contig_error): - test_stats_v3_ingestion.read_allele_count(regions=[":1-100"]) - with pytest.raises(Exception, match=empty_contig_error): - test_stats_v3_ingestion.read_allele_count_arrow(regions=[":1-100"]) - with pytest.raises(Exception, match=base_1_error): - test_stats_v3_ingestion.read_allele_count(regions=["chr1:0-100"]) - with pytest.raises(Exception, match=base_1_error): - test_stats_v3_ingestion.read_allele_count_arrow(regions=["chr1:0-100"]) - with pytest.raises(Exception, match=interval_error): - test_stats_v3_ingestion.read_allele_count(regions=["chr1:100-1"]) - with pytest.raises(Exception, match=interval_error): - test_stats_v3_ingestion.read_allele_count_arrow(regions=["chr1:100-1"]) - - # test empty region - assert test_stats_v3_ingestion.read_allele_count(regions=["chr3:1-10000"]).empty - - # test types and deprecated region parameter - region1 = "chr1:1-10000" - pos = (0, 1, 1, 2, 2, 2, 3) - count = (8, 5, 3, 4, 2, 2, 1) - df = test_stats_v3_ingestion.read_allele_count(region1) - tbl = test_stats_v3_ingestion.read_allele_count_arrow(region1) - assert isinstance(df, pd.DataFrame) - assert isinstance(tbl, pa.Table) - assert df.shape == (7, 7) - assert df.equals(tbl.to_pandas()) - assert sum(df["pos"] == pos) == 7 - assert sum(df["count"] == count) == 7 - df = test_stats_v3_ingestion.read_allele_count(regions=[region1]) - tbl = test_stats_v3_ingestion.read_allele_count_arrow(regions=[region1]) - assert isinstance(df, pd.DataFrame) - assert isinstance(tbl, pa.Table) - assert df.shape == (7, 7) - assert df.equals(tbl.to_pandas()) - assert sum(df["pos"] == pos) == 7 - assert sum(df["count"] == count) == 7 - - # test a region on a different contig - region2 = "chr2:1-10000" - df = test_stats_v3_ingestion.read_allele_count(regions=[region2]) - tbl = test_stats_v3_ingestion.read_allele_count_arrow(regions=[region2]) - assert df.shape == (2, 7) - assert df.equals(tbl.to_pandas()) - - # test multiple regions from different contigs and their ordering - regions = [region1, region2] - contigs = ["chr1"] * 7 + ["chr2"] * 2 - df = test_stats_v3_ingestion.read_allele_count(regions=regions) - assert df.shape == (9, 7) - assert contigs == list(df["contig"].values) - df2 = test_stats_v3_ingestion.read_allele_count(regions=reversed(regions)) - assert df.equals(df2) - tbl = test_stats_v3_ingestion.read_allele_count_arrow(regions=regions) - tbl2 = test_stats_v3_ingestion.read_allele_count_arrow(regions=reversed(regions)) - assert tbl.equals(tbl2) - assert df.equals(tbl.to_pandas()) - assert df2.equals(tbl2.to_pandas()) - - # test overlapping regions on different contigs and their order - region1 = "chr1:1-1" - df = test_stats_v3_ingestion.read_allele_count(regions=[region1]) - assert df.shape == (1, 7) - region2 = "chr1:1-2" - df = test_stats_v3_ingestion.read_allele_count(regions=[region2]) - assert df.shape == (3, 7) - region3 = "chr1:3-4" - df = test_stats_v3_ingestion.read_allele_count(regions=[region3]) - assert df.shape == (4, 7) - region4 = "chr1:2-5" - df = test_stats_v3_ingestion.read_allele_count(regions=[region4]) - assert df.shape == (6, 7) - regions_chr1 = [region1, region2, region3, region4] - df = test_stats_v3_ingestion.read_allele_count(regions=regions_chr1) - df2 = test_stats_v3_ingestion.read_allele_count(regions=reversed(regions_chr1)) - assert df.shape == (7, 7) - assert df.equals(df2) - region5 = "chr2:1-1" - df = test_stats_v3_ingestion.read_allele_count(regions=[region5]) - assert df.shape == (1, 7) - region6 = "chr2:3-3" - df = test_stats_v3_ingestion.read_allele_count(regions=[region6]) - assert df.shape == (1, 7) - regions_chr2 = [region5, region6] - df = test_stats_v3_ingestion.read_allele_count(regions=regions_chr2) - df2 = test_stats_v3_ingestion.read_allele_count(regions=reversed(regions_chr2)) - assert df.shape == (2, 7) - assert df.equals(df2) - regions = regions_chr1 + regions_chr2 - df = test_stats_v3_ingestion.read_allele_count(regions=regions) - df2 = test_stats_v3_ingestion.read_allele_count(regions=reversed(regions)) - assert df.shape == (9, 7) - assert contigs == list(df["contig"].values) - assert df.equals(df2) - regions = regions_chr2 + regions_chr1 - df = test_stats_v3_ingestion.read_allele_count(regions=regions) - df2 = test_stats_v3_ingestion.read_allele_count(regions=reversed(regions)) - assert df.shape == (9, 7) - assert contigs == list(df["contig"].values) - assert df.equals(df2) - - ######################### - # read_allele_frequency # - ######################### - - region = "chr1:1-10000" - df = tiledbvcf.allele_frequency.read_allele_frequency( - os.path.join(tmp_path, "stats_test"), region - ) - assert df.pos.is_monotonic_increasing - df["an_check"] = (df.ac / df.af).round(0).astype("int32") - assert df.an_check.equals(df.an) - df = test_stats_v3_ingestion.read_variant_stats(region) - assert df.shape == (13, 6) - - -@pytest.mark.skipif( - os.environ.get("CI") == "true" - and platform.system() == "Windows" - and shutil.which("bcftools") is None, - reason="no bcftools", -) -def test_delete_samples(tmp_path, test_stats_v3_ingestion, test_stats_sample_names): - # assert test_stats_v3_ingestion.samples() == test_stats_sample_names - assert "second" in test_stats_sample_names - assert "fifth" in test_stats_sample_names - assert "third" in test_stats_sample_names - ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="w") - # tiledbvcf.config_logging("trace") - ds.delete_samples(["second", "fifth"]) - ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="r") - sample_names = ds.samples() - assert "second" not in sample_names - assert "fifth" not in sample_names - assert "third" in sample_names - - -# Ok to skip is missing bcftools in Windows CI job -@pytest.mark.skipif( - os.environ.get("CI") == "true" - and platform.system() == "Windows" - and shutil.which("bcftools") is None, - reason="no bcftools", -) -def test_ingest_with_stats_v2(tmp_path): - # tiledbvcf.config_logging("debug") - tmp_path_contents = os.listdir(tmp_path) - if "stats" in tmp_path_contents: - shutil.rmtree(os.path.join(tmp_path, "stats")) - shutil.copytree( - os.path.join(TESTS_INPUT_DIR, "stats"), os.path.join(tmp_path, "stats") - ) - raw_inputs = glob.glob(os.path.join(tmp_path, "stats", "*.vcf")) - # print(f"raw inputs: {raw_inputs}") - for vcf_file in raw_inputs: - subprocess.run( - "bcftools view --no-version -Oz -o " + vcf_file + ".gz " + vcf_file, - shell=True, - check=True, - ) - bgzipped_inputs = glob.glob(os.path.join(tmp_path, "stats", "*.gz")) - # print(f"bgzipped inputs: {bgzipped_inputs}") - for vcf_file in bgzipped_inputs: - assert subprocess.run("bcftools index " + vcf_file, shell=True).returncode == 0 - if "outputs" in tmp_path_contents: - shutil.rmtree(os.path.join(tmp_path, "outputs")) - if "stats_test" in tmp_path_contents: - shutil.rmtree(os.path.join(tmp_path, "stats_test")) - # tiledbvcf.config_logging("trace") - ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="w") - ds.create_dataset(enable_variant_stats=True, enable_allele_count=True) - ds.ingest_samples(bgzipped_inputs) - ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="r") - sample_names = [os.path.basename(file).split(".")[0] for file in bgzipped_inputs] - data_frame = ds.read( - samples=sample_names, - attrs=["contig", "pos_start", "id", "qual", "info_TILEDB_IAF", "sample_name"], - set_af_filter="<0.2", - ) - assert data_frame.shape == (1, 8) - assert data_frame.query("sample_name == 'second'")["qual"].iloc[0] == pytest.approx( - 343.73 - ) - assert ( - data_frame[data_frame["sample_name"] == "second"]["info_TILEDB_IAF"].iloc[0][0] - == 0.9375 - ) - data_frame = ds.read( - samples=sample_names, - attrs=["contig", "pos_start", "id", "qual", "info_TILEDB_IAF", "sample_name"], - scan_all_samples=True, - ) - assert ( - data_frame[ - (data_frame["sample_name"] == "second") & (data_frame["pos_start"] == 4) - ]["info_TILEDB_IAF"].iloc[0][0] - == 0.9375 - ) - ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="r") - df = ds.read_variant_stats("chr1:1-10000") - assert df.shape == (13, 6) - df = tiledbvcf.allele_frequency.read_allele_frequency( - os.path.join(tmp_path, "stats_test"), "chr1:1-10000" - ) - assert df.pos.is_monotonic_increasing - df["an_check"] = (df.ac / df.af).round(0).astype("int32") - assert df.an_check.equals(df.an) - df = ds.read_variant_stats("chr1:1-10000") - assert df.shape == (13, 6) - df = ds.read_allele_count("chr1:1-10000") - assert df.shape == (7, 7) - assert sum(df["pos"] == (0, 1, 1, 2, 2, 2, 3)) == 7 - assert sum(df["count"] == (8, 5, 3, 4, 2, 2, 1)) == 7 - - -# Ok to skip is missing bcftools in Windows CI job -@pytest.mark.skipif( - os.environ.get("CI") == "true" - and platform.system() == "Windows" - and shutil.which("bcftools") is None, - reason="no bcftools", -) -def test_ingest_polyploid(tmp_path): - tmp_path_contents = os.listdir(tmp_path) - if "polyploid" in tmp_path_contents: - shutil.rmtree(os.path.join(tmp_path, "polyploid")) - shutil.copytree( - os.path.join(TESTS_INPUT_DIR, "polyploid"), os.path.join(tmp_path, "polyploid") - ) - raw_inputs = glob.glob(os.path.join(tmp_path, "polyploid", "*.vcf")) - # print(f"raw inputs: {raw_inputs}") - for vcf_file in raw_inputs: - subprocess.run( - "bcftools view --no-version -Oz -o " + vcf_file + ".gz " + vcf_file, - shell=True, - check=True, - ) - bgzipped_inputs = glob.glob(os.path.join(tmp_path, "polyploid", "*.gz")) - # print(f"bgzipped inputs: {bgzipped_inputs}") - for vcf_file in bgzipped_inputs: - assert subprocess.run("bcftools index " + vcf_file, shell=True).returncode == 0 - if "polyploid" in tmp_path_contents: - shutil.rmtree(os.path.join(tmp_path, "polyploid")) - if "outputs" in tmp_path_contents: - shutil.rmtree(os.path.join(tmp_path, "outputs")) - if "polyploid_test" in tmp_path_contents: - shutil.rmtree(os.path.join(tmp_path, "polyploid_test")) - # tiledbvcf.config_logging("trace") - ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "polyploid_test"), mode="w") - ds.create_dataset(enable_variant_stats=True) - ds.ingest_samples(bgzipped_inputs) - ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "polyploid_test"), mode="r") - sample_names = [os.path.basename(file).split(".")[0] for file in bgzipped_inputs] - data_frame = ds.read( - samples=sample_names, - attrs=["contig", "pos_start", "id", "qual", "info_TILEDB_IAF", "sample_name"], - set_af_filter="<0.8", - ) - # print(data_frame) - - -def test_ingest_mode_separate(tmp_path): - # tiledbvcf.config_logging("debug") - # Create the dataset - uri = os.path.join(tmp_path, "dataset_merging") - ds = tiledbvcf.Dataset(uri, mode="w") - samples = [ - os.path.join(TESTS_INPUT_DIR, s) for s in ["v2-DjrIAzkP-downsampled.vcf.gz"] - ] - ds.create_dataset() - # ingest only merged contigs (pseudo-contigs) - ds.ingest_samples( - samples, contigs_to_keep_separate=["chr1"], contig_mode="separate" - ) - - # Open it back in read mode and check some queries - ds = tiledbvcf.Dataset(uri, mode="r") - assert ds.count() == 17 - assert ds.count(regions=["chrX:9032893-9032893"]) == 0 - - -def test_vcf_attrs(tmp_path): - # Create the dataset with vcf info and fmt attributes - uri = os.path.join(tmp_path, "vcf_attrs_dataset") - ds = tiledbvcf.Dataset(uri, mode="w") - vcf_uri = os.path.join(TESTS_INPUT_DIR, "v2-DjrIAzkP-downsampled.vcf.gz") - ds.create_dataset(vcf_attrs=vcf_uri) - - # Open it back in read mode and check attributes - ds = tiledbvcf.Dataset(uri, mode="r") - - queryable_attrs = [ - "alleles", - "contig", - "filters", - "fmt", - "fmt_DP", - "fmt_GQ", - "fmt_GT", - "fmt_MIN_DP", - "fmt_PS", - "fmt_SB", - "fmt_STR_MAX_LEN", - "fmt_STR_PERIOD", - "fmt_STR_TIMES", - "fmt_VAR_CONTEXT", - "fmt_VAR_TYPE", - "id", - "info", - "info_AC", - "info_AC_AFR", - "info_AC_AMR", - "info_AC_Adj", - "info_AC_CONSANGUINEOUS", - "info_AC_EAS", - "info_AC_FEMALE", - "info_AC_FIN", - "info_AC_Hemi", - "info_AC_Het", - "info_AC_Hom", - "info_AC_MALE", - "info_AC_NFE", - "info_AC_OTH", - "info_AC_POPMAX", - "info_AC_SAS", - "info_AF", - "info_AF_AFR", - "info_AF_AMR", - "info_AF_Adj", - "info_AF_EAS", - "info_AF_FIN", - "info_AF_NFE", - "info_AF_OTH", - "info_AF_SAS", - "info_AGE_HISTOGRAM_HET", - "info_AGE_HISTOGRAM_HOM", - "info_AN", - "info_AN_AFR", - "info_AN_AMR", - "info_AN_Adj", - "info_AN_CONSANGUINEOUS", - "info_AN_EAS", - "info_AN_FEMALE", - "info_AN_FIN", - "info_AN_MALE", - "info_AN_NFE", - "info_AN_OTH", - "info_AN_POPMAX", - "info_AN_SAS", - "info_BaseQRankSum", - "info_CCC", - "info_CSQ", - "info_ClippingRankSum", - "info_DB", - "info_DOUBLETON_DIST", - "info_DP", - "info_DP_HIST", - "info_DS", - "info_END", - "info_ESP_AC", - "info_ESP_AF_GLOBAL", - "info_ESP_AF_POPMAX", - "info_FS", - "info_GQ_HIST", - "info_GQ_MEAN", - "info_GQ_STDDEV", - "info_HWP", - "info_HaplotypeScore", - "info_Hemi_AFR", - "info_Hemi_AMR", - "info_Hemi_EAS", - "info_Hemi_FIN", - "info_Hemi_NFE", - "info_Hemi_OTH", - "info_Hemi_SAS", - "info_Het_AFR", - "info_Het_AMR", - "info_Het_EAS", - "info_Het_FIN", - "info_Het_NFE", - "info_Het_OTH", - "info_Het_SAS", - "info_Hom_AFR", - "info_Hom_AMR", - "info_Hom_CONSANGUINEOUS", - "info_Hom_EAS", - "info_Hom_FIN", - "info_Hom_NFE", - "info_Hom_OTH", - "info_Hom_SAS", - "info_InbreedingCoeff", - "info_K1_RUN", - "info_K2_RUN", - "info_K3_RUN", - "info_KG_AC", - "info_KG_AF_GLOBAL", - "info_KG_AF_POPMAX", - "info_MLEAC", - "info_MLEAF", - "info_MQ", - "info_MQ0", - "info_MQRankSum", - "info_NCC", - "info_NEGATIVE_TRAIN_SITE", - "info_OLD_VARIANT", - "info_POPMAX", - "info_POSITIVE_TRAIN_SITE", - "info_QD", - "info_ReadPosRankSum", - "info_VQSLOD", - "info_clinvar_conflicted", - "info_clinvar_measureset_id", - "info_clinvar_mut", - "info_clinvar_pathogenic", - "info_culprit", - "pos_end", - "pos_start", - "qual", - "query_bed_end", - "query_bed_line", - "query_bed_start", - "sample_name", - ] - - assert ds.attributes(attr_type="info") == [] - assert ds.attributes(attr_type="fmt") == [] - assert sorted(ds.attributes()) == sorted(queryable_attrs) - - -@pytest.mark.parametrize("compress", [True, False]) -def test_sample_compression(tmp_path, compress): - # Create the dataset - dataset_uri = os.path.join(tmp_path, "sample_compression") - array_uri = os.path.join(dataset_uri, "data") - ds = tiledbvcf.Dataset(dataset_uri, mode="w") - ds.create_dataset(compress_sample_dim=compress) - - check_if_compatible(array_uri) - - # Check for the presence of the Zstd filter - found_zstd = False - with tiledb.open(array_uri) as A: - for filter in A.domain.dim("sample").filters: - found_zstd = found_zstd or "Zstd" in str(filter) - - assert found_zstd == compress - - -@pytest.mark.parametrize("level", [1, 4, 16, 22]) -def test_compression_level(tmp_path, level): - # Create the dataset - dataset_uri = os.path.join(tmp_path, "compression_level") - array_uri = os.path.join(dataset_uri, "data") - ds = tiledbvcf.Dataset(dataset_uri, mode="w") - ds.create_dataset(compression_level=level) - - check_if_compatible(array_uri) - - # Check for the expected compression level - with tiledb.open(array_uri) as A: - for i in range(A.schema.nattr): - attr = A.schema.attr(i) - for filter in attr.filters: - if "Zstd" in str(filter): - assert filter.level == level - - -# Ok to skip is missing bcftools in Windows CI job -@pytest.mark.skipif( - os.environ.get("CI") == "true" - and platform.system() == "Windows" - and shutil.which("bcftools") is None, - reason="no bcftools", -) -def test_gvcf_export(tmp_path): - # Compress the input VCFs - vcf_inputs = glob.glob(os.path.join(TESTS_INPUT_DIR, "gvcf-export", "*.vcf")) - for vcf_input in vcf_inputs: - vcf_output = os.path.join(tmp_path, os.path.basename(vcf_input)) + ".gz" - cmd = f"bcftools view --no-version -Oz -o {vcf_output} {vcf_input}" - subprocess.run(cmd, shell=True, check=True) - - # Index the compressed VCFs - vcf_files = glob.glob(os.path.join(tmp_path, "*.gz")) - for vcf_file in vcf_files: - cmd = f"bcftools index {vcf_file}" - subprocess.run(cmd, shell=True, check=True) - - # Ingest the VCFs - uri = os.path.join(tmp_path, "vcf.tdb") - ds = tiledbvcf.Dataset(uri=uri, mode="w") - ds.create_dataset() - ds.ingest_samples(vcf_files) - ds = tiledbvcf.Dataset(uri=uri, mode="r") - - # List of tests. - tests = [ - {"region": "chr1:100-120", "samples": ["s0", "s1", "s2"]}, - {"region": "chr1:110-120", "samples": ["s0", "s1"]}, - {"region": "chr1:149-149", "samples": ["s0", "s1", "s3"]}, - {"region": "chr1:150-150", "samples": ["s0", "s1", "s3", "s4"]}, - ] - - # No IAF filtering or reporting - for test in tests: - df = ds.read(regions=test["region"]) - assert set(df["sample_name"].unique()) == set(test["samples"]) - - attrs = [ - "sample_name", - "contig", - "pos_start", - "alleles", - "fmt_GT", - "info_TILEDB_IAF", - ] - - # IAF reporting - for test in tests: - df = ds.read(attrs=attrs, regions=test["region"]) - assert set(df["sample_name"].unique()) == set(test["samples"]) - - # IAF filtering and reporting - for test in tests: - df = ds.read(attrs=attrs, regions=test["region"], set_af_filter="<=1.0") - assert set(df["sample_name"].unique()) == set(test["samples"]) - - -def test_flag_export(tmp_path): - # Create the dataset - uri = os.path.join(tmp_path, "dataset") - ds = tiledbvcf.Dataset(uri, mode="w") - samples = [os.path.join(TESTS_INPUT_DIR, s) for s in ["small.vcf.gz"]] - ds.create_dataset() - ds.ingest_samples(samples) - - # Read info flags - ds = tiledbvcf.Dataset(uri, mode="r") - df = ds.read(attrs=["pos_start", "info_DB", "info_DS"]) - df = df.sort_values(by=["pos_start"]) - - # Check if flags match the expected values - expected_db = [1, 1, 1, 0, 0, 1] - assert df["info_DB"].tolist() == expected_db - - expected_ds = [1, 1, 0, 0, 1, 1] - assert df["info_DS"].tolist() == expected_ds - - -def test_bed_filestore(tmp_path, test_ds_v4): - # tiledbvcf.config_logging("debug") - - expected_df = pd.DataFrame( - { - "sample_name": pd.Series( - [ - "HG00280", - "HG01762", - "HG00280", - "HG01762", - "HG00280", - ] - ), - "pos_start": pd.Series( - [ - 12141, - 12141, - 12546, - 12546, - 17319, - ], - dtype=np.int32, - ), - "pos_end": pd.Series( - [ - 12277, - 12277, - 12771, - 12771, - 17479, - ], - dtype=np.int32, - ), - } - ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - - # Create BED file - bed_file = os.path.join(tmp_path, "test.bed") - - regions = [ - (1, 12000, 13000), - (1, 17000, 17479), - ] - - with open(bed_file, "w") as f: - for region in regions: - f.write(f"{region[0]}\t{region[1]}\t{region[2]}\n") - - # Create BED filestore from BED file - bed_filestore = os.path.join(tmp_path, "test.bed.filestore") - tiledb.Array.create(bed_filestore, tiledb.ArraySchema.from_file(bed_file)) - tiledb.Filestore.copy_from(bed_filestore, bed_file) - - # Create the dataset - for use_arrow in [False, True]: - func = test_ds_v4.read_arrow if use_arrow else test_ds_v4.read - - df = func(attrs=["sample_name", "pos_start", "pos_end"], bed_file=bed_filestore) - if use_arrow: - df = df.to_pandas() - - # print(df) - - _check_dfs( - expected_df, - df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]), - ) - - -def test_bed_array(tmp_path, test_ds_v4): - expected_df = pd.DataFrame( - { - "sample_name": pd.Series( - [ - "HG00280", - "HG01762", - "HG00280", - "HG01762", - "HG00280", - ] - ), - "pos_start": pd.Series( - [ - 12141, - 12141, - 12546, - 12546, - 17319, - ], - dtype=np.int32, - ), - "pos_end": pd.Series( - [ - 12277, - 12277, - 12771, - 12771, - 17479, - ], - dtype=np.int32, - ), - } - ).sort_values(ignore_index=True, by=["sample_name", "pos_start"]) - - # Create bed array - bed_array = os.path.join(tmp_path, "bed_array") - tiledb.from_pandas( - bed_array, - pd.DataFrame( - { - "chrom": ["1", "1"], - "chromStart": [12000, 17000], - "chromEnd": [13000, 17479], - } - ), - sparse=True, - index_col=["chrom", "chromStart"], - ) - - # Add aliases to the array metadata - with tiledb.Array(bed_array, "w") as A: - A.meta["alias contig"] = "chrom" - A.meta["alias start"] = "chromStart" - A.meta["alias end"] = "chromEnd" - - # Create the dataset - for use_arrow in [False, True]: - func = test_ds_v4.read_arrow if use_arrow else test_ds_v4.read - - df = func(attrs=["sample_name", "pos_start", "pos_end"], bed_file=bed_array) - if use_arrow: - df = df.to_pandas() - - _check_dfs( - expected_df, - df.sort_values(ignore_index=True, by=["sample_name", "pos_start"]), - ) - - -def test_info_end(tmp_path): - """ - This test checks that the info_END attribute is handled correctly, even when the - VCF header incorrectly defines the END attribute as a string. - - The test also checks that info_END contains the original values from the VCF, - including the missing values. - """ - - expected_end = pd.DataFrame( - { - "pos_end": pd.Series( - [ - 12277, - 12771, - 13374, - 13395, - 13413, - 13451, - 13519, - 13544, - 13689, - 17479, - 17486, - 30553, - 35224, - 35531, - 35786, - 69096, - 69103, - 69104, - 69109, - 69110, - 69111, - 69112, - 69114, - 69115, - 69122, - 69123, - 69128, - 69129, - 69130, - 69192, - 69195, - 69196, - 69215, - 69222, - 69227, - 69228, - 69261, - 69262, - 69269, - 69270, - 69346, - 69349, - 69352, - 69353, - 69370, - 69510, - 69511, - 69760, - 69761, - 69770, - 69834, - 69835, - 69838, - 69861, - 69863, - 69866, - 69896, - 69897, - 69912, - 69938, - 69939, - 69941, - 69946, - 69947, - 69948, - 69949, - 69953, - 70012, - 866511, - 1289369, - ], - dtype=np.int32, - ), - # Expected values are strings because the small3.vcf.gz defines END as a string - "info_END": pd.Series( - [ - "12277", - "12771", - "13374", - "13395", - "13413", - "13451", - "13519", - "13544", - "13689", - "17479", - "17486", - "30553", - "35224", - "35531", - "35786", - "69096", - "69103", - "69104", - "69109", - "69110", - "69111", - "69112", - "69114", - "69115", - "69122", - "69123", - "69128", - "69129", - "69130", - "69192", - "69195", - "69196", - "69215", - "69222", - "69227", - "69228", - "69261", - "69262", - "69269", - None, - "69346", - "69349", - "69352", - "69353", - "69370", - "69510", - None, - "69760", - None, - "69770", - "69834", - "69835", - "69838", - "69861", - "69863", - "69866", - "69896", - None, - "69912", - "69938", - "69939", - "69941", - "69946", - "69947", - "69948", - "69949", - "69953", - "70012", - None, - None, - ], - dtype=object, - ), - } - ) - - # Ingest the data - uri = os.path.join(tmp_path, "dataset") - ds = tiledbvcf.Dataset(uri, mode="w") - samples = [os.path.join(TESTS_INPUT_DIR, s) for s in ["small3.vcf.gz"]] - ds.create_dataset() - ds.ingest_samples(samples) - - # Read the data - ds = tiledbvcf.Dataset(uri) - df = ds.read(attrs=["sample_name", "pos_start", "pos_end", "info_END"]) - - # Sort the results because VCF uses an unordered reader - df.sort_values(ignore_index=True, by=["sample_name", "pos_start"], inplace=True) - - # Drop the columns that are not used for comparison - df.drop(columns=["sample_name", "pos_start"], inplace=True) - - # Check the results - _check_dfs(df, expected_end) - - -def test_context_manager(): - ds1_uri = os.path.join(TESTS_INPUT_DIR, "arrays/v4/ingested_2samples") - expected_count1 = 14 - ds2_uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/synth-array") - expected_count2 = 19565 - - # Test the context manager - with tiledbvcf.Dataset(ds1_uri) as ds: - assert ds.count() == expected_count1 - - with tiledbvcf.Dataset(ds2_uri) as ds: - assert ds.count() == expected_count2 - - # Open the datasets outside the context manager - ds1 = tiledbvcf.Dataset(ds1_uri) - assert ds1.count() == expected_count1 - - ds2 = tiledbvcf.Dataset(ds2_uri) - assert ds2.count() == expected_count2 - - # Check that an exception is raised when trying to access a closed dataset - ds1.close() - with pytest.raises(Exception): - assert ds1.count() == expected_count1 - - assert ds2.count() == expected_count2 - - ds2.close() - with pytest.raises(Exception): - assert ds2.count() == expected_count2 - - -def test_delete_dataset(tmp_path): - uri = os.path.join(tmp_path, "delete_dataset") - - with tiledbvcf.Dataset(uri, mode="w") as ds: - ds.create_dataset() - - # Check that the dataset exists - assert os.path.exists(uri) - - # Delete the dataset - tiledbvcf.Dataset.delete(uri) - - # Check that the dataset does not exist - assert not os.path.exists(uri) - - -def test_equality_old_new_format(): - old_ds = tiledbvcf.Dataset(os.path.join(TESTS_INPUT_DIR, "arrays/old_format")) - new_ds = tiledbvcf.Dataset(os.path.join(TESTS_INPUT_DIR, "arrays/new_format")) - - assert old_ds.count() == new_ds.count() - assert old_ds.samples() == new_ds.samples() - assert old_ds.read().equals(new_ds.read())