Skip to content

Commit 3bf99ff

Browse files
committed
fix: rebase changes for v2
1 parent 65a9bfa commit 3bf99ff

7 files changed

Lines changed: 152 additions & 38 deletions

File tree

tests/test_package/test_log.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import numpy as np
88
import pytest
9-
from pydantic import ValidationError
9+
from pydantic import ValidationError, model_validator
1010

1111
import tidy3d as td
1212
from tidy3d.exceptions import Tidy3dError
@@ -255,6 +255,35 @@ def test_logging_warning_capture():
255255
assert str(error_without) == str(error_with)
256256

257257

258+
def test_warning_capture_during_model_validation():
259+
from tidy3d.components.base import Tidy3dBaseModel
260+
from tidy3d.log import log
261+
262+
class _CaptureChild(Tidy3dBaseModel):
263+
x: int
264+
265+
@model_validator(mode="after")
266+
def _warn_child(self):
267+
log.warning("child warning")
268+
return self
269+
270+
class _CaptureParent(Tidy3dBaseModel):
271+
child: _CaptureChild
272+
273+
@model_validator(mode="after")
274+
def _warn_parent(self):
275+
log.warning("parent warning")
276+
return self
277+
278+
td.log.set_capture(True)
279+
_CaptureParent(child={"x": 1})
280+
warning_list = td.log.captured_warnings()
281+
td.log.set_capture(False)
282+
283+
assert {"loc": [], "msg": "parent warning"} in warning_list
284+
assert {"loc": ["child"], "msg": "child warning"} in warning_list
285+
286+
258287
def test_log_suppression():
259288
with td.log as suppressed_log:
260289
assert td.log._counts is not None

tidy3d/components/base.py

Lines changed: 74 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
import rich
2323
import xarray as xr
2424
import yaml
25+
from autograd.numpy.numpy_boxes import ArrayBox
2526
from autograd.tracer import isbox
26-
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator
27+
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator, model_validator
28+
from pydantic.functional_validators import ModelWrapValidatorHandler
2729

2830
from tidy3d.exceptions import FileError
2931
from tidy3d.log import log
@@ -201,12 +203,6 @@ def _validate_name_no_special_characters(cls: type[T], name: Optional[str]) -> O
201203
)
202204
return name
203205

204-
def __init__(self, **kwargs: Any) -> None:
205-
"""Init method, includes post-init validators."""
206-
log.begin_capture()
207-
super().__init__(**kwargs)
208-
log.end_capture(self)
209-
210206
def __init_subclass__(cls: type[T], **kwargs: Any) -> None:
211207
"""Injects a constant discriminator field before Pydantic builds the model.
212208
@@ -224,6 +220,16 @@ def __init_subclass__(cls: type[T], **kwargs: Any) -> None:
224220
setattr(cls, TYPE_TAG_STR, tag)
225221
TYPE_TO_CLASS_MAP[tag] = cls
226222

223+
if "__tidy3d_end_capture__" not in cls.__dict__:
224+
225+
@model_validator(mode="after")
226+
def __tidy3d_end_capture__(self: T) -> T:
227+
if log._capture:
228+
log.end_capture(self)
229+
return self
230+
231+
cls.__tidy3d_end_capture__ = __tidy3d_end_capture__
232+
227233
super().__init_subclass__(**kwargs)
228234

229235
@classmethod
@@ -233,12 +239,69 @@ def __pydantic_init_subclass__(cls: type[T], **kwargs: Any) -> None:
233239
# add docstring once pydantic is done constructing the class
234240
cls.__doc__ = cls.generate_docstring()
235241

242+
@model_validator(mode="wrap")
243+
@classmethod
244+
def _capture_validation_warnings(
245+
cls: type[T],
246+
data: Any,
247+
handler: ModelWrapValidatorHandler[T],
248+
) -> T:
249+
if not log._capture:
250+
return handler(data)
251+
252+
log.begin_capture()
253+
try:
254+
return handler(data)
255+
except Exception:
256+
log.abort_capture()
257+
raise
258+
236259
def __hash__(self) -> int:
237260
"""Hash method."""
238-
try:
239-
return super().__hash__(self)
240-
except TypeError:
241-
return hash(self.model_dump_json())
261+
return self._recursive_hash(self)
262+
263+
@staticmethod
264+
def _recursive_hash(value: Any) -> int:
265+
# Handle Autograd ArrayBoxes
266+
if isinstance(value, ArrayBox):
267+
# Unwrap the underlying numpy array and recurse
268+
return Tidy3dBaseModel._recursive_hash(value._value)
269+
if isinstance(value, np.ndarray):
270+
# numpy arrays are not hashable by default, use byte representation
271+
v_hash = hashlib.md5(value.tobytes()).hexdigest()
272+
return hash(v_hash)
273+
if isinstance(value, (xr.DataArray, xr.Dataset)):
274+
# we choose to not hash data arrays as this would require a lot of careful handling of units, metadata.
275+
# technically this is incorrect, but should never lead to bugs in current implementation
276+
return hash(str(value.__class__.__name__))
277+
if isinstance(value, str):
278+
# this if-case is necessary because length-1 string would lead to infinite recursion in sequence case below
279+
return hash(value)
280+
if isinstance(value, Sequence):
281+
# this assumes all objects in lists are hashable by default and do not require special handling
282+
v_hash = tuple([Tidy3dBaseModel._recursive_hash(vi) for vi in value])
283+
return hash(v_hash)
284+
if isinstance(value, dict):
285+
to_hash_list = []
286+
for k, v in value.items():
287+
v_hash = Tidy3dBaseModel._recursive_hash(v)
288+
to_hash_list.append((k, v_hash))
289+
return hash(tuple(to_hash_list))
290+
if isinstance(value, Tidy3dBaseModel):
291+
# This function needs to take special care because of mutable attributes inside of frozen pydantic models
292+
to_hash_list = []
293+
for k, v in dict(value).items():
294+
if k == "attrs":
295+
continue
296+
v_hash = Tidy3dBaseModel._recursive_hash(v)
297+
to_hash_list.append((k, v_hash))
298+
# attrs is mutable, use serialized output as safe hashing option
299+
if value.attrs:
300+
attrs_str = value._attrs_digest()
301+
attrs_hash = hash(attrs_str)
302+
to_hash_list.append(("attrs", attrs_hash))
303+
return hash(tuple(to_hash_list))
304+
return hash(value)
242305

243306
def _hash_self(self) -> str:
244307
"""Hash this component with ``hashlib`` in a way that is the same every session."""

tidy3d/components/medium.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111
import autograd.numpy as np
1212
import numpy as npo
1313
from autograd.differential_operators import tensor_jacobian_product
14+
from numpy.typing import NDArray
15+
from pydantic import (
16+
Field,
17+
NonNegativeFloat,
18+
PositiveFloat,
19+
PositiveInt,
20+
field_validator,
21+
model_validator,
22+
)
1423

1524
from tidy3d.components.autograd.utils import pack_complex_vec
1625
from tidy3d.constants import (
@@ -2419,11 +2428,11 @@ def _derivative_field_cmp_custom(
24192428
E_der_map: ElectromagneticFieldDataset,
24202429
spatial_data: SpatialDataArray,
24212430
dim: str,
2422-
freqs: NDArray,
2431+
freqs: ArrayFloat,
24232432
bounds: Optional[Bound] = None,
24242433
component: str = "real",
24252434
interp_method: Optional[InterpMethod] = None,
2426-
) -> NDArray:
2435+
) -> ArrayGeneric:
24272436
"""Compute the derivative with respect to a material property component."""
24282437
param_coords = {axis: np.asarray(spatial_data.coords[axis]) for axis in "xyz"}
24292438
eps_shape = [len(param_coords[axis]) for axis in "xyz"]
@@ -2609,7 +2618,18 @@ def _interp_axis(
26092618
elif component == "real":
26102619
values = values.real
26112620

2612-
return values.sum(axis=-1).reshape(eps_shape)
2621+
vjp_array = values.sum(axis=-1).reshape(eps_shape)
2622+
2623+
# match derivative dtype to the underlying dataset
2624+
target_array = getattr(spatial_data, "values", None)
2625+
if target_array is None and hasattr(spatial_data, "data"):
2626+
target_array = spatial_data.data
2627+
if target_array is not None:
2628+
target_dtype = np.asarray(target_array).dtype
2629+
if not np.issubdtype(target_dtype, np.complexfloating):
2630+
vjp_array = np.real(vjp_array).astype(target_dtype, copy=False)
2631+
2632+
return vjp_array
26132633

26142634

26152635
""" Dispersive Media """

tidy3d/components/simulation.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4258,19 +4258,6 @@ def _validate_low_freq_smoothing(self) -> Self:
42584258
)
42594259
return self
42604260

4261-
@model_validator(mode="after")
4262-
def _validate_scene(self) -> Self:
4263-
_ = self.scene
4264-
self._validate_no_structures_pml()
4265-
self._validate_tfsf_nonuniform_grid()
4266-
self._validate_tfsf_aux_sources()
4267-
self._validate_nonlinear_specs()
4268-
self._validate_custom_source_time()
4269-
self._validate_mode_objects()
4270-
self._warn_rf_license()
4271-
self._validate_internal_abc_no_fully_anisotropic()
4272-
return self
4273-
42744261
@model_validator(mode="after")
42754262
def _warn_rf_license(self) -> Self:
42764263
"""
@@ -4545,16 +4532,18 @@ def aux_fields(self) -> list[str]:
45454532
fields += medium.nonlinear_spec.aux_fields
45464533
return fields
45474534

4548-
def _validate_internal_abc_no_fully_anisotropic(self) -> None:
4535+
@model_validator(mode="after")
4536+
def _validate_internal_abc_no_fully_anisotropic(self) -> Self:
45494537
"""Error if internal absorber intersect fully anisotropic mediums."""
45504538

45514539
total_structures = [self.scene.background_structure, *list(self.structures)]
45524540

45534541
for abc in self._shifted_internal_absorbers:
4554-
mediums = Scene.intersecting_media(abc, total_structures)
4542+
mediums = Scene.intersecting_media(abc, tuple(total_structures))
45554543

45564544
if any(isinstance(med, FullyAnisotropicMedium) for med in mediums):
45574545
raise SetupError("A 'InternalAbsorber' cannot cross a 'FullyAnisotropicMedium'.")
4546+
return self
45584547

45594548
""" Pre submit validation (before web.upload()) """
45604549

tidy3d/log.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,9 @@ class Logger:
122122
highest level of the captures messages.
123123
124124
Messages can also be captured for post-processing. That can be enabled through 'set_capture' to
125-
record all warnings emitted during model validation. A structured copy of all validation
126-
messages can then be recovered through 'captured_warnings'.
125+
record warnings emitted during model validation (and other explicit begin/end capture regions,
126+
e.g. validation routines like ``validate_pre_upload``). A structured copy of captured warnings
127+
can then be recovered through 'captured_warnings'.
127128
"""
128129

129130
_static_cache = set()
@@ -178,8 +179,8 @@ def __exit__(
178179
def begin_capture(self) -> None:
179180
"""Start capturing log stack for consolidated validation log.
180181
181-
This method is used before any model validation starts and is included in the initialization
182-
of 'BaseModel'. It must be followed by a corresponding 'end_capture'.
182+
This method should be called before a validation routine starts. It must be followed by a
183+
corresponding 'end_capture'.
183184
"""
184185
if not self._capture:
185186
return
@@ -190,11 +191,23 @@ def begin_capture(self) -> None:
190191
else:
191192
self._stack = [stack_item]
192193

194+
def abort_capture(self) -> None:
195+
"""Undo the last ``begin_capture()`` call.
196+
197+
This is used when validation fails before reaching the corresponding ``end_capture()``.
198+
"""
199+
if not self._stack:
200+
return
201+
202+
self._stack.pop()
203+
if len(self._stack) == 0:
204+
self._stack = None
205+
193206
def end_capture(self, model: BaseModel) -> None:
194207
"""End capturing log stack for consolidated validation log.
195208
196-
This method is used after all model validations and is included in the initialization of
197-
'BaseModel'. It must follow a corresponding 'begin_capture'.
209+
This method should be called after a validation routine ends. It must follow a
210+
corresponding 'begin_capture'.
198211
"""
199212
if not self._stack:
200213
return

tidy3d/plugins/expressions/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def generate_validation_data(expr: Expression) -> dict[str, xr.Dataset]:
3232
dict[str, xr.Dataset]
3333
The combined validation data.
3434
"""
35-
metrics = set(expr.filter(target_type=Metric))
35+
metrics = set(expr.filter(target_type=Metric)) # type: ignore[type-abstract]
3636
combined_data = {k: v for metric in metrics for k, v in metric._validation_data.items()}
3737
return combined_data
3838

tidy3d/plugins/invdes/region.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class TopologyDesignRegion(DesignRegion):
181181
"Supplying ``False`` will completely leave out the override structure.",
182182
)
183183

184-
priority: int = pd.Field(
184+
priority: Optional[int] = Field(
185185
None,
186186
title="Priority",
187187
description="Priority of the structure applied in structure overlapping region. "

0 commit comments

Comments
 (0)