Skip to content

Commit d7da3d9

Browse files
committed
add vlen-bytes
1 parent 3991406 commit d7da3d9

4 files changed

Lines changed: 29 additions & 23 deletions

File tree

src/zarr/core/dtype/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Structured,
1313
)
1414
from zarr.core.dtype.npy.time import DateTime64, TimeDelta64
15+
from zarr.core.dtype.npy.vlen_bytes import VariableLengthBytes
1516

1617
if TYPE_CHECKING:
1718
from zarr.core.common import ZarrFormat
@@ -88,6 +89,7 @@
8889
| FixedLengthBytes
8990
| Structured
9091
| TimeDType
92+
| VariableLengthBytes
9193
)
9294
# mypy has trouble inferring the type of variablelengthstring dtype, because its class definition
9395
# depends on the installed numpy version. That's why the type: ignore statement is needed here.
@@ -100,6 +102,7 @@
100102
FixedLengthBytes,
101103
Structured,
102104
*TIME_DTYPE,
105+
VariableLengthBytes,
103106
)
104107

105108
# This type models inputs that can be coerced to a ZDType

src/zarr/core/dtype/npy/vlen_bytes.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1+
import base64
12
from dataclasses import dataclass
23
from typing import ClassVar, Literal, Self, TypeGuard, overload
34

45
import numpy as np
56

67
from zarr.core.common import JSON, ZarrFormat
7-
from zarr.core.dtype.common import HasObjectCodec
8-
from zarr.core.dtype.wrapper import TBaseDType, ZDType
8+
from zarr.core.dtype.common import HasObjectCodec, v3_unstable_dtype_warning
9+
from zarr.core.dtype.npy.common import check_json_str
10+
from zarr.core.dtype.wrapper import DTypeJSON_V2, DTypeJSON_V3, TBaseDType, ZDType
911

1012

1113
@dataclass(frozen=True, kw_only=True)
12-
class VariableLengthString(ZDType[np.dtypes.ObjectDType, str], HasObjectCodec): # type: ignore[no-redef]
14+
class VariableLengthBytes(ZDType[np.dtypes.ObjectDType, bytes], HasObjectCodec):
1315
dtype_cls = np.dtypes.ObjectDType
1416
_zarr_v3_name: ClassVar[Literal["variable_length_bytes"]] = "variable_length_bytes"
1517
object_codec_id = "vlen-bytes"
@@ -39,12 +41,13 @@ def check_json_v3(cls, data: JSON) -> TypeGuard[Literal["variable_length_utf8"]]
3941
def to_json(self, zarr_format: Literal[2]) -> Literal["|O"]: ...
4042

4143
@overload
42-
def to_json(self, zarr_format: Literal[3]) -> Literal["variable_length_utf8"]: ...
44+
def to_json(self, zarr_format: Literal[3]) -> Literal["variable_length_bytes"]: ...
4345

44-
def to_json(self, zarr_format: ZarrFormat) -> Literal["|O", "variable_length_utf8"]:
46+
def to_json(self, zarr_format: ZarrFormat) -> Literal["|O", "variable_length_bytes"]:
4547
if zarr_format == 2:
4648
return "|O"
4749
elif zarr_format == 3:
50+
v3_unstable_dtype_warning(self)
4851
return self._zarr_v3_name
4952
raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover
5053

@@ -54,22 +57,19 @@ def _from_json_unchecked(
5457
) -> Self:
5558
return cls()
5659

57-
def default_scalar(self) -> str:
58-
return ""
60+
def default_scalar(self) -> bytes:
61+
return b""
5962

6063
def to_json_scalar(self, data: object, *, zarr_format: ZarrFormat) -> str:
61-
return data # type: ignore[return-value]
64+
return base64.standard_b64encode(data).decode("ascii") # type: ignore[arg-type]
6265

63-
def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> str:
64-
"""
65-
Strings pass through
66-
"""
67-
if not check_json_str(data):
68-
raise TypeError(f"Invalid type: {data}. Expected a string.")
69-
return data
66+
def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> bytes:
67+
if check_json_str(data):
68+
return base64.standard_b64decode(data.encode("ascii"))
69+
raise TypeError(f"Invalid type: {data}. Expected a string.") # pragma: no cover
7070

7171
def check_scalar(self, data: object) -> bool:
72-
return isinstance(data, str)
72+
return isinstance(data, bytes | str)
7373

74-
def _cast_scalar_unchecked(self, data: object) -> str:
75-
return str(data)
74+
def _cast_scalar_unchecked(self, data: object) -> bytes:
75+
return bytes(data) # type: ignore[no-any-return, call-overload]

src/zarr/core/dtype/wrapper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646
# This the upper bound for the scalar types we support. It's numpy scalars + str,
4747
# because the new variable-length string dtype in numpy does not have a corresponding scalar type
48-
TBaseScalar = np.generic | str
48+
TBaseScalar = np.generic | str | bytes
4949
# This is the bound for the dtypes that we support. If we support non-numpy dtypes,
5050
# then this bound will need to be widened.
5151
TBaseDType = np.dtype[np.generic]
@@ -174,8 +174,8 @@ def cast_scalar(self, data: object) -> TScalar_co:
174174
if self.check_scalar(data):
175175
return self._cast_scalar_unchecked(data)
176176
msg = (
177-
f"The value {data} failed a type check. "
178-
f"It cannot be safely cast to a scalar compatible with {self.dtype_cls}. "
177+
f"The value {data!r} failed a type check. "
178+
f"It cannot be safely cast to a scalar compatible with {self}. "
179179
f"Consult the documentation for {self} to determine the possible values that can "
180180
"be cast to scalars of the wrapped data type."
181181
)

tests/test_regression/test_regression.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from zarr.core.array import Array
1414
from zarr.core.chunk_key_encodings import V2ChunkKeyEncoding
1515
from zarr.core.dtype.npy.string import VariableLengthString
16+
from zarr.core.dtype.npy.vlen_bytes import VariableLengthBytes
1617
from zarr.storage import LocalStore
1718

1819
if TYPE_CHECKING:
@@ -33,7 +34,7 @@ def runner_installed() -> bool:
3334
@dataclass(kw_only=True)
3435
class ArrayParams:
3536
values: np.ndarray[tuple[int], np.dtype[np.generic]]
36-
fill_value: np.generic | str | int
37+
fill_value: np.generic | str | int | bytes
3738
filters: tuple[numcodecs.abc.Codec, ...] = ()
3839
compressor: numcodecs.abc.Codec
3940

@@ -92,8 +93,10 @@ def source_array(tmp_path: Path, request: pytest.FixtureRequest) -> Array:
9293
compressor = array_params.compressor
9394
chunk_key_encoding = V2ChunkKeyEncoding(separator="/")
9495
dtype: ZDTypeLike
95-
if array_params.values.dtype == np.dtype("|O"):
96+
if array_params.values.dtype == np.dtype("|O") and array_params.filters == (VLenUTF8(),):
9697
dtype = VariableLengthString() # type: ignore[assignment]
98+
elif array_params.values.dtype == np.dtype("|O") and array_params.filters == (VLenBytes(),):
99+
dtype = VariableLengthBytes()
97100
else:
98101
dtype = array_params.values.dtype
99102
z = zarr.create_array(

0 commit comments

Comments
 (0)