Skip to content

Commit 7619bba

Browse files
committed
fix: rebase changes for v2
1 parent 514ebe7 commit 7619bba

8 files changed

Lines changed: 172 additions & 52 deletions

File tree

tests/test_package/test_log.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import numpy as np
88
import pytest
9-
from pydantic import ValidationError
9+
from pydantic import ValidationError, model_validator
1010

1111
import tidy3d as td
1212
from tidy3d.exceptions import Tidy3dError
@@ -255,6 +255,35 @@ def test_logging_warning_capture():
255255
assert str(error_without) == str(error_with)
256256

257257

258+
def test_warning_capture_during_model_validation():
259+
from tidy3d.components.base import Tidy3dBaseModel
260+
from tidy3d.log import log
261+
262+
class _CaptureChild(Tidy3dBaseModel):
263+
x: int
264+
265+
@model_validator(mode="after")
266+
def _warn_child(self):
267+
log.warning("child warning")
268+
return self
269+
270+
class _CaptureParent(Tidy3dBaseModel):
271+
child: _CaptureChild
272+
273+
@model_validator(mode="after")
274+
def _warn_parent(self):
275+
log.warning("parent warning")
276+
return self
277+
278+
td.log.set_capture(True)
279+
_CaptureParent(child={"x": 1})
280+
warning_list = td.log.captured_warnings()
281+
td.log.set_capture(False)
282+
283+
assert {"loc": [], "msg": "parent warning"} in warning_list
284+
assert {"loc": ["child"], "msg": "child warning"} in warning_list
285+
286+
258287
def test_log_suppression():
259288
with td.log as suppressed_log:
260289
assert td.log._counts is not None

tidy3d/components/autograd/functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def add_at(x: NDArray, indices_x: tuple, y: NDArray) -> NDArray:
262262

263263

264264
@primitive
265-
def _straight_through_clip(x, a_min, a_max):
265+
def _straight_through_clip(x: NDArray, a_min: Any, a_max: Any) -> NDArray:
266266
"""Passthrough clip can be used to preserve gradients at the endpoints of the clip range where
267267
there is a discontinuity in the derivative. This is useful when values are at the endpoints but may
268268
have a gradient away from the boundary or in cases where numerical precision causes a function that is
@@ -271,7 +271,7 @@ def _straight_through_clip(x, a_min, a_max):
271271
return anp.clip(x, a_min=a_min, a_max=a_max)
272272

273273

274-
def _straight_through_clip_vjp(ans, x, a_min, a_max):
274+
def _straight_through_clip_vjp(ans: Any, x: NDArray, a_min: Any, a_max: Any) -> NDArray:
275275
"""Preserve original gradient information in the backward pass up until a tolerance beyond the clip bounds."""
276276
tolerance = 1e-5
277277
mask = (x >= a_min - tolerance) & (x <= a_max + tolerance)

tidy3d/components/base.py

Lines changed: 92 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,17 @@
1616
from os import PathLike
1717
from pathlib import Path
1818
from types import UnionType
19-
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union, get_args, get_origin, get_args, get_origin
19+
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union, get_args, get_origin
2020

2121
import h5py
2222
import numpy as np
2323
import rich
2424
import xarray as xr
2525
import yaml
26+
from autograd.numpy.numpy_boxes import ArrayBox
2627
from autograd.tracer import isbox
27-
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator
28+
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator, model_validator
29+
from pydantic.functional_validators import ModelWrapValidatorHandler
2830

2931
from tidy3d.exceptions import FileError
3032
from tidy3d.log import log
@@ -155,8 +157,8 @@ def _fmt_ann_literal(ann: Any) -> str:
155157
T = TypeVar("T", bound="Tidy3dBaseModel")
156158

157159

158-
def field_allows_scalar(field: ModelField) -> bool:
159-
annotation = field.outer_type_
160+
def field_allows_scalar(field: FieldInfo) -> bool:
161+
annotation = field.annotation
160162

161163
def allows_scalar(a: Any) -> bool:
162164
origin = get_origin(a)
@@ -169,6 +171,7 @@ def allows_scalar(a: Any) -> bool:
169171

170172
return allows_scalar(annotation)
171173

174+
172175
@total_ordering
173176
class Tidy3dBaseModel(BaseModel):
174177
"""Base pydantic model that all Tidy3d components inherit from.
@@ -216,12 +219,6 @@ def _validate_name_no_special_characters(cls: type[T], name: Optional[str]) -> O
216219
)
217220
return name
218221

219-
def __init__(self, **kwargs: Any) -> None:
220-
"""Init method, includes post-init validators."""
221-
log.begin_capture()
222-
super().__init__(**kwargs)
223-
log.end_capture(self)
224-
225222
def __init_subclass__(cls: type[T], **kwargs: Any) -> None:
226223
"""Injects a constant discriminator field before Pydantic builds the model.
227224
@@ -239,6 +236,16 @@ def __init_subclass__(cls: type[T], **kwargs: Any) -> None:
239236
setattr(cls, TYPE_TAG_STR, tag)
240237
TYPE_TO_CLASS_MAP[tag] = cls
241238

239+
if "__tidy3d_end_capture__" not in cls.__dict__:
240+
241+
@model_validator(mode="after")
242+
def __tidy3d_end_capture__(self: T) -> T:
243+
if log._capture:
244+
log.end_capture(self)
245+
return self
246+
247+
cls.__tidy3d_end_capture__ = __tidy3d_end_capture__
248+
242249
super().__init_subclass__(**kwargs)
243250

244251
@classmethod
@@ -248,33 +255,95 @@ def __pydantic_init_subclass__(cls: type[T], **kwargs: Any) -> None:
248255
# add docstring once pydantic is done constructing the class
249256
cls.__doc__ = cls.generate_docstring()
250257

258+
@model_validator(mode="wrap")
259+
@classmethod
260+
def _capture_validation_warnings(
261+
cls: type[T],
262+
data: Any,
263+
handler: ModelWrapValidatorHandler[T],
264+
) -> T:
265+
if not log._capture:
266+
return handler(data)
267+
268+
log.begin_capture()
269+
try:
270+
return handler(data)
271+
except Exception:
272+
log.abort_capture()
273+
raise
274+
251275
def __hash__(self) -> int:
252276
"""Hash method."""
253-
try:
254-
return super().__hash__(self)
255-
except TypeError:
256-
return hash(self.model_dump_json())
277+
return self._recursive_hash(self)
278+
279+
@staticmethod
280+
def _recursive_hash(value: Any) -> int:
281+
# Handle Autograd ArrayBoxes
282+
if isinstance(value, ArrayBox):
283+
# Unwrap the underlying numpy array and recurse
284+
return Tidy3dBaseModel._recursive_hash(value._value)
285+
if isinstance(value, np.ndarray):
286+
# numpy arrays are not hashable by default, use byte representation
287+
v_hash = hashlib.md5(value.tobytes()).hexdigest()
288+
return hash(v_hash)
289+
if isinstance(value, (xr.DataArray, xr.Dataset)):
290+
# we choose to not hash data arrays as this would require a lot of careful handling of units, metadata.
291+
# technically this is incorrect, but should never lead to bugs in current implementation
292+
return hash(str(value.__class__.__name__))
293+
if isinstance(value, str):
294+
# this if-case is necessary because length-1 string would lead to infinite recursion in sequence case below
295+
return hash(value)
296+
if isinstance(value, Sequence):
297+
# this assumes all objects in lists are hashable by default and do not require special handling
298+
v_hash = tuple([Tidy3dBaseModel._recursive_hash(vi) for vi in value])
299+
return hash(v_hash)
300+
if isinstance(value, dict):
301+
to_hash_list = []
302+
for k, v in value.items():
303+
v_hash = Tidy3dBaseModel._recursive_hash(v)
304+
to_hash_list.append((k, v_hash))
305+
return hash(tuple(to_hash_list))
306+
if isinstance(value, Tidy3dBaseModel):
307+
# This function needs to take special care because of mutable attributes inside of frozen pydantic models
308+
to_hash_list = []
309+
for k, v in dict(value).items():
310+
if k == "attrs":
311+
continue
312+
v_hash = Tidy3dBaseModel._recursive_hash(v)
313+
to_hash_list.append((k, v_hash))
314+
# attrs is mutable, use serialized output as safe hashing option
315+
if value.attrs:
316+
attrs_str = value._attrs_digest()
317+
attrs_hash = hash(attrs_str)
318+
to_hash_list.append(("attrs", attrs_hash))
319+
return hash(tuple(to_hash_list))
320+
return hash(value)
257321

258322
def _hash_self(self) -> str:
259323
"""Hash this component with ``hashlib`` in a way that is the same every session."""
260324
bf = io.BytesIO()
261325
self.to_hdf5(bf)
262326
return hashlib.md5(bf.getvalue()).hexdigest()
263327

264-
265-
@pydantic.validator("*", pre=True, allow_reuse=True)
266-
def coerce_numpy_scalars_for_model(cls, v: Any, field: ModelField) -> Any:
328+
@model_validator(mode="before")
329+
@classmethod
330+
def coerce_numpy_scalars_for_model(cls, data: Any) -> Any:
267331
"""
268-
Wildcard field validator: coerce numpy scalars / size-1 arrays to native Python
332+
coerce numpy scalars / size-1 arrays to native Python
269333
scalars, but only for fields whose annotations allow scalars.
270334
"""
271-
if not field_allows_scalar(field):
272-
return v
335+
if not isinstance(data, dict):
336+
return data
337+
338+
for name, field in cls.model_fields.items():
339+
if name not in data or not field_allows_scalar(field):
340+
continue
273341

274-
if isinstance(v, np.generic) or (isinstance(v, np.ndarray) and v.size == 1):
275-
return v.item()
342+
v = data[name]
343+
if isinstance(v, np.generic) or (isinstance(v, np.ndarray) and v.size == 1):
344+
data[name] = v.item()
276345

277-
return v
346+
return data
278347

279348
@classmethod
280349
def _get_type_value(cls, obj: dict[str, Any]) -> str:

tidy3d/components/medium.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111
import autograd.numpy as np
1212
import numpy as npo
1313
from autograd.differential_operators import tensor_jacobian_product
14+
from numpy.typing import NDArray
15+
from pydantic import (
16+
Field,
17+
NonNegativeFloat,
18+
PositiveFloat,
19+
PositiveInt,
20+
field_validator,
21+
model_validator,
22+
)
1423

1524
from tidy3d.components.autograd.utils import pack_complex_vec
1625
from tidy3d.constants import (
@@ -2419,11 +2428,11 @@ def _derivative_field_cmp_custom(
24192428
E_der_map: ElectromagneticFieldDataset,
24202429
spatial_data: SpatialDataArray,
24212430
dim: str,
2422-
freqs: NDArray,
2431+
freqs: ArrayFloat,
24232432
bounds: Optional[Bound] = None,
24242433
component: str = "real",
24252434
interp_method: Optional[InterpMethod] = None,
2426-
) -> NDArray:
2435+
) -> ArrayGeneric:
24272436
"""Compute the derivative with respect to a material property component."""
24282437
param_coords = {axis: np.asarray(spatial_data.coords[axis]) for axis in "xyz"}
24292438
eps_shape = [len(param_coords[axis]) for axis in "xyz"]
@@ -2609,7 +2618,18 @@ def _interp_axis(
26092618
elif component == "real":
26102619
values = values.real
26112620

2612-
return values.sum(axis=-1).reshape(eps_shape)
2621+
vjp_array = values.sum(axis=-1).reshape(eps_shape)
2622+
2623+
# match derivative dtype to the underlying dataset
2624+
target_array = getattr(spatial_data, "values", None)
2625+
if target_array is None and hasattr(spatial_data, "data"):
2626+
target_array = spatial_data.data
2627+
if target_array is not None:
2628+
target_dtype = np.asarray(target_array).dtype
2629+
if not np.issubdtype(target_dtype, np.complexfloating):
2630+
vjp_array = np.real(vjp_array).astype(target_dtype, copy=False)
2631+
2632+
return vjp_array
26132633

26142634

26152635
""" Dispersive Media """

tidy3d/components/simulation.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4258,19 +4258,6 @@ def _validate_low_freq_smoothing(self) -> Self:
42584258
)
42594259
return self
42604260

4261-
@model_validator(mode="after")
4262-
def _validate_scene(self) -> Self:
4263-
_ = self.scene
4264-
self._validate_no_structures_pml()
4265-
self._validate_tfsf_nonuniform_grid()
4266-
self._validate_tfsf_aux_sources()
4267-
self._validate_nonlinear_specs()
4268-
self._validate_custom_source_time()
4269-
self._validate_mode_objects()
4270-
self._warn_rf_license()
4271-
self._validate_internal_abc_no_fully_anisotropic()
4272-
return self
4273-
42744261
@model_validator(mode="after")
42754262
def _warn_rf_license(self) -> Self:
42764263
"""
@@ -4545,16 +4532,18 @@ def aux_fields(self) -> list[str]:
45454532
fields += medium.nonlinear_spec.aux_fields
45464533
return fields
45474534

4548-
def _validate_internal_abc_no_fully_anisotropic(self) -> None:
4535+
@model_validator(mode="after")
4536+
def _validate_internal_abc_no_fully_anisotropic(self) -> Self:
45494537
"""Error if internal absorber intersect fully anisotropic mediums."""
45504538

45514539
total_structures = [self.scene.background_structure, *list(self.structures)]
45524540

45534541
for abc in self._shifted_internal_absorbers:
4554-
mediums = Scene.intersecting_media(abc, total_structures)
4542+
mediums = Scene.intersecting_media(abc, tuple(total_structures))
45554543

45564544
if any(isinstance(med, FullyAnisotropicMedium) for med in mediums):
45574545
raise SetupError("A 'InternalAbsorber' cannot cross a 'FullyAnisotropicMedium'.")
4546+
return self
45584547

45594548
""" Pre submit validation (before web.upload()) """
45604549

tidy3d/log.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,9 @@ class Logger:
122122
highest level of the captures messages.
123123
124124
Messages can also be captured for post-processing. That can be enabled through 'set_capture' to
125-
record all warnings emitted during model validation. A structured copy of all validation
126-
messages can then be recovered through 'captured_warnings'.
125+
record warnings emitted during model validation (and other explicit begin/end capture regions,
126+
e.g. validation routines like ``validate_pre_upload``). A structured copy of captured warnings
127+
can then be recovered through 'captured_warnings'.
127128
"""
128129

129130
_static_cache = set()
@@ -178,8 +179,8 @@ def __exit__(
178179
def begin_capture(self) -> None:
179180
"""Start capturing log stack for consolidated validation log.
180181
181-
This method is used before any model validation starts and is included in the initialization
182-
of 'BaseModel'. It must be followed by a corresponding 'end_capture'.
182+
This method should be called before a validation routine starts. It must be followed by a
183+
corresponding 'end_capture'.
183184
"""
184185
if not self._capture:
185186
return
@@ -190,11 +191,23 @@ def begin_capture(self) -> None:
190191
else:
191192
self._stack = [stack_item]
192193

194+
def abort_capture(self) -> None:
195+
"""Undo the last ``begin_capture()`` call.
196+
197+
This is used when validation fails before reaching the corresponding ``end_capture()``.
198+
"""
199+
if not self._stack:
200+
return
201+
202+
self._stack.pop()
203+
if len(self._stack) == 0:
204+
self._stack = None
205+
193206
def end_capture(self, model: BaseModel) -> None:
194207
"""End capturing log stack for consolidated validation log.
195208
196-
This method is used after all model validations and is included in the initialization of
197-
'BaseModel'. It must follow a corresponding 'begin_capture'.
209+
This method should be called after a validation routine ends. It must follow a
210+
corresponding 'begin_capture'.
198211
"""
199212
if not self._stack:
200213
return

tidy3d/plugins/expressions/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def generate_validation_data(expr: Expression) -> dict[str, xr.Dataset]:
3232
dict[str, xr.Dataset]
3333
The combined validation data.
3434
"""
35-
metrics = set(expr.filter(target_type=Metric))
35+
metrics = set(expr.filter(target_type=Metric)) # type: ignore[type-abstract]
3636
combined_data = {k: v for metric in metrics for k, v in metric._validation_data.items()}
3737
return combined_data
3838

tidy3d/plugins/invdes/region.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class TopologyDesignRegion(DesignRegion):
181181
"Supplying ``False`` will completely leave out the override structure.",
182182
)
183183

184-
priority: int = pd.Field(
184+
priority: Optional[int] = Field(
185185
None,
186186
title="Priority",
187187
description="Priority of the structure applied in structure overlapping region. "

0 commit comments

Comments
 (0)