11"""Tests type definitions."""
22
3+ from typing import Optional
4+
35import numpy as np
4- import pydantic .v1 as pydantic
56import pytest
7+ from pydantic import ValidationError
8+ from pydantic .fields import FieldInfo
69from tidy3d .components .base import Tidy3dBaseModel
7- from tidy3d .components .types import ArrayLike , Complex , Tuple , constrained_array
8-
9-
10- def _test_validate_array_like ():
11- class S (Tidy3dBaseModel ):
12- f : ArrayLike [float , 2 ]
13-
14- _ = S (f = np .array ([[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ]]))
15- with pytest .raises (pydantic .ValidationError ):
16- _ = S (f = np .array ([1.0 , 2.0 , 3.0 ]))
17-
18- class MyClass (Tidy3dBaseModel ):
19- f : constrained_array (ndim = 3 , shape = (1 , 2 , 3 ))
20-
21- with pytest .raises (pydantic .ValidationError ):
22- _ = MyClass (f = np .ones ((2 , 2 , 3 )))
23-
24- with pytest .raises (pydantic .ValidationError ):
25- _ = MyClass (f = np .ones ((1 , 2 , 3 , 4 )))
10+ from tidy3d .components .types import ArrayConstraints , ArrayLike , Complex , array_alias
2611
2712
2813def test_schemas ():
2914 class S (Tidy3dBaseModel ):
3015 f : ArrayLike
31- ca : constrained_array (ndim = 1 , dtype = complex )
16+ ca : array_alias (ndim = 1 , dtype = complex )
3217 c : Complex
3318
34- # TODO: unexpected behavior, if list with more than one element, it fails.
35- _ = S (f = [13 ], c = 1 + 1j , ca = 1 + 1j )
19+ _ = S (f = [13 ], c = 1 + 1j , ca = [1 + 1j ])
3620 S .schema ()
3721
3822
3923def test_array_like ():
4024 class MyClass (Tidy3dBaseModel ):
4125 a : ArrayLike = None # can be any array-like thing
42- b : constrained_array (ndim = 2 ) = None # must be 2D
43- c : constrained_array (dtype = float ) = None # must be float-like
44- d : constrained_array (ndim = 1 , dtype = complex ) = None # 1D complex
26+ b : array_alias (ndim = 2 ) = None # must be 2D
27+ c : array_alias (dtype = float ) = None # must be float-like
28+ d : array_alias (ndim = 1 , dtype = complex ) = None # 1D complex
4529 e : ArrayLike
46- f : constrained_array (ndim = 3 , shape = (1 , 2 , 3 )) = None # must have certain shape
30+ f : array_alias (ndim = 3 , shape = (1 , 2 , 3 )) = None # must have certain shape
4731
4832 my_obj = MyClass (
4933 a = 1.0 + 2j ,
@@ -62,34 +46,311 @@ class MyClass(Tidy3dBaseModel):
6246 my_obj .json ()
6347
6448
65- def test_array_like_field_name ():
49+ def test_array_like_field_metadata ():
6650 class MyClass (Tidy3dBaseModel ):
67- a : ArrayLike # can be any array-like thing
68- b : constrained_array (ndim = 2 ) # must be 2D
69- c : constrained_array (dtype = float ) # must be float-like
70- d : constrained_array (ndim = 1 , dtype = complex ) # 1D complex
71- e : constrained_array (ndim = 3 , shape = (1 , 2 , 3 )) # must have certain shape
72- f : ArrayLike = None
51+ a : ArrayLike
52+ b : array_alias (ndim = 2 )
53+ c : array_alias (dtype = float )
54+ d : array_alias (ndim = 1 , dtype = complex )
55+ e : array_alias (ndim = 3 , shape = (1 , 2 , 3 ))
56+ f : array_alias (strict = True , forbid_nan = False , scalar_to_1d = True )
57+ g : array_alias (dtype = np .int32 )
7358
7459 fields = MyClass .model_fields
7560
76- def correct_field_display (field_name , display_name ):
77- """Make sure the field has the expected name."""
78- assert fields [field_name ]._type_display () == display_name
61+ def check_field_constraints (
62+ field_name : str ,
63+ expected_dtype_name : Optional [str ],
64+ expected_ndim : Optional [int ],
65+ expected_shape : Optional [tuple [int , ...]],
66+ expected_forbid_nan : bool ,
67+ expected_scalar_to_1d : bool ,
68+ expected_strict : bool ,
69+ ):
70+ field_info : FieldInfo = fields [field_name ]
7971
80- correct_field_display ("a" , "ArrayLike" )
81- correct_field_display ("b" , "ArrayLike[ndim=2]" )
82- correct_field_display ("c" , "ArrayLike[dtype=float]" )
83- correct_field_display ("d" , "ArrayLike[dtype=complex, ndim=1]" )
84- correct_field_display ("e" , "ArrayLike[ndim=3, shape=(1, 2, 3)]" )
85- correct_field_display ("f" , "Optional[ArrayLike]" )
72+ constraints = None
73+ for meta_item in field_info .metadata :
74+ if isinstance (meta_item , ArrayConstraints ):
75+ constraints = meta_item
76+ break
77+
78+ assert (
79+ constraints is not None
80+ ), f"No ArrayConstraints found in metadata for field { field_name } . Metadata: { field_info .metadata } "
81+
82+ assert (
83+ constraints .get ("dtype" ) == expected_dtype_name
84+ ), f"Field { field_name } : dtype mismatch. Expected { expected_dtype_name !r} , got { constraints .get ('dtype' )!r} "
85+ assert (
86+ constraints .get ("ndim" ) == expected_ndim
87+ ), f"Field { field_name } : ndim mismatch. Expected { expected_ndim } , got { constraints .get ('ndim' )} "
88+ assert (
89+ constraints .get ("shape" ) == expected_shape
90+ ), f"Field { field_name } : shape mismatch. Expected { expected_shape } , got { constraints .get ('shape' )} "
91+ assert (
92+ constraints .get ("forbid_nan" ) == expected_forbid_nan
93+ ), f"Field { field_name } : forbid_nan mismatch. Expected { expected_forbid_nan } , got { constraints .get ('forbid_nan' )} "
94+ assert (
95+ constraints .get ("scalar_to_1d" ) == expected_scalar_to_1d
96+ ), f"Field { field_name } : scalar_to_1d mismatch. Expected { expected_scalar_to_1d } , got { constraints .get ('scalar_to_1d' )} "
97+ assert (
98+ constraints .get ("strict" ) == expected_strict
99+ ), f"Field { field_name } : strict mismatch. Expected { expected_strict } , got { constraints .get ('strict' )} "
100+
101+ # Default values from array_alias: forbid_nan=True, scalar_to_1d=False, strict=False
102+ check_field_constraints (
103+ "a" ,
104+ expected_dtype_name = None ,
105+ expected_ndim = None ,
106+ expected_shape = None ,
107+ expected_forbid_nan = True ,
108+ expected_scalar_to_1d = False ,
109+ expected_strict = False ,
110+ )
111+
112+ check_field_constraints (
113+ "b" ,
114+ expected_dtype_name = None ,
115+ expected_ndim = 2 ,
116+ expected_shape = None ,
117+ expected_forbid_nan = True ,
118+ expected_scalar_to_1d = False ,
119+ expected_strict = False ,
120+ )
121+
122+ # dtype=float -> float.__name__ is "float"
123+ check_field_constraints (
124+ "c" ,
125+ expected_dtype_name = "float" ,
126+ expected_ndim = None ,
127+ expected_shape = None ,
128+ expected_forbid_nan = True ,
129+ expected_scalar_to_1d = False ,
130+ expected_strict = False ,
131+ )
132+
133+ # dtype=complex -> complex.__name__ is "complex"
134+ check_field_constraints (
135+ "d" ,
136+ expected_dtype_name = "complex" ,
137+ expected_ndim = 1 ,
138+ expected_shape = None ,
139+ expected_forbid_nan = True ,
140+ expected_scalar_to_1d = False ,
141+ expected_strict = False ,
142+ )
143+
144+ check_field_constraints (
145+ "e" ,
146+ expected_dtype_name = None ,
147+ expected_ndim = 3 ,
148+ expected_shape = (1 , 2 , 3 ),
149+ expected_forbid_nan = True ,
150+ expected_scalar_to_1d = False ,
151+ expected_strict = False ,
152+ )
153+
154+ # Test case for field 'f' with non-default values for flags
155+ check_field_constraints (
156+ "f" ,
157+ expected_dtype_name = None ,
158+ expected_ndim = None ,
159+ expected_shape = None ,
160+ expected_forbid_nan = False ,
161+ expected_scalar_to_1d = True ,
162+ expected_strict = True ,
163+ )
164+
165+ # Test case for field 'g' with NumPy specific dtype
166+ # dtype=np.int32 -> np.int32.__name__ is "int32"
167+ check_field_constraints (
168+ "g" ,
169+ expected_dtype_name = "int32" ,
170+ expected_ndim = None ,
171+ expected_shape = None ,
172+ expected_forbid_nan = True ,
173+ expected_scalar_to_1d = False ,
174+ expected_strict = False ,
175+ )
86176
87177
88178def test_hash ():
89179 class MyClass (Tidy3dBaseModel ):
90180 a : ArrayLike
91- b : constrained_array (ndim = 1 )
92- c : Tuple [ArrayLike , ...]
181+ b : array_alias (ndim = 1 )
182+ c : tuple [ArrayLike , ...]
93183
94184 c = MyClass (a = [1.0 ], b = [2.0 , 1.0 ], c = ([2.0 , 1.0 ]))
95185 hash (c .json ())
186+
187+
188+ def test_array_like_validation_errors ():
189+ """Tests that appropriate ValidationErrors are raised for array constraints."""
190+
191+ # input that cannot be converted to a NumPy array at all (with specific dtype)
192+ class ModelDtypeConversionFail (Tidy3dBaseModel ):
193+ a : array_alias (dtype = int )
194+
195+ with pytest .raises (ValidationError , match = "cannot convert" ):
196+ ModelDtypeConversionFail (a = "not an int" )
197+
198+ # ndim mismatch
199+ class ModelNdimMismatch (Tidy3dBaseModel ):
200+ a : array_alias (ndim = 1 )
201+
202+ with pytest .raises (ValidationError , match = "expected" ):
203+ ModelNdimMismatch (a = [[1 , 2 ], [3 , 4 ]])
204+
205+ # ndim mismatch (scalar for ndim=1, scalar_to_1d=False by default)
206+ class ModelNdimScalarDefault (Tidy3dBaseModel ):
207+ a : array_alias (ndim = 1 )
208+
209+ with pytest .raises (ValidationError , match = "expected" ):
210+ ModelNdimScalarDefault (a = 5 )
211+
212+ # shape mismatch
213+ class ModelShapeMismatch (Tidy3dBaseModel ):
214+ a : array_alias (shape = (2 , 2 ))
215+
216+ with pytest .raises (ValidationError , match = r"expected shape" ):
217+ ModelShapeMismatch (a = [[1 , 2 , 3 ], [4 , 5 , 6 ]])
218+
219+ # forbid_nan=True (default) and array contains NaN
220+ class ModelForbidNan (Tidy3dBaseModel ):
221+ a : array_alias (dtype = float )
222+
223+ with pytest .raises (ValidationError , match = "array contains NaN" ):
224+ ModelForbidNan (a = [1.0 , np .nan , 3.0 ])
225+
226+ # strict=True and a scalar is provided
227+ class ModelStrictScalar (Tidy3dBaseModel ):
228+ a : array_alias (strict = True )
229+
230+ with pytest .raises (ValidationError , match = "strict mode" ):
231+ ModelStrictScalar (a = 10 )
232+
233+ # input results in an array with dtype=object
234+ class ModelObjectDtype (Tidy3dBaseModel ):
235+ a : ArrayLike
236+
237+ with pytest .raises (ValidationError , match = r"unsupported element type" ):
238+ ModelObjectDtype (a = [1 , "string" , object ()])
239+
240+ # general conversion failure for an unhandled type
241+ class ModelGeneralConversionFail (Tidy3dBaseModel ):
242+ a : ArrayLike
243+
244+ class UnconvertibleObject :
245+ pass
246+
247+ with pytest .raises (ValidationError , match = "unsupported element type" ):
248+ ModelGeneralConversionFail (a = UnconvertibleObject ())
249+
250+ # _from_complex_dict receives a dict it doesn't understand, passes it to _coerce,
251+ # which then fails because dict becomes an object array or direct conversion fails
252+ class ModelComplexInvalidDict (Tidy3dBaseModel ):
253+ a : array_alias (dtype = complex )
254+
255+ with pytest .raises (ValidationError , match = r"cannot convert" ):
256+ ModelComplexInvalidDict (a = {"real_part" : 1 , "imag_part" : 2 })
257+
258+ # scalar_to_1d=True with ndim=1 successfully converts scalar
259+ class ModelScalarTo1DSuccess (Tidy3dBaseModel ):
260+ a : array_alias (ndim = 1 , scalar_to_1d = True )
261+
262+ obj_s21d = ModelScalarTo1DSuccess (a = 5.0 )
263+ assert np .array_equal (obj_s21d .a , np .array ([5.0 ]))
264+ assert obj_s21d .a .ndim == 1
265+
266+ # scalar_to_1d=True but ndim is incompatible with 1D array (e.g. ndim=2)
267+ class ModelScalarTo1DWrongNdim (Tidy3dBaseModel ):
268+ a : array_alias (ndim = 2 , scalar_to_1d = True , dtype = float )
269+
270+ with pytest .raises (ValidationError , match = "expected" ):
271+ ModelScalarTo1DWrongNdim (a = 5.0 )
272+
273+ # strict=True takes precedence over scalar_to_1d=True if input is scalar
274+ class ModelStrictAndScalarTo1D (Tidy3dBaseModel ):
275+ a : array_alias (strict = True , scalar_to_1d = True , dtype = float )
276+
277+ with pytest .raises (ValidationError , match = "strict mode" ):
278+ ModelStrictAndScalarTo1D (a = 5.0 )
279+
280+ # allow NaN when forbid_nan=False
281+ class ModelAllowNan (Tidy3dBaseModel ):
282+ a : array_alias (dtype = float , forbid_nan = False )
283+
284+ obj_allow_nan = ModelAllowNan (a = [1.0 , np .nan ])
285+ assert np .array_equal (obj_allow_nan .a , np .array ([1.0 , np .nan ]), equal_nan = True )
286+
287+ # strict=False (default) allows non-array if it can be coerced
288+ class ModelStrictFalseCoercion (Tidy3dBaseModel ):
289+ a : array_alias (dtype = int , ndim = 1 )
290+
291+ # should pass because [1.0, 2.0] can be coerced to np.array([1,2]) of dtype int, ndim 1
292+ obj_sf_coerce = ModelStrictFalseCoercion (a = [1.0 , 2.0 ])
293+ assert np .array_equal (obj_sf_coerce .a , np .array ([1 , 2 ]))
294+ assert obj_sf_coerce .a .dtype == np .dtype (int )
295+ assert obj_sf_coerce .a .ndim == 1
296+
297+ # scalar_to_1d=False (default), ndim=None, scalar input -> 0D array
298+ class ModelScalarTo0D (Tidy3dBaseModel ):
299+ a : array_alias (scalar_to_1d = False )
300+
301+ obj_s0d = ModelScalarTo0D (a = 10 )
302+ assert np .array_equal (obj_s0d .a , np .array (10 ))
303+ assert obj_s0d .a .ndim == 0
304+
305+ # scalar_to_1d=True, ndim=None, scalar input -> 1D array
306+ class ModelScalarTo1DNoNdim (Tidy3dBaseModel ):
307+ a : array_alias (scalar_to_1d = True )
308+
309+ obj_s1d_no_ndim = ModelScalarTo1DNoNdim (a = 10 )
310+ assert np .array_equal (obj_s1d_no_ndim .a , np .array ([10 ])), obj_s1d_no_ndim .a
311+ assert obj_s1d_no_ndim .a .ndim == 1
312+
313+
314+ def test_complex_type ():
315+ """Tests the Complex type for parsing and serialization."""
316+
317+ class ComplexModel (Tidy3dBaseModel ):
318+ val : Complex
319+
320+ inputs = [
321+ (1 + 2j , 1 + 2j ),
322+ ({"real" : 3 , "imag" : - 4 }, 3 - 4j ),
323+ ({"real" : 3.5 , "imag" : 0 }, 3.5 + 0j ),
324+ (5 , 5 + 0j ), # int
325+ (6.7 , 6.7 + 0j ), # float
326+ (True , 1 + 0j ), # bool (subclass of int, numbers.Number)
327+ (np .float32 (2.5 ), 2.5 + 0j ), # numpy float
328+ (np .int64 (- 3 ), - 3 + 0j ), # numpy int
329+ ([10 , - 2 ], 10 - 2j ), # list of two numbers
330+ ((0.5 , 1.5 ), 0.5 + 1.5j ), # tuple of two numbers
331+ ]
332+
333+ class ObjWithComplexMethod :
334+ def __complex__ (self ):
335+ return - 1 - 1j
336+
337+ class ObjWithComplexMethodNumeric :
338+ def __init__ (self , val ):
339+ self ._val = val
340+
341+ def __complex__ (self ):
342+ return self ._val
343+
344+ inputs .append ((ObjWithComplexMethod (), - 1 - 1j ))
345+ inputs .append ((ObjWithComplexMethodNumeric (3 + 7j ), 3 + 7j ))
346+ inputs .append ((ObjWithComplexMethodNumeric (5 ), 5 + 0j ))
347+
348+ for input_val , expected_complex in inputs :
349+ model = ComplexModel (val = input_val )
350+ assert model .val == expected_complex , f"Input: { input_val } "
351+ assert isinstance (model .val , complex ), f"Input: { input_val } "
352+
353+ expected_json_val = {"real" : expected_complex .real , "imag" : expected_complex .imag }
354+ assert (
355+ model .model_dump (mode = "json" )["val" ] == expected_json_val
356+ ), f"Input for serialization: { input_val } "
0 commit comments