11# /// script
22# requires-python = ">=3.11"
33# dependencies = [
4- # "zarr @ {root} ",
4+ # "zarr @ git+https://github.com/zarr-developers/zarr-python.git@main ",
55# "ml_dtypes==0.5.1",
66# "pytest==8.4.1"
77# ]
1818import json
1919import sys
2020from pathlib import Path
21- from typing import ClassVar , Literal , Self , TypeGuard
21+ from typing import ClassVar , Literal , Self , TypeGuard , overload
2222
2323import ml_dtypes # necessary to add extra dtypes to NumPy
2424import numpy as np
3434 check_dtype_spec_v2 ,
3535)
3636
37+ # This is the int2 array data type
3738int2_dtype_cls = type (np .dtype ("int2" ))
39+
40+ # This is the int2 scalar type
3841int2_scalar_cls = ml_dtypes .int2
3942
4043
4144class Int2 (ZDType [int2_dtype_cls , int2_scalar_cls ]):
4245 """
43- This class provides a Zarr compatibility layer around the int2 data type and the int2
44- scalar type.
46+ This class provides a Zarr compatibility layer around the int2 data type ( the ``dtype`` of a
47+ NumPy array of type int2) and the int2 scalar type (the ``dtype`` of the scalar value inside an int2 array) .
4548 """
4649
4750 # This field is as the key for the data type in the internal data type registry, and also
@@ -68,72 +71,140 @@ def to_native_dtype(self: Self) -> int2_dtype_cls:
6871
6972 @classmethod
7073 def _check_json_v2 (cls , data : DTypeJSON ) -> TypeGuard [DTypeConfig_V2 [Literal ["|b1" ], None ]]:
71- """Type check for Zarr v2-flavored JSON"""
74+ """
75+ Type check for Zarr v2-flavored JSON.
76+
77+ This will check that the input is a dict like this:
78+ .. code-block:: json
79+
80+ {
81+ "name": "int2",
82+ "object_codec_id": None
83+ }
84+
85+ Note that this representation differs from the ``dtype`` field looks like in zarr v2 metadata.
86+ Specifically, whatever goes into the ``dtype`` field in metadata is assigned to the ``name`` field here.
87+
88+ See the Zarr docs for more information about the JSON encoding for data types.
89+ """
7290 return (
7391 check_dtype_spec_v2 (data ) and data ["name" ] == "int2" and data ["object_codec_id" ] is None
7492 )
7593
7694 @classmethod
7795 def _check_json_v3 (cls , data : DTypeJSON ) -> TypeGuard [Literal ["int2" ]]:
78- """Type check for Zarr v3-flavored JSON"""
96+ """
97+ Type check for Zarr V3-flavored JSON.
98+
99+ Checks that the input is the string "int2".
100+ """
79101 return data == cls ._zarr_v3_name
80102
81103 @classmethod
82104 def _from_json_v2 (cls , data : DTypeJSON ) -> Self :
83105 """
84- Create an instance of this ZDType from zarr v3 -flavored JSON.
106+ Create an instance of this ZDType from Zarr V3 -flavored JSON.
85107 """
86108 if cls ._check_json_v2 (data ):
87109 return cls ()
110+ # This first does a type check on the input, and if that passes we create an instance of the ZDType.
88111 msg = f"Invalid JSON representation of { cls .__name__ } . Got { data !r} , expected the string { cls ._zarr_v2_name !r} "
89112 raise DataTypeValidationError (msg )
90113
91114 @classmethod
92115 def _from_json_v3 (cls : type [Self ], data : DTypeJSON ) -> Self :
93116 """
94- Create an instance of this ZDType from zarr v3-flavored JSON.
117+ Create an instance of this ZDType from Zarr V3-flavored JSON.
118+
119+ This first does a type check on the input, and if that passes we create an instance of the ZDType.
95120 """
96121 if cls ._check_json_v3 (data ):
97122 return cls ()
98123 msg = f"Invalid JSON representation of { cls .__name__ } . Got { data !r} , expected the string { cls ._zarr_v3_name !r} "
99124 raise DataTypeValidationError (msg )
100125
126+ @overload # type: ignore[override]
127+ def to_json (self , zarr_format : Literal [2 ]) -> DTypeConfig_V2 [Literal ["int2" ], None ]: ...
128+
129+ @overload
130+ def to_json (self , zarr_format : Literal [3 ]) -> Literal ["int2" ]: ...
131+
101132 def to_json (
102133 self , zarr_format : ZarrFormat
103134 ) -> DTypeConfig_V2 [Literal ["int2" ], None ] | Literal ["int2" ]:
104- """Serialize this ZDType to v2- or v3-flavored JSON"""
135+ """
136+ Serialize this ZDType to v2- or v3-flavored JSON
137+
138+ If the zarr_format is 2, then return a dict like this:
139+ .. code-block:: json
140+
141+ {
142+ "name": "int2",
143+ "object_codec_id": None
144+ }
145+
146+ If the zarr_format is 3, then return the string "int2"
147+
148+ """
105149 if zarr_format == 2 :
106150 return {"name" : "int2" , "object_codec_id" : None }
107151 elif zarr_format == 3 :
108152 return self ._zarr_v3_name
109153 raise ValueError (f"zarr_format must be 2 or 3, got { zarr_format } " ) # pragma: no cover
110154
111- def _check_scalar (self , data : object ) -> TypeGuard [int ]:
112- """Check if a python object is a valid scalar"""
155+ def _check_scalar (self , data : object ) -> TypeGuard [int | ml_dtypes .int2 ]:
156+ """
157+ Check if a python object is a valid int2-compatible scalar
158+
159+ The strictness of this type check is an implementation degree of freedom.
160+ You could be strict here, and only accept int2 values, or be open and accept any integer
161+ or any object and rely on exceptions from the int2 constructor that will be called in
162+ cast_scalar.
163+ """
113164 return isinstance (data , (int , int2_scalar_cls ))
114165
115166 def cast_scalar (self , data : object ) -> ml_dtypes .int2 :
116167 """
117- Attempt to cast a python object to an int2. Might fail pending a type check.
168+ Attempt to cast a python object to an int2.
169+
170+ We first perform a type check to ensure that the input type is appropriate, and if that
171+ passes we call the int2 scalar constructor.
118172 """
119173 if self ._check_scalar (data ):
120174 return ml_dtypes .int2 (data )
121175 msg = f"Cannot convert object with type { type (data )} to a 2-bit integer."
122176 raise TypeError (msg )
123177
124178 def default_scalar (self ) -> ml_dtypes .int2 :
125- """Get the default scalar value"""
179+ """
180+ Get the default scalar value. This will be used when automatically selecting a fill value.
181+ """
126182 return ml_dtypes .int2 (0 )
127183
128184 def to_json_scalar (self , data : object , * , zarr_format : ZarrFormat ) -> int :
129- """Convert a python object to a scalar."""
130- return int (data )
185+ """
186+ Convert a python object to a JSON representation of an int2 scalar.
187+ This is necessary for taking user input for the ``fill_value`` attribute in array metadata.
188+
189+ In this implementation, we optimistically convert the input to an int,
190+ and then check that it lies in the acceptable range for this data type.
191+ """
192+ # We could add a type check here, but we don't need to for this example
193+ val : int = int (data ) # type: ignore[call-overload]
194+ if val not in (- 2 , - 1 , 0 , 1 ):
195+ raise ValueError ("Invalid value. Expected -2, -1, 0, or 1." )
196+ return val
131197
132198 def from_json_scalar (self , data : JSON , * , zarr_format : ZarrFormat ) -> ml_dtypes .int2 :
133199 """
134- Read a JSON-serializable value as a scalar. The base definition of this method
135- requires that it take a zarr_format parameter, because some data types serialize scalars
136- differently in zarr v2 and v3
200+ Read a JSON-serializable value as an int2 scalar.
201+
202+ We first perform a type check to ensure that the JSON value is well-formed, then call the
203+ int2 scalar constructor.
204+
205+ The base definition of this method requires that it take a zarr_format parameter because
206+ other data types serialize scalars differently in zarr v2 and v3, but we don't use this here.
207+
137208 """
138209 if self ._check_scalar (data ):
139210 return ml_dtypes .int2 (data )
@@ -167,4 +238,8 @@ def test_custom_dtype(tmp_path: Path, zarr_format: Literal[2, 3]) -> None:
167238
168239
169240if __name__ == "__main__" :
241+ # Run the example with printed output, and a dummy pytest configuration file specified.
242+ # Without the dummy configuration file, at test time pytest will attempt to use the
243+ # configuration file in the project root, which will error because Zarr is using some
244+ # plugins that are not installed in this example.
170245 sys .exit (pytest .main (["-s" , __file__ , f"-c { __file__ } " ]))
0 commit comments