Skip to content

Commit 15ebfa6

Browse files
committed
make examples stand-alone and testable via script dependency modification at test time
1 parent 893540f commit 15ebfa6

3 files changed

Lines changed: 170 additions & 18 deletions

File tree

examples/custom_dtype.py

Lines changed: 93 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# /// script
22
# requires-python = ">=3.11"
33
# dependencies = [
4-
# "zarr @ {root}",
4+
# "zarr @ git+https://github.com/zarr-developers/zarr-python.git@main",
55
# "ml_dtypes==0.5.1",
66
# "pytest==8.4.1"
77
# ]
@@ -18,7 +18,7 @@
1818
import json
1919
import sys
2020
from pathlib import Path
21-
from typing import ClassVar, Literal, Self, TypeGuard
21+
from typing import ClassVar, Literal, Self, TypeGuard, overload
2222

2323
import ml_dtypes # necessary to add extra dtypes to NumPy
2424
import numpy as np
@@ -34,14 +34,17 @@
3434
check_dtype_spec_v2,
3535
)
3636

37+
# This is the int2 array data type
3738
int2_dtype_cls = type(np.dtype("int2"))
39+
40+
# This is the int2 scalar type
3841
int2_scalar_cls = ml_dtypes.int2
3942

4043

4144
class Int2(ZDType[int2_dtype_cls, int2_scalar_cls]):
4245
"""
43-
This class provides a Zarr compatibility layer around the int2 data type and the int2
44-
scalar type.
46+
This class provides a Zarr compatibility layer around the int2 data type (the ``dtype`` of a
47+
NumPy array of type int2) and the int2 scalar type (the ``dtype`` of the scalar value inside an int2 array).
4548
"""
4649

4750
# This field is as the key for the data type in the internal data type registry, and also
@@ -68,72 +71,140 @@ def to_native_dtype(self: Self) -> int2_dtype_cls:
6871

6972
@classmethod
7073
def _check_json_v2(cls, data: DTypeJSON) -> TypeGuard[DTypeConfig_V2[Literal["|b1"], None]]:
71-
"""Type check for Zarr v2-flavored JSON"""
74+
"""
75+
Type check for Zarr v2-flavored JSON.
76+
77+
This will check that the input is a dict like this:
78+
.. code-block:: json
79+
80+
{
81+
"name": "int2",
82+
"object_codec_id": None
83+
}
84+
85+
Note that this representation differs from the ``dtype`` field looks like in zarr v2 metadata.
86+
Specifically, whatever goes into the ``dtype`` field in metadata is assigned to the ``name`` field here.
87+
88+
See the Zarr docs for more information about the JSON encoding for data types.
89+
"""
7290
return (
7391
check_dtype_spec_v2(data) and data["name"] == "int2" and data["object_codec_id"] is None
7492
)
7593

7694
@classmethod
7795
def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[Literal["int2"]]:
78-
"""Type check for Zarr v3-flavored JSON"""
96+
"""
97+
Type check for Zarr V3-flavored JSON.
98+
99+
Checks that the input is the string "int2".
100+
"""
79101
return data == cls._zarr_v3_name
80102

81103
@classmethod
82104
def _from_json_v2(cls, data: DTypeJSON) -> Self:
83105
"""
84-
Create an instance of this ZDType from zarr v3-flavored JSON.
106+
Create an instance of this ZDType from Zarr V3-flavored JSON.
85107
"""
86108
if cls._check_json_v2(data):
87109
return cls()
110+
# This first does a type check on the input, and if that passes we create an instance of the ZDType.
88111
msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v2_name!r}"
89112
raise DataTypeValidationError(msg)
90113

91114
@classmethod
92115
def _from_json_v3(cls: type[Self], data: DTypeJSON) -> Self:
93116
"""
94-
Create an instance of this ZDType from zarr v3-flavored JSON.
117+
Create an instance of this ZDType from Zarr V3-flavored JSON.
118+
119+
This first does a type check on the input, and if that passes we create an instance of the ZDType.
95120
"""
96121
if cls._check_json_v3(data):
97122
return cls()
98123
msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v3_name!r}"
99124
raise DataTypeValidationError(msg)
100125

126+
@overload # type: ignore[override]
127+
def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[Literal["int2"], None]: ...
128+
129+
@overload
130+
def to_json(self, zarr_format: Literal[3]) -> Literal["int2"]: ...
131+
101132
def to_json(
102133
self, zarr_format: ZarrFormat
103134
) -> DTypeConfig_V2[Literal["int2"], None] | Literal["int2"]:
104-
"""Serialize this ZDType to v2- or v3-flavored JSON"""
135+
"""
136+
Serialize this ZDType to v2- or v3-flavored JSON
137+
138+
If the zarr_format is 2, then return a dict like this:
139+
.. code-block:: json
140+
141+
{
142+
"name": "int2",
143+
"object_codec_id": None
144+
}
145+
146+
If the zarr_format is 3, then return the string "int2"
147+
148+
"""
105149
if zarr_format == 2:
106150
return {"name": "int2", "object_codec_id": None}
107151
elif zarr_format == 3:
108152
return self._zarr_v3_name
109153
raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover
110154

111-
def _check_scalar(self, data: object) -> TypeGuard[int]:
112-
"""Check if a python object is a valid scalar"""
155+
def _check_scalar(self, data: object) -> TypeGuard[int | ml_dtypes.int2]:
156+
"""
157+
Check if a python object is a valid int2-compatible scalar
158+
159+
The strictness of this type check is an implementation degree of freedom.
160+
You could be strict here, and only accept int2 values, or be open and accept any integer
161+
or any object and rely on exceptions from the int2 constructor that will be called in
162+
cast_scalar.
163+
"""
113164
return isinstance(data, (int, int2_scalar_cls))
114165

115166
def cast_scalar(self, data: object) -> ml_dtypes.int2:
116167
"""
117-
Attempt to cast a python object to an int2. Might fail pending a type check.
168+
Attempt to cast a python object to an int2.
169+
170+
We first perform a type check to ensure that the input type is appropriate, and if that
171+
passes we call the int2 scalar constructor.
118172
"""
119173
if self._check_scalar(data):
120174
return ml_dtypes.int2(data)
121175
msg = f"Cannot convert object with type {type(data)} to a 2-bit integer."
122176
raise TypeError(msg)
123177

124178
def default_scalar(self) -> ml_dtypes.int2:
125-
"""Get the default scalar value"""
179+
"""
180+
Get the default scalar value. This will be used when automatically selecting a fill value.
181+
"""
126182
return ml_dtypes.int2(0)
127183

128184
def to_json_scalar(self, data: object, *, zarr_format: ZarrFormat) -> int:
129-
"""Convert a python object to a scalar."""
130-
return int(data)
185+
"""
186+
Convert a python object to a JSON representation of an int2 scalar.
187+
This is necessary for taking user input for the ``fill_value`` attribute in array metadata.
188+
189+
In this implementation, we optimistically convert the input to an int,
190+
and then check that it lies in the acceptable range for this data type.
191+
"""
192+
# We could add a type check here, but we don't need to for this example
193+
val: int = int(data) # type: ignore[call-overload]
194+
if val not in (-2, -1, 0, 1):
195+
raise ValueError("Invalid value. Expected -2, -1, 0, or 1.")
196+
return val
131197

132198
def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> ml_dtypes.int2:
133199
"""
134-
Read a JSON-serializable value as a scalar. The base definition of this method
135-
requires that it take a zarr_format parameter, because some data types serialize scalars
136-
differently in zarr v2 and v3
200+
Read a JSON-serializable value as an int2 scalar.
201+
202+
We first perform a type check to ensure that the JSON value is well-formed, then call the
203+
int2 scalar constructor.
204+
205+
The base definition of this method requires that it take a zarr_format parameter because
206+
other data types serialize scalars differently in zarr v2 and v3, but we don't use this here.
207+
137208
"""
138209
if self._check_scalar(data):
139210
return ml_dtypes.int2(data)
@@ -167,4 +238,8 @@ def test_custom_dtype(tmp_path: Path, zarr_format: Literal[2, 3]) -> None:
167238

168239

169240
if __name__ == "__main__":
241+
# Run the example with printed output, and a dummy pytest configuration file specified.
242+
# Without the dummy configuration file, at test time pytest will attempt to use the
243+
# configuration file in the project root, which will error because Zarr is using some
244+
# plugins that are not installed in this example.
170245
sys.exit(pytest.main(["-s", __file__, f"-c {__file__}"]))

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ test = [
8080
"mypy",
8181
"hypothesis",
8282
"pytest-xdist",
83+
"packaging",
84+
"tomlkit",
85+
"uv"
8386
]
8487
remote_tests = [
8588
'zarr[remote]',

tests/test_examples.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from __future__ import annotations
2+
3+
import re
4+
import subprocess
5+
from pathlib import Path
6+
from typing import Final
7+
8+
import pytest
9+
import tomlkit
10+
from packaging.requirements import Requirement
11+
12+
examples_dir = "examples"
13+
script_paths = Path(examples_dir).glob("*.py")
14+
15+
PEP_723_REGEX: Final = r"(?m)^# /// (?P<type>[a-zA-Z0-9-]+)$\s(?P<content>(^#(| .*)$\s)+)^# ///$"
16+
17+
# This is the absolute path to the local Zarr installation. Moving this test to a different directory will break it.
18+
ZARR_PROJECT_PATH = Path(".").absolute()
19+
20+
21+
def set_dep(script: str, dependency: str) -> str:
22+
"""
23+
Set a dependency in a PEP-723 script header.
24+
If the package is already in the list, it will be replaced.
25+
If the package is not already in the list, it will be added.
26+
27+
Source code modified from
28+
https://packaging.python.org/en/latest/specifications/inline-script-metadata/#reference-implementation
29+
"""
30+
match = re.search(PEP_723_REGEX, script)
31+
32+
if match is None:
33+
raise ValueError(f"PEP-723 header not found in {script}")
34+
35+
content = "".join(
36+
line[2:] if line.startswith("# ") else line[1:]
37+
for line in match.group("content").splitlines(keepends=True)
38+
)
39+
40+
config = tomlkit.parse(content)
41+
for idx, dep in enumerate(tuple(config["dependencies"])):
42+
if Requirement(dep).name == Requirement(dependency).name:
43+
config["dependencies"][idx] = dependency
44+
45+
new_content = "".join(
46+
f"# {line}" if line.strip() else f"#{line}"
47+
for line in tomlkit.dumps(config).splitlines(keepends=True)
48+
)
49+
50+
start, end = match.span("content")
51+
return script[:start] + new_content + script[end:]
52+
53+
54+
def resave_script(source_path: Path, dest_path: Path) -> None:
55+
"""
56+
Read a script from source_path and save it to dest_path after inserting the absolute path to the
57+
local Zarr project directory in the PEP-723 header.
58+
"""
59+
source_text = source_path.read_text()
60+
dest_text = set_dep(source_text, f"zarr @ file:///{ZARR_PROJECT_PATH}")
61+
dest_path.write_text(dest_text)
62+
63+
64+
@pytest.mark.parametrize("script_path", script_paths)
65+
def test_scripts_can_run(script_path: Path, tmp_path: Path) -> None:
66+
dest_path = tmp_path / script_path.name
67+
# We resave the script after inserting the absolute path to the local Zarr project directory,
68+
# and then test its behavior.
69+
# This allows the example to be useful to users who don't have Zarr installed, but also testable.
70+
resave_script(script_path, dest_path)
71+
result = subprocess.run(["uv", "run", str(dest_path)], capture_output=True, text=True)
72+
assert result.returncode == 0, (
73+
f"Script at {script_path} failed to run. Output: {result.stdout} Error: {result.stderr}"
74+
)

0 commit comments

Comments
 (0)