Skip to content

Commit c0951c8

Browse files
committed
passing: types
1 parent 74c326a commit c0951c8

2 files changed

Lines changed: 340 additions & 51 deletions

File tree

Lines changed: 306 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,33 @@
11
"""Tests type definitions."""
22

3+
from typing import Optional
4+
35
import numpy as np
4-
import pydantic.v1 as pydantic
56
import pytest
7+
from pydantic import ValidationError
8+
from pydantic.fields import FieldInfo
69
from tidy3d.components.base import Tidy3dBaseModel
7-
from tidy3d.components.types import ArrayLike, Complex, Tuple, constrained_array
8-
9-
10-
def _test_validate_array_like():
11-
class S(Tidy3dBaseModel):
12-
f: ArrayLike[float, 2]
13-
14-
_ = S(f=np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]))
15-
with pytest.raises(pydantic.ValidationError):
16-
_ = S(f=np.array([1.0, 2.0, 3.0]))
17-
18-
class MyClass(Tidy3dBaseModel):
19-
f: constrained_array(ndim=3, shape=(1, 2, 3))
20-
21-
with pytest.raises(pydantic.ValidationError):
22-
_ = MyClass(f=np.ones((2, 2, 3)))
23-
24-
with pytest.raises(pydantic.ValidationError):
25-
_ = MyClass(f=np.ones((1, 2, 3, 4)))
10+
from tidy3d.components.types import ArrayConstraints, ArrayLike, Complex, array_alias
2611

2712

2813
def test_schemas():
2914
class S(Tidy3dBaseModel):
3015
f: ArrayLike
31-
ca: constrained_array(ndim=1, dtype=complex)
16+
ca: array_alias(ndim=1, dtype=complex)
3217
c: Complex
3318

34-
# TODO: unexpected behavior, if list with more than one element, it fails.
35-
_ = S(f=[13], c=1 + 1j, ca=1 + 1j)
19+
_ = S(f=[13], c=1 + 1j, ca=[1 + 1j])
3620
S.schema()
3721

3822

3923
def test_array_like():
4024
class MyClass(Tidy3dBaseModel):
4125
a: ArrayLike = None # can be any array-like thing
42-
b: constrained_array(ndim=2) = None # must be 2D
43-
c: constrained_array(dtype=float) = None # must be float-like
44-
d: constrained_array(ndim=1, dtype=complex) = None # 1D complex
26+
b: array_alias(ndim=2) = None # must be 2D
27+
c: array_alias(dtype=float) = None # must be float-like
28+
d: array_alias(ndim=1, dtype=complex) = None # 1D complex
4529
e: ArrayLike
46-
f: constrained_array(ndim=3, shape=(1, 2, 3)) = None # must have certain shape
30+
f: array_alias(ndim=3, shape=(1, 2, 3)) = None # must have certain shape
4731

4832
my_obj = MyClass(
4933
a=1.0 + 2j,
@@ -62,34 +46,311 @@ class MyClass(Tidy3dBaseModel):
6246
my_obj.json()
6347

6448

65-
def test_array_like_field_name():
49+
def test_array_like_field_metadata():
6650
class MyClass(Tidy3dBaseModel):
67-
a: ArrayLike # can be any array-like thing
68-
b: constrained_array(ndim=2) # must be 2D
69-
c: constrained_array(dtype=float) # must be float-like
70-
d: constrained_array(ndim=1, dtype=complex) # 1D complex
71-
e: constrained_array(ndim=3, shape=(1, 2, 3)) # must have certain shape
72-
f: ArrayLike = None
51+
a: ArrayLike
52+
b: array_alias(ndim=2)
53+
c: array_alias(dtype=float)
54+
d: array_alias(ndim=1, dtype=complex)
55+
e: array_alias(ndim=3, shape=(1, 2, 3))
56+
f: array_alias(strict=True, forbid_nan=False, scalar_to_1d=True)
57+
g: array_alias(dtype=np.int32)
7358

7459
fields = MyClass.model_fields
7560

76-
def correct_field_display(field_name, display_name):
77-
"""Make sure the field has the expected name."""
78-
assert fields[field_name]._type_display() == display_name
61+
def check_field_constraints(
62+
field_name: str,
63+
expected_dtype_name: Optional[str],
64+
expected_ndim: Optional[int],
65+
expected_shape: Optional[tuple[int, ...]],
66+
expected_forbid_nan: bool,
67+
expected_scalar_to_1d: bool,
68+
expected_strict: bool,
69+
):
70+
field_info: FieldInfo = fields[field_name]
7971

80-
correct_field_display("a", "ArrayLike")
81-
correct_field_display("b", "ArrayLike[ndim=2]")
82-
correct_field_display("c", "ArrayLike[dtype=float]")
83-
correct_field_display("d", "ArrayLike[dtype=complex, ndim=1]")
84-
correct_field_display("e", "ArrayLike[ndim=3, shape=(1, 2, 3)]")
85-
correct_field_display("f", "Optional[ArrayLike]")
72+
constraints = None
73+
for meta_item in field_info.metadata:
74+
if isinstance(meta_item, ArrayConstraints):
75+
constraints = meta_item
76+
break
77+
78+
assert (
79+
constraints is not None
80+
), f"No ArrayConstraints found in metadata for field {field_name}. Metadata: {field_info.metadata}"
81+
82+
assert (
83+
constraints.get("dtype") == expected_dtype_name
84+
), f"Field {field_name}: dtype mismatch. Expected {expected_dtype_name!r}, got {constraints.get('dtype')!r}"
85+
assert (
86+
constraints.get("ndim") == expected_ndim
87+
), f"Field {field_name}: ndim mismatch. Expected {expected_ndim}, got {constraints.get('ndim')}"
88+
assert (
89+
constraints.get("shape") == expected_shape
90+
), f"Field {field_name}: shape mismatch. Expected {expected_shape}, got {constraints.get('shape')}"
91+
assert (
92+
constraints.get("forbid_nan") == expected_forbid_nan
93+
), f"Field {field_name}: forbid_nan mismatch. Expected {expected_forbid_nan}, got {constraints.get('forbid_nan')}"
94+
assert (
95+
constraints.get("scalar_to_1d") == expected_scalar_to_1d
96+
), f"Field {field_name}: scalar_to_1d mismatch. Expected {expected_scalar_to_1d}, got {constraints.get('scalar_to_1d')}"
97+
assert (
98+
constraints.get("strict") == expected_strict
99+
), f"Field {field_name}: strict mismatch. Expected {expected_strict}, got {constraints.get('strict')}"
100+
101+
# Default values from array_alias: forbid_nan=True, scalar_to_1d=False, strict=False
102+
check_field_constraints(
103+
"a",
104+
expected_dtype_name=None,
105+
expected_ndim=None,
106+
expected_shape=None,
107+
expected_forbid_nan=True,
108+
expected_scalar_to_1d=False,
109+
expected_strict=False,
110+
)
111+
112+
check_field_constraints(
113+
"b",
114+
expected_dtype_name=None,
115+
expected_ndim=2,
116+
expected_shape=None,
117+
expected_forbid_nan=True,
118+
expected_scalar_to_1d=False,
119+
expected_strict=False,
120+
)
121+
122+
# dtype=float -> float.__name__ is "float"
123+
check_field_constraints(
124+
"c",
125+
expected_dtype_name="float",
126+
expected_ndim=None,
127+
expected_shape=None,
128+
expected_forbid_nan=True,
129+
expected_scalar_to_1d=False,
130+
expected_strict=False,
131+
)
132+
133+
# dtype=complex -> complex.__name__ is "complex"
134+
check_field_constraints(
135+
"d",
136+
expected_dtype_name="complex",
137+
expected_ndim=1,
138+
expected_shape=None,
139+
expected_forbid_nan=True,
140+
expected_scalar_to_1d=False,
141+
expected_strict=False,
142+
)
143+
144+
check_field_constraints(
145+
"e",
146+
expected_dtype_name=None,
147+
expected_ndim=3,
148+
expected_shape=(1, 2, 3),
149+
expected_forbid_nan=True,
150+
expected_scalar_to_1d=False,
151+
expected_strict=False,
152+
)
153+
154+
# Test case for field 'f' with non-default values for flags
155+
check_field_constraints(
156+
"f",
157+
expected_dtype_name=None,
158+
expected_ndim=None,
159+
expected_shape=None,
160+
expected_forbid_nan=False,
161+
expected_scalar_to_1d=True,
162+
expected_strict=True,
163+
)
164+
165+
# Test case for field 'g' with NumPy specific dtype
166+
# dtype=np.int32 -> np.int32.__name__ is "int32"
167+
check_field_constraints(
168+
"g",
169+
expected_dtype_name="int32",
170+
expected_ndim=None,
171+
expected_shape=None,
172+
expected_forbid_nan=True,
173+
expected_scalar_to_1d=False,
174+
expected_strict=False,
175+
)
86176

87177

88178
def test_hash():
89179
class MyClass(Tidy3dBaseModel):
90180
a: ArrayLike
91-
b: constrained_array(ndim=1)
92-
c: Tuple[ArrayLike, ...]
181+
b: array_alias(ndim=1)
182+
c: tuple[ArrayLike, ...]
93183

94184
c = MyClass(a=[1.0], b=[2.0, 1.0], c=([2.0, 1.0]))
95185
hash(c.json())
186+
187+
188+
def test_array_like_validation_errors():
189+
"""Tests that appropriate ValidationErrors are raised for array constraints."""
190+
191+
# input that cannot be converted to a NumPy array at all (with specific dtype)
192+
class ModelDtypeConversionFail(Tidy3dBaseModel):
193+
a: array_alias(dtype=int)
194+
195+
with pytest.raises(ValidationError, match="cannot convert"):
196+
ModelDtypeConversionFail(a="not an int")
197+
198+
# ndim mismatch
199+
class ModelNdimMismatch(Tidy3dBaseModel):
200+
a: array_alias(ndim=1)
201+
202+
with pytest.raises(ValidationError, match="expected"):
203+
ModelNdimMismatch(a=[[1, 2], [3, 4]])
204+
205+
# ndim mismatch (scalar for ndim=1, scalar_to_1d=False by default)
206+
class ModelNdimScalarDefault(Tidy3dBaseModel):
207+
a: array_alias(ndim=1)
208+
209+
with pytest.raises(ValidationError, match="expected"):
210+
ModelNdimScalarDefault(a=5)
211+
212+
# shape mismatch
213+
class ModelShapeMismatch(Tidy3dBaseModel):
214+
a: array_alias(shape=(2, 2))
215+
216+
with pytest.raises(ValidationError, match=r"expected shape"):
217+
ModelShapeMismatch(a=[[1, 2, 3], [4, 5, 6]])
218+
219+
# forbid_nan=True (default) and array contains NaN
220+
class ModelForbidNan(Tidy3dBaseModel):
221+
a: array_alias(dtype=float)
222+
223+
with pytest.raises(ValidationError, match="array contains NaN"):
224+
ModelForbidNan(a=[1.0, np.nan, 3.0])
225+
226+
# strict=True and a scalar is provided
227+
class ModelStrictScalar(Tidy3dBaseModel):
228+
a: array_alias(strict=True)
229+
230+
with pytest.raises(ValidationError, match="strict mode"):
231+
ModelStrictScalar(a=10)
232+
233+
# input results in an array with dtype=object
234+
class ModelObjectDtype(Tidy3dBaseModel):
235+
a: ArrayLike
236+
237+
with pytest.raises(ValidationError, match=r"unsupported element type"):
238+
ModelObjectDtype(a=[1, "string", object()])
239+
240+
# general conversion failure for an unhandled type
241+
class ModelGeneralConversionFail(Tidy3dBaseModel):
242+
a: ArrayLike
243+
244+
class UnconvertibleObject:
245+
pass
246+
247+
with pytest.raises(ValidationError, match="unsupported element type"):
248+
ModelGeneralConversionFail(a=UnconvertibleObject())
249+
250+
# _from_complex_dict receives a dict it doesn't understand, passes it to _coerce,
251+
# which then fails because dict becomes an object array or direct conversion fails
252+
class ModelComplexInvalidDict(Tidy3dBaseModel):
253+
a: array_alias(dtype=complex)
254+
255+
with pytest.raises(ValidationError, match=r"cannot convert"):
256+
ModelComplexInvalidDict(a={"real_part": 1, "imag_part": 2})
257+
258+
# scalar_to_1d=True with ndim=1 successfully converts scalar
259+
class ModelScalarTo1DSuccess(Tidy3dBaseModel):
260+
a: array_alias(ndim=1, scalar_to_1d=True)
261+
262+
obj_s21d = ModelScalarTo1DSuccess(a=5.0)
263+
assert np.array_equal(obj_s21d.a, np.array([5.0]))
264+
assert obj_s21d.a.ndim == 1
265+
266+
# scalar_to_1d=True but ndim is incompatible with 1D array (e.g. ndim=2)
267+
class ModelScalarTo1DWrongNdim(Tidy3dBaseModel):
268+
a: array_alias(ndim=2, scalar_to_1d=True, dtype=float)
269+
270+
with pytest.raises(ValidationError, match="expected"):
271+
ModelScalarTo1DWrongNdim(a=5.0)
272+
273+
# strict=True takes precedence over scalar_to_1d=True if input is scalar
274+
class ModelStrictAndScalarTo1D(Tidy3dBaseModel):
275+
a: array_alias(strict=True, scalar_to_1d=True, dtype=float)
276+
277+
with pytest.raises(ValidationError, match="strict mode"):
278+
ModelStrictAndScalarTo1D(a=5.0)
279+
280+
# allow NaN when forbid_nan=False
281+
class ModelAllowNan(Tidy3dBaseModel):
282+
a: array_alias(dtype=float, forbid_nan=False)
283+
284+
obj_allow_nan = ModelAllowNan(a=[1.0, np.nan])
285+
assert np.array_equal(obj_allow_nan.a, np.array([1.0, np.nan]), equal_nan=True)
286+
287+
# strict=False (default) allows non-array if it can be coerced
288+
class ModelStrictFalseCoercion(Tidy3dBaseModel):
289+
a: array_alias(dtype=int, ndim=1)
290+
291+
# should pass because [1.0, 2.0] can be coerced to np.array([1,2]) of dtype int, ndim 1
292+
obj_sf_coerce = ModelStrictFalseCoercion(a=[1.0, 2.0])
293+
assert np.array_equal(obj_sf_coerce.a, np.array([1, 2]))
294+
assert obj_sf_coerce.a.dtype == np.dtype(int)
295+
assert obj_sf_coerce.a.ndim == 1
296+
297+
# scalar_to_1d=False (default), ndim=None, scalar input -> 0D array
298+
class ModelScalarTo0D(Tidy3dBaseModel):
299+
a: array_alias(scalar_to_1d=False)
300+
301+
obj_s0d = ModelScalarTo0D(a=10)
302+
assert np.array_equal(obj_s0d.a, np.array(10))
303+
assert obj_s0d.a.ndim == 0
304+
305+
# scalar_to_1d=True, ndim=None, scalar input -> 1D array
306+
class ModelScalarTo1DNoNdim(Tidy3dBaseModel):
307+
a: array_alias(scalar_to_1d=True)
308+
309+
obj_s1d_no_ndim = ModelScalarTo1DNoNdim(a=10)
310+
assert np.array_equal(obj_s1d_no_ndim.a, np.array([10])), obj_s1d_no_ndim.a
311+
assert obj_s1d_no_ndim.a.ndim == 1
312+
313+
314+
def test_complex_type():
315+
"""Tests the Complex type for parsing and serialization."""
316+
317+
class ComplexModel(Tidy3dBaseModel):
318+
val: Complex
319+
320+
inputs = [
321+
(1 + 2j, 1 + 2j),
322+
({"real": 3, "imag": -4}, 3 - 4j),
323+
({"real": 3.5, "imag": 0}, 3.5 + 0j),
324+
(5, 5 + 0j), # int
325+
(6.7, 6.7 + 0j), # float
326+
(True, 1 + 0j), # bool (subclass of int, numbers.Number)
327+
(np.float32(2.5), 2.5 + 0j), # numpy float
328+
(np.int64(-3), -3 + 0j), # numpy int
329+
([10, -2], 10 - 2j), # list of two numbers
330+
((0.5, 1.5), 0.5 + 1.5j), # tuple of two numbers
331+
]
332+
333+
class ObjWithComplexMethod:
334+
def __complex__(self):
335+
return -1 - 1j
336+
337+
class ObjWithComplexMethodNumeric:
338+
def __init__(self, val):
339+
self._val = val
340+
341+
def __complex__(self):
342+
return self._val
343+
344+
inputs.append((ObjWithComplexMethod(), -1 - 1j))
345+
inputs.append((ObjWithComplexMethodNumeric(3 + 7j), 3 + 7j))
346+
inputs.append((ObjWithComplexMethodNumeric(5), 5 + 0j))
347+
348+
for input_val, expected_complex in inputs:
349+
model = ComplexModel(val=input_val)
350+
assert model.val == expected_complex, f"Input: {input_val}"
351+
assert isinstance(model.val, complex), f"Input: {input_val}"
352+
353+
expected_json_val = {"real": expected_complex.real, "imag": expected_complex.imag}
354+
assert (
355+
model.model_dump(mode="json")["val"] == expected_json_val
356+
), f"Input for serialization: {input_val}"

0 commit comments

Comments
 (0)