1616from os import PathLike
1717from pathlib import Path
1818from types import UnionType
19- from typing import TYPE_CHECKING , Any , Literal , Optional , TypeVar , Union , get_args , get_origin , get_args , get_origin
19+ from typing import TYPE_CHECKING , Any , Literal , Optional , TypeVar , Union , get_args , get_origin
2020
2121import h5py
2222import numpy as np
2323import rich
2424import xarray as xr
2525import yaml
26+ from autograd .numpy .numpy_boxes import ArrayBox
2627from autograd .tracer import isbox
27- from pydantic import BaseModel , ConfigDict , Field , PrivateAttr , field_validator
28+ from pydantic import BaseModel , ConfigDict , Field , PrivateAttr , field_validator , model_validator
29+ from pydantic .functional_validators import ModelWrapValidatorHandler
2830
2931from tidy3d .exceptions import FileError
3032from tidy3d .log import log
@@ -155,8 +157,8 @@ def _fmt_ann_literal(ann: Any) -> str:
155157T = TypeVar ("T" , bound = "Tidy3dBaseModel" )
156158
157159
158- def field_allows_scalar (field : ModelField ) -> bool :
159- annotation = field .outer_type_
160+ def field_allows_scalar (field : FieldInfo ) -> bool :
161+ annotation = field .annotation
160162
161163 def allows_scalar (a : Any ) -> bool :
162164 origin = get_origin (a )
@@ -169,6 +171,7 @@ def allows_scalar(a: Any) -> bool:
169171
170172 return allows_scalar (annotation )
171173
174+
172175@total_ordering
173176class Tidy3dBaseModel (BaseModel ):
174177 """Base pydantic model that all Tidy3d components inherit from.
@@ -216,12 +219,6 @@ def _validate_name_no_special_characters(cls: type[T], name: Optional[str]) -> O
216219 )
217220 return name
218221
219- def __init__ (self , ** kwargs : Any ) -> None :
220- """Init method, includes post-init validators."""
221- log .begin_capture ()
222- super ().__init__ (** kwargs )
223- log .end_capture (self )
224-
225222 def __init_subclass__ (cls : type [T ], ** kwargs : Any ) -> None :
226223 """Injects a constant discriminator field before Pydantic builds the model.
227224
@@ -239,6 +236,16 @@ def __init_subclass__(cls: type[T], **kwargs: Any) -> None:
239236 setattr (cls , TYPE_TAG_STR , tag )
240237 TYPE_TO_CLASS_MAP [tag ] = cls
241238
239+ if "__tidy3d_end_capture__" not in cls .__dict__ :
240+
241+ @model_validator (mode = "after" )
242+ def __tidy3d_end_capture__ (self : T ) -> T :
243+ if log ._capture :
244+ log .end_capture (self )
245+ return self
246+
247+ cls .__tidy3d_end_capture__ = __tidy3d_end_capture__
248+
242249 super ().__init_subclass__ (** kwargs )
243250
244251 @classmethod
@@ -248,33 +255,95 @@ def __pydantic_init_subclass__(cls: type[T], **kwargs: Any) -> None:
248255 # add docstring once pydantic is done constructing the class
249256 cls .__doc__ = cls .generate_docstring ()
250257
258+ @model_validator (mode = "wrap" )
259+ @classmethod
260+ def _capture_validation_warnings (
261+ cls : type [T ],
262+ data : Any ,
263+ handler : ModelWrapValidatorHandler [T ],
264+ ) -> T :
265+ if not log ._capture :
266+ return handler (data )
267+
268+ log .begin_capture ()
269+ try :
270+ return handler (data )
271+ except Exception :
272+ log .abort_capture ()
273+ raise
274+
251275 def __hash__ (self ) -> int :
252276 """Hash method."""
253- try :
254- return super ().__hash__ (self )
255- except TypeError :
256- return hash (self .model_dump_json ())
277+ return self ._recursive_hash (self )
278+
279+ @staticmethod
280+ def _recursive_hash (value : Any ) -> int :
281+ # Handle Autograd ArrayBoxes
282+ if isinstance (value , ArrayBox ):
283+ # Unwrap the underlying numpy array and recurse
284+ return Tidy3dBaseModel ._recursive_hash (value ._value )
285+ if isinstance (value , np .ndarray ):
286+ # numpy arrays are not hashable by default, use byte representation
287+ v_hash = hashlib .md5 (value .tobytes ()).hexdigest ()
288+ return hash (v_hash )
289+ if isinstance (value , (xr .DataArray , xr .Dataset )):
290+ # we choose to not hash data arrays as this would require a lot of careful handling of units, metadata.
291+ # technically this is incorrect, but should never lead to bugs in current implementation
292+ return hash (str (value .__class__ .__name__ ))
293+ if isinstance (value , str ):
294+ # this if-case is necessary because length-1 string would lead to infinite recursion in sequence case below
295+ return hash (value )
296+ if isinstance (value , Sequence ):
297+ # this assumes all objects in lists are hashable by default and do not require special handling
298+ v_hash = tuple ([Tidy3dBaseModel ._recursive_hash (vi ) for vi in value ])
299+ return hash (v_hash )
300+ if isinstance (value , dict ):
301+ to_hash_list = []
302+ for k , v in value .items ():
303+ v_hash = Tidy3dBaseModel ._recursive_hash (v )
304+ to_hash_list .append ((k , v_hash ))
305+ return hash (tuple (to_hash_list ))
306+ if isinstance (value , Tidy3dBaseModel ):
307+ # This function needs to take special care because of mutable attributes inside of frozen pydantic models
308+ to_hash_list = []
309+ for k , v in dict (value ).items ():
310+ if k == "attrs" :
311+ continue
312+ v_hash = Tidy3dBaseModel ._recursive_hash (v )
313+ to_hash_list .append ((k , v_hash ))
314+ # attrs is mutable, use serialized output as safe hashing option
315+ if value .attrs :
316+ attrs_str = value ._attrs_digest ()
317+ attrs_hash = hash (attrs_str )
318+ to_hash_list .append (("attrs" , attrs_hash ))
319+ return hash (tuple (to_hash_list ))
320+ return hash (value )
257321
258322 def _hash_self (self ) -> str :
259323 """Hash this component with ``hashlib`` in a way that is the same every session."""
260324 bf = io .BytesIO ()
261325 self .to_hdf5 (bf )
262326 return hashlib .md5 (bf .getvalue ()).hexdigest ()
263327
264-
265- @pydantic . validator ( "*" , pre = True , allow_reuse = True )
266- def coerce_numpy_scalars_for_model (cls , v : Any , field : ModelField ) -> Any :
328+ @ model_validator ( mode = "before" )
329+ @classmethod
330+ def coerce_numpy_scalars_for_model (cls , data : Any ) -> Any :
267331 """
268- Wildcard field validator: coerce numpy scalars / size-1 arrays to native Python
332+ coerce numpy scalars / size-1 arrays to native Python
269333 scalars, but only for fields whose annotations allow scalars.
270334 """
271- if not field_allows_scalar (field ):
272- return v
335+ if not isinstance (data , dict ):
336+ return data
337+
338+ for name , field in cls .model_fields .items ():
339+ if name not in data or not field_allows_scalar (field ):
340+ continue
273341
274- if isinstance (v , np .generic ) or (isinstance (v , np .ndarray ) and v .size == 1 ):
275- return v .item ()
342+ v = data [name ]
343+ if isinstance (v , np .generic ) or (isinstance (v , np .ndarray ) and v .size == 1 ):
344+ data [name ] = v .item ()
276345
277- return v
346+ return data
278347
279348 @classmethod
280349 def _get_type_value (cls , obj : dict [str , Any ]) -> str :
0 commit comments