Skip to content

Commit 4b50f88

Browse files
tomwhitejeromekelleher
authored andcommitted
Allow string dtype to be specified using StringDType() in create_empty_group_array
1 parent 4ef9b4e commit 4b50f88

2 files changed

Lines changed: 15 additions & 5 deletions

File tree

bio2zarr/zarr_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import tqdm
99
import zarr
10+
from numpy.dtypes import StringDType
1011
from zarr.codecs.blosc import BloscCodec, BloscShuffle
1112

1213
logger = logging.getLogger(__name__)
@@ -169,7 +170,7 @@ def create_empty_group_array(
169170
codecs = []
170171
if filters is not None:
171172
codecs = [numcodecs.get_codec(f) for f in filters]
172-
if dtype == STRING_DTYPE_NAME:
173+
if dtype == STRING_DTYPE_NAME or dtype == StringDType():
173174
codecs.append(numcodecs.VLenUTF8())
174175
if len(codecs) == 0:
175176
codecs = None

tests/test_zarr_utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy.testing as nt
77
import pytest
88
import zarr
9+
from numpy.dtypes import StringDType
910
from zarr.codecs.blosc import BloscCodec, BloscShuffle
1011

1112
from bio2zarr import zarr_utils
@@ -442,12 +443,15 @@ def test_user_filters_non_string_v2(self, tmp_path, monkeypatch):
442443
assert len(a.metadata.filters) == 1
443444
assert isinstance(a.metadata.filters[0], numcodecs.Delta)
444445

445-
def test_string_dtype_no_user_filters(self, group, zarr_format):
446+
@pytest.mark.parametrize(
447+
"string_dtype", [zarr_utils.STRING_DTYPE_NAME, StringDType()]
448+
)
449+
def test_string_dtype_no_user_filters(self, group, zarr_format, string_dtype):
446450
a = zarr_utils.create_empty_group_array(
447451
group,
448452
"x",
449453
shape=(4,),
450-
dtype=zarr_utils.STRING_DTYPE_NAME,
454+
dtype=string_dtype,
451455
chunks=(2,),
452456
compressor=zarr_utils.DEFAULT_COMPRESSOR_CONFIG,
453457
)
@@ -457,7 +461,12 @@ def test_string_dtype_no_user_filters(self, group, zarr_format):
457461
a[:] = ["a", "b", "c", "d"]
458462
nt.assert_array_equal(a[:], ["a", "b", "c", "d"])
459463

460-
def test_string_dtype_with_user_filters_v2(self, tmp_path, monkeypatch):
464+
@pytest.mark.parametrize(
465+
"string_dtype", [zarr_utils.STRING_DTYPE_NAME, StringDType()]
466+
)
467+
def test_string_dtype_with_user_filters_v2(
468+
self, tmp_path, monkeypatch, string_dtype
469+
):
461470
monkeypatch.setattr(zarr_utils, "ZARR_FORMAT", 2)
462471
root = zarr.open_group(tmp_path / "s", mode="w", zarr_format=2)
463472
# Place a VLenUTF8 in the user filter list; the helper should
@@ -468,7 +477,7 @@ def test_string_dtype_with_user_filters_v2(self, tmp_path, monkeypatch):
468477
root,
469478
"x",
470479
shape=(4,),
471-
dtype=zarr_utils.STRING_DTYPE_NAME,
480+
dtype=string_dtype,
472481
chunks=(2,),
473482
compressor=zarr_utils.DEFAULT_COMPRESSOR_CONFIG,
474483
filters=[{"id": "vlen-utf8"}],

0 commit comments

Comments
 (0)