2222import rich
2323import xarray as xr
2424import yaml
25+ from autograd .numpy .numpy_boxes import ArrayBox
2526from autograd .tracer import isbox
26- from pydantic import BaseModel , ConfigDict , Field , PrivateAttr , field_validator
27+ from pydantic import BaseModel , ConfigDict , Field , PrivateAttr , field_validator , model_validator
28+ from pydantic .functional_validators import ModelWrapValidatorHandler
2729
2830from tidy3d .exceptions import FileError
2931from tidy3d .log import log
@@ -201,12 +203,6 @@ def _validate_name_no_special_characters(cls: type[T], name: Optional[str]) -> O
201203 )
202204 return name
203205
204- def __init__ (self , ** kwargs : Any ) -> None :
205- """Init method, includes post-init validators."""
206- log .begin_capture ()
207- super ().__init__ (** kwargs )
208- log .end_capture (self )
209-
210206 def __init_subclass__ (cls : type [T ], ** kwargs : Any ) -> None :
211207 """Injects a constant discriminator field before Pydantic builds the model.
212208
@@ -224,6 +220,16 @@ def __init_subclass__(cls: type[T], **kwargs: Any) -> None:
224220 setattr (cls , TYPE_TAG_STR , tag )
225221 TYPE_TO_CLASS_MAP [tag ] = cls
226222
223+ if "__tidy3d_end_capture__" not in cls .__dict__ :
224+
225+ @model_validator (mode = "after" )
226+ def __tidy3d_end_capture__ (self : T ) -> T :
227+ if log ._capture :
228+ log .end_capture (self )
229+ return self
230+
231+ cls .__tidy3d_end_capture__ = __tidy3d_end_capture__
232+
227233 super ().__init_subclass__ (** kwargs )
228234
229235 @classmethod
@@ -233,12 +239,69 @@ def __pydantic_init_subclass__(cls: type[T], **kwargs: Any) -> None:
233239 # add docstring once pydantic is done constructing the class
234240 cls .__doc__ = cls .generate_docstring ()
235241
242+ @model_validator (mode = "wrap" )
243+ @classmethod
244+ def _capture_validation_warnings (
245+ cls : type [T ],
246+ data : Any ,
247+ handler : ModelWrapValidatorHandler [T ],
248+ ) -> T :
249+ if not log ._capture :
250+ return handler (data )
251+
252+ log .begin_capture ()
253+ try :
254+ return handler (data )
255+ except Exception :
256+ log .abort_capture ()
257+ raise
258+
236259 def __hash__ (self ) -> int :
237260 """Hash method."""
238- try :
239- return super ().__hash__ (self )
240- except TypeError :
241- return hash (self .model_dump_json ())
261+ return self ._recursive_hash (self )
262+
263+ @staticmethod
264+ def _recursive_hash (value : Any ) -> int :
265+ # Handle Autograd ArrayBoxes
266+ if isinstance (value , ArrayBox ):
267+ # Unwrap the underlying numpy array and recurse
268+ return Tidy3dBaseModel ._recursive_hash (value ._value )
269+ if isinstance (value , np .ndarray ):
270+ # numpy arrays are not hashable by default, use byte representation
271+ v_hash = hashlib .md5 (value .tobytes ()).hexdigest ()
272+ return hash (v_hash )
273+ if isinstance (value , (xr .DataArray , xr .Dataset )):
274+ # we choose to not hash data arrays as this would require a lot of careful handling of units, metadata.
275+ # technically this is incorrect, but should never lead to bugs in current implementation
276+ return hash (str (value .__class__ .__name__ ))
277+ if isinstance (value , str ):
278+ # this if-case is necessary because length-1 string would lead to infinite recursion in sequence case below
279+ return hash (value )
280+ if isinstance (value , Sequence ):
281+ # this assumes all objects in lists are hashable by default and do not require special handling
282+ v_hash = tuple ([Tidy3dBaseModel ._recursive_hash (vi ) for vi in value ])
283+ return hash (v_hash )
284+ if isinstance (value , dict ):
285+ to_hash_list = []
286+ for k , v in value .items ():
287+ v_hash = Tidy3dBaseModel ._recursive_hash (v )
288+ to_hash_list .append ((k , v_hash ))
289+ return hash (tuple (to_hash_list ))
290+ if isinstance (value , Tidy3dBaseModel ):
291+ # This function needs to take special care because of mutable attributes inside of frozen pydantic models
292+ to_hash_list = []
293+ for k , v in dict (value ).items ():
294+ if k == "attrs" :
295+ continue
296+ v_hash = Tidy3dBaseModel ._recursive_hash (v )
297+ to_hash_list .append ((k , v_hash ))
298+ # attrs is mutable, use serialized output as safe hashing option
299+ if value .attrs :
300+ attrs_str = value ._attrs_digest ()
301+ attrs_hash = hash (attrs_str )
302+ to_hash_list .append (("attrs" , attrs_hash ))
303+ return hash (tuple (to_hash_list ))
304+ return hash (value )
242305
243306 def _hash_self (self ) -> str :
244307 """Hash this component with ``hashlib`` in a way that is the same every session."""
0 commit comments