Skip to content

Commit db4e19c

Browse files
authored
Update regions parameter to support type numpy.ndarray (#873)
* Added support to Python API for regions of type numpy.ndarray Also, regions of type None are now handled explicitly with all other types raising an exception. * Updated Python tests to include new numpy.ndarray regions type Previously untested regions types were added as well.
1 parent 5f5d692 commit db4e19c

2 files changed

Lines changed: 120 additions & 31 deletions

File tree

apis/python/src/tiledbvcf/dataset.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections import namedtuple
55
from typing import Generator, List
66

7+
import numpy as np
78
import pandas as pd
89
import pyarrow as pa
910
import pyarrow.compute as pc
@@ -279,7 +280,7 @@ def read_arrow(
279280
self,
280281
attrs: List[str] = DEFAULT_ATTRS,
281282
samples: (str, List[str]) = None,
282-
regions: (str, List[str]) = None,
283+
regions: (str, List[str], np.ndarray) = None,
283284
samples_file: str = None,
284285
bed_file: str = None,
285286
skip_check_samples: bool = False,
@@ -324,10 +325,18 @@ def read_arrow(
324325

325326
if isinstance(regions, str):
326327
regions = [regions]
328+
elif isinstance(regions, np.ndarray):
329+
if regions.ndim != 1:
330+
raise Exception(
331+
f'"regions" parameter of type {type(regions)} must be 1-dimensional'
332+
)
333+
regions = regions.tolist()
327334
if isinstance(regions, list):
328335
regions = map(str, self._prepare_regions(regions))
329-
else:
336+
elif regions is None:
330337
regions = ""
338+
else:
339+
raise Exception(f'"regions" parameter cannot have type: {type(regions)}')
331340

332341
if isinstance(samples, str):
333342
samples = [samples]
@@ -526,7 +535,7 @@ def read(
526535
self,
527536
attrs: List[str] = DEFAULT_ATTRS,
528537
samples: (str, List[str]) = None,
529-
regions: (str, List[str]) = None,
538+
regions: (str, List[str], np.ndarray) = None,
530539
samples_file: str = None,
531540
bed_file: str = None,
532541
skip_check_samples: bool = False,
@@ -571,10 +580,19 @@ def read(
571580

572581
if isinstance(regions, str):
573582
regions = [regions]
583+
elif isinstance(regions, np.ndarray):
584+
if regions.ndim != 1:
585+
raise Exception(
586+
f'"regions" parameter of type {type(regions)} must be 1-dimensional'
587+
)
588+
regions = regions.tolist()
574589
if isinstance(regions, list):
575590
regions = map(str, self._prepare_regions(regions))
576-
else:
591+
elif regions is None:
577592
regions = ""
593+
else:
594+
raise Exception(f'"regions" parameter cannot have type: {type(regions)}')
595+
578596
if isinstance(samples, str):
579597
samples = [samples]
580598

@@ -596,7 +614,7 @@ def read(
596614
def export(
597615
self,
598616
samples: (str, List[str]) = None,
599-
regions: (str, List[str]) = None,
617+
regions: (str, List[str], np.ndarray) = None,
600618
samples_file: str = None,
601619
bed_file: str = None,
602620
skip_check_samples: bool = False,
@@ -639,10 +657,19 @@ def export(
639657

640658
if isinstance(regions, str):
641659
regions = [regions]
660+
elif isinstance(regions, np.ndarray):
661+
if regions.ndim != 1:
662+
raise Exception(
663+
f'"regions" parameter of type {type(regions)} must be 1-dimensional'
664+
)
665+
regions = regions.tolist()
642666
if isinstance(regions, list):
643667
regions = map(str, self._prepare_regions(regions))
644-
else:
668+
elif regions is None:
645669
regions = ""
670+
else:
671+
raise Exception(f'"regions" parameter cannot have type: {type(regions)}')
672+
646673
if isinstance(samples, str):
647674
samples = [samples]
648675

@@ -671,7 +698,7 @@ def read_iter(
671698
self,
672699
attrs: List[str] = DEFAULT_ATTRS,
673700
samples: (str, List[str]) = None,
674-
regions: (str, List[str]) = None,
701+
regions: (str, List[str], np.ndarray) = None,
675702
samples_file: str = None,
676703
bed_file: str = None,
677704
):
@@ -696,10 +723,19 @@ def read_iter(
696723

697724
if isinstance(regions, str):
698725
regions = [regions]
726+
elif isinstance(regions, np.ndarray):
727+
if regions.ndim != 1:
728+
raise Exception(
729+
f'"regions" parameter of type {type(regions)} must be 1-dimensional'
730+
)
731+
regions = regions.tolist()
699732
if isinstance(regions, list):
700733
regions = map(str, self._prepare_regions(regions))
701-
else:
734+
elif regions is None:
702735
regions = ""
736+
else:
737+
raise Exception(f'"regions" parameter cannot have type: {type(regions)}')
738+
703739
if isinstance(samples, str):
704740
samples = [samples]
705741

apis/python/tests/test_tiledbvcf.py

Lines changed: 76 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,27 @@ def test_retrieve_samples(test_ds):
125125
assert test_ds.samples() == ["HG00280", "HG01762"]
126126

127127

128+
def test_read_unsupported_regions_type(test_ds):
129+
unsupported_region = 3.14
130+
unsupported_type_error = f'"regions" parameter cannot have type: {type(unsupported_region)}'
131+
wrong_dimension_region = np.array([["1:12700-13400"], ["1:12700-13400"]])
132+
ndarray_wrong_dimension_error = f'"regions" parameter of type {type(wrong_dimension_region)} must be 1-dimensional'
133+
with pytest.raises(Exception, match=unsupported_type_error):
134+
test_ds.read(regions=unsupported_region)
135+
with pytest.raises(Exception, match=ndarray_wrong_dimension_error):
136+
test_ds.read(regions=wrong_dimension_region)
137+
with pytest.raises(Exception, match=unsupported_type_error):
138+
test_ds.read_arrow(regions=unsupported_region)
139+
with pytest.raises(Exception, match=ndarray_wrong_dimension_error):
140+
test_ds.read_arrow(regions=wrong_dimension_region)
141+
with pytest.raises(Exception, match=unsupported_type_error):
142+
for variant in test_ds.read_iter(regions=unsupported_region):
143+
print(variant)
144+
with pytest.raises(Exception, match=ndarray_wrong_dimension_error):
145+
for variant in test_ds.read_iter(regions=wrong_dimension_region):
146+
print(variant)
147+
148+
128149
def test_read_attrs(test_ds_attrs):
129150
attrs = ["sample_name"]
130151
df = test_ds_attrs.read(attrs=attrs)
@@ -233,6 +254,40 @@ def test_basic_reads(test_ds):
233254
_check_dfs(
234255
expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"])
235256
)
257+
df = test_ds.read_arrow(
258+
attrs=["sample_name", "pos_start", "pos_end"], regions=["1:12700-13400"]
259+
).to_pandas()
260+
_check_dfs(
261+
expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"])
262+
)
263+
264+
# Regions as string
265+
df = test_ds.read(
266+
attrs=["sample_name", "pos_start", "pos_end"], regions="1:12700-13400"
267+
)
268+
_check_dfs(
269+
expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"])
270+
)
271+
df = test_ds.read_arrow(
272+
attrs=["sample_name", "pos_start", "pos_end"], regions="1:12700-13400"
273+
).to_pandas()
274+
_check_dfs(
275+
expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"])
276+
)
277+
278+
# Regions as numpy.ndarray
279+
df = test_ds.read(
280+
attrs=["sample_name", "pos_start", "pos_end"], regions=np.array(["1:12700-13400"])
281+
)
282+
_check_dfs(
283+
expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"])
284+
)
285+
df = test_ds.read_arrow(
286+
attrs=["sample_name", "pos_start", "pos_end"], regions=np.array(["1:12700-13400"])
287+
).to_pandas()
288+
_check_dfs(
289+
expected_df, df.sort_values(ignore_index=True, by=["sample_name", "pos_start"])
290+
)
236291

237292
# Region and sample intersection
238293
df = test_ds.read(
@@ -382,41 +437,39 @@ def test_incomplete_read_generator():
382437
uri = os.path.join(TESTS_INPUT_DIR, "arrays/v3/ingested_2samples")
383438
cfg = tiledbvcf.ReadConfig(memory_budget_mb=0)
384439
test_ds = tiledbvcf.Dataset(uri, mode="r", cfg=cfg)
385-
386-
dfs = []
387-
for df in test_ds.read_iter(attrs=["pos_end"], regions=["1:12700-13400"]):
388-
dfs.append(df)
389-
overall_df = pd.concat(dfs, ignore_index=True)
390-
391-
assert len(overall_df) == 6
392-
_check_dfs(
393-
pd.DataFrame.from_dict(
440+
expected_df = pd.DataFrame.from_dict(
394441
{
395442
"pos_end": np.array(
396443
[12771, 12771, 13374, 13389, 13395, 13413], dtype=np.int32
397444
)
398445
}
399-
),
400-
overall_df,
401-
)
446+
)
447+
448+
# NOTE: Running multiple test shows that the iterator can be reused
449+
450+
# Regions as string
451+
dfs = []
452+
for df in test_ds.read_iter(attrs=["pos_end"], regions="1:12700-13400"):
453+
dfs.append(df)
454+
overall_df = pd.concat(dfs, ignore_index=True)
455+
assert len(overall_df) == 6
456+
_check_dfs(expected_df, overall_df)
402457

403-
# Test that the iterator can be used again
458+
# Regions as list
404459
dfs = []
405460
for df in test_ds.read_iter(attrs=["pos_end"], regions=["1:12700-13400"]):
406461
dfs.append(df)
407462
overall_df = pd.concat(dfs, ignore_index=True)
463+
assert len(overall_df) == 6
464+
_check_dfs(expected_df, overall_df)
408465

466+
# Regions as numpy.ndarray
467+
dfs = []
468+
for df in test_ds.read_iter(attrs=["pos_end"], regions=np.array(["1:12700-13400"])):
469+
dfs.append(df)
470+
overall_df = pd.concat(dfs, ignore_index=True)
409471
assert len(overall_df) == 6
410-
_check_dfs(
411-
pd.DataFrame.from_dict(
412-
{
413-
"pos_end": np.array(
414-
[12771, 12771, 13374, 13389, 13395, 13413], dtype=np.int32
415-
)
416-
}
417-
),
418-
overall_df,
419-
)
472+
_check_dfs(expected_df, overall_df)
420473

421474

422475
def test_read_filters(test_ds):

0 commit comments

Comments
 (0)