Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions apis/python/src/tiledbvcf/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import namedtuple
from typing import Generator, List

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
Expand Down Expand Up @@ -279,7 +280,7 @@ def read_arrow(
self,
attrs: List[str] = DEFAULT_ATTRS,
samples: (str, List[str]) = None,
regions: (str, List[str]) = None,
regions: (str, List[str], np.ndarray) = None,
samples_file: str = None,
bed_file: str = None,
skip_check_samples: bool = False,
Expand Down Expand Up @@ -324,10 +325,18 @@ def read_arrow(

if isinstance(regions, str):
regions = [regions]
elif isinstance(regions, np.ndarray):
if regions.ndim != 1:
raise Exception(
f'"regions" parameter of type {type(regions)} must be 1-dimensional'
)
regions = regions.tolist()
if isinstance(regions, list):
regions = map(str, self._prepare_regions(regions))
else:
elif regions is None:
regions = ""
else:
raise Exception(f'"regions" parameter cannot have type: {type(regions)}')

if isinstance(samples, str):
samples = [samples]
Expand Down Expand Up @@ -526,7 +535,7 @@ def read(
self,
attrs: List[str] = DEFAULT_ATTRS,
samples: (str, List[str]) = None,
regions: (str, List[str]) = None,
regions: (str, List[str], np.ndarray) = None,
samples_file: str = None,
bed_file: str = None,
skip_check_samples: bool = False,
Expand Down Expand Up @@ -571,10 +580,19 @@ def read(

if isinstance(regions, str):
regions = [regions]
elif isinstance(regions, np.ndarray):
if regions.ndim != 1:
raise Exception(
f'"regions" parameter of type {type(regions)} must be 1-dimensional'
)
regions = regions.tolist()
if isinstance(regions, list):
regions = map(str, self._prepare_regions(regions))
else:
elif regions is None:
regions = ""
else:
raise Exception(f'"regions" parameter cannot have type: {type(regions)}')

if isinstance(samples, str):
samples = [samples]

Expand All @@ -596,7 +614,7 @@ def read(
def export(
self,
samples: (str, List[str]) = None,
regions: (str, List[str]) = None,
regions: (str, List[str], np.ndarray) = None,
samples_file: str = None,
bed_file: str = None,
skip_check_samples: bool = False,
Expand Down Expand Up @@ -639,10 +657,19 @@ def export(

if isinstance(regions, str):
regions = [regions]
elif isinstance(regions, np.ndarray):
if regions.ndim != 1:
raise Exception(
f'"regions" parameter of type {type(regions)} must be 1-dimensional'
)
regions = regions.tolist()
if isinstance(regions, list):
regions = map(str, self._prepare_regions(regions))
else:
elif regions is None:
regions = ""
else:
raise Exception(f'"regions" parameter cannot have type: {type(regions)}')

if isinstance(samples, str):
samples = [samples]

Expand Down Expand Up @@ -671,7 +698,7 @@ def read_iter(
self,
attrs: List[str] = DEFAULT_ATTRS,
samples: (str, List[str]) = None,
regions: (str, List[str]) = None,
regions: (str, List[str], np.ndarray) = None,
samples_file: str = None,
bed_file: str = None,
):
Expand All @@ -696,10 +723,19 @@ def read_iter(

if isinstance(regions, str):
regions = [regions]
elif isinstance(regions, np.ndarray):
if regions.ndim != 1:
raise Exception(
f'"regions" parameter of type {type(regions)} must be 1-dimensional'
)
regions = regions.tolist()
if isinstance(regions, list):
regions = map(str, self._prepare_regions(regions))
else:
elif regions is None:
regions = ""
else:
raise Exception(f'"regions" parameter cannot have type: {type(regions)}')

if isinstance(samples, str):
samples = [samples]

Expand Down
99 changes: 76 additions & 23 deletions apis/python/tests/test_tiledbvcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,27 @@ def test_retrieve_samples(test_ds):
assert test_ds.samples() == ["HG00280", "HG01762"]


def test_read_unsupported_regions_type(test_ds):
unsupported_region = 3.14
unsupported_type_error = f'"regions" parameter cannot have type: {type(unsupported_region)}'
wrong_dimension_region = np.array([["1:12700-13400"], ["1:12700-13400"]])
ndarray_wrong_dimension_error = f'"regions" parameter of type {type(wrong_dimension_region)} must be 1-dimensional'
with pytest.raises(Exception, match=unsupported_type_error):
test_ds.read(regions=unsupported_region)
with pytest.raises(Exception, match=ndarray_wrong_dimension_error):
test_ds.read(regions=wrong_dimension_region)
with pytest.raises(Exception, match=unsupported_type_error):
test_ds.read_arrow(regions=unsupported_region)
with pytest.raises(Exception, match=ndarray_wrong_dimension_error):
test_ds.read_arrow(regions=wrong_dimension_region)
with pytest.raises(Exception, match=unsupported_type_error):
for variant in test_ds.read_iter(regions=unsupported_region):
print(variant)
with pytest.raises(Exception, match=ndarray_wrong_dimension_error):
for variant in test_ds.read_iter(regions=wrong_dimension_region):
print(variant)


def test_read_attrs(test_ds_attrs):
attrs = ["sample_name"]
df = test_ds_attrs.read(attrs=attrs)
Expand Down Expand Up @@ -233,6 +254,40 @@ def test_basic_reads(test_ds):
_check_dfs(
expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"])
)
df = test_ds.read_arrow(
attrs=["sample_name", "pos_start", "pos_end"], regions=["1:12700-13400"]
).to_pandas()
_check_dfs(
expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"])
)

# Regions as string
df = test_ds.read(
attrs=["sample_name", "pos_start", "pos_end"], regions="1:12700-13400"
)
_check_dfs(
expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"])
)
df = test_ds.read_arrow(
attrs=["sample_name", "pos_start", "pos_end"], regions="1:12700-13400"
).to_pandas()
_check_dfs(
expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"])
)

# Regions as numpy.ndarray
df = test_ds.read(
attrs=["sample_name", "pos_start", "pos_end"], regions=np.array(["1:12700-13400"])
)
_check_dfs(
expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"])
)
df = test_ds.read_arrow(
attrs=["sample_name", "pos_start", "pos_end"], regions=np.array(["1:12700-13400"])
).to_pandas()
_check_dfs(
expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"])
)

# Region and sample intersection
df = test_ds.read(
Expand Down Expand Up @@ -382,41 +437,39 @@ def test_incomplete_read_generator():
uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples")
cfg = tiledbvcf.ReadConfig(memory_budget_mb=0)
test_ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg)

dfs = []
for df in test_ds.read_iter(attrs=["pos_end"], regions=["1:12700-13400"]):
dfs.append(df)
overall_df = pd.concat(dfs, ignore_index=True)

assert len(overall_df) == 6
_check_dfs(
pd.DataFrame.from_dict(
expected_df = pd.DataFrame.from_dict(
{
"pos_end": np.array(
[12771, 12771, 13374, 13389, 13395, 13413], dtype=np.int32
)
}
),
overall_df,
)
)

# NOTE: Running multiple test shows that the iterator can be reused

# Regions as string
dfs = []
for df in test_ds.read_iter(attrs=["pos_end"], regions="1:12700-13400"):
dfs.append(df)
overall_df = pd.concat(dfs, ignore_index=True)
assert len(overall_df) == 6
_check_dfs(expected_df, overall_df)

# Test that the iterator can be used again
# Regions as list
dfs = []
for df in test_ds.read_iter(attrs=["pos_end"], regions=["1:12700-13400"]):
dfs.append(df)
overall_df = pd.concat(dfs, ignore_index=True)
assert len(overall_df) == 6
_check_dfs(expected_df, overall_df)

# Regions as numpy.ndarray
dfs = []
for df in test_ds.read_iter(attrs=["pos_end"], regions=np.array(["1:12700-13400"])):
dfs.append(df)
overall_df = pd.concat(dfs, ignore_index=True)
assert len(overall_df) == 6
_check_dfs(
pd.DataFrame.from_dict(
{
"pos_end": np.array(
[12771, 12771, 13374, 13389, 13395, 13413], dtype=np.int32
)
}
),
overall_df,
)
_check_dfs(expected_df, overall_df)


def test_read_filters(test_ds):
Expand Down
Loading