Skip to content

Commit 5000801

Browse files
committed
(tentative): fix warnings caused by model_copy()
1 parent 09b3ec7 commit 5000801

8 files changed

Lines changed: 230 additions & 25 deletions

File tree

tests/test_components/test_simulation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3667,12 +3667,12 @@ def test_messages_contain_object_names():
36673667
polarization="Ex",
36683668
source_time=td.GaussianPulse(freq0=100e14, fwidth=10e14),
36693669
)
3670-
with pytest.raises(pydantic.ValidationError, match=name) as e:
3670+
with pytest.raises(ValidationError, match=name) as e:
36713671
_ = sim.updated_copy(sources=[source])
36723672

36733673
# Test 3) Create a monitor lying outside the simulation boundary.
36743674
# Check that an error message is generated containing the monitor's `name`.
36753675
name = "monitor_123"
36763676
monitor = td.FieldMonitor(name=name, center=(-1.0, 0, 0), size=(0.5, 0, 1), freqs=[100e14])
3677-
with pytest.raises(pydantic.ValidationError, match=name) as e:
3677+
with pytest.raises(ValidationError, match=name) as e:
36783678
_ = sim.updated_copy(monitors=[monitor])

tests/test_plugins/test_invdes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,12 +273,12 @@ def test_invdes_multi_same_length():
273273
invdes = make_invdes_multi()
274274
n = len(invdes.simulations)
275275

276-
output_monitor_names = (n + 1) * [["test"]]
276+
output_monitor_names = (n + 1) * [("test",)]
277277

278278
with pytest.raises(ValueError):
279279
_ = invdes.updated_copy(output_monitor_names=output_monitor_names)
280280

281-
output_monitor_names = [([MNT_NAME1, MNT_NAME2], None)[i % 2] for i in range(n)]
281+
output_monitor_names = [((MNT_NAME1, MNT_NAME2), None)[i % 2] for i in range(n)]
282282
invdes = invdes.updated_copy(output_monitor_names=output_monitor_names)
283283

284284
_ = invdes.designs

tidy3d/components/autograd/types.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Annotated, Literal, Optional, Union, get_origin
77

88
import autograd.numpy as anp
9-
from autograd.builtins import dict as TracedDict
9+
from autograd.builtins import dict as _AutogradTracedDict
1010
from autograd.extend import Box, defvjp, primitive
1111
from pydantic import BeforeValidator, PlainSerializer, PositiveFloat, TypeAdapter
1212

@@ -16,6 +16,30 @@
1616

1717
from .utils import get_static, hasbox
1818

19+
20+
class TracedDict(_AutogradTracedDict):
21+
"""Custom TracedDict that preserves tuple keys instead of converting them to lists."""
22+
23+
def __init__(self, *args, **kwargs):
24+
"""Initialize TracedDict while preserving tuple keys."""
25+
if args and isinstance(args[0], dict):
26+
# Convert any list keys back to tuples
27+
fixed_dict = {}
28+
for key, value in args[0].items():
29+
if isinstance(key, list):
30+
key = tuple(key)
31+
fixed_dict[key] = value
32+
super().__init__(fixed_dict, **kwargs)
33+
else:
34+
super().__init__(*args, **kwargs)
35+
36+
def __setitem__(self, key, value):
37+
"""Ensure keys remain as tuples when setting items."""
38+
if isinstance(key, list):
39+
key = tuple(key)
40+
super().__setitem__(key, value)
41+
42+
1943
# add schema to the Box
2044
_add_schema(Box, title="AutogradBox", field_type_str="autograd.tracer.Box")
2145

@@ -114,6 +138,7 @@ def _validate_box_or_container(v):
114138
"TracedArrayLike",
115139
"TracedComplex",
116140
"TracedCoordinate",
141+
"TracedDict",
117142
"TracedFloat",
118143
"TracedPoleAndResidue",
119144
"TracedPolesAndResidues",

tidy3d/components/base.py

Lines changed: 141 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from functools import total_ordering, wraps
1515
from math import ceil
1616
from pathlib import Path
17-
from typing import Any, Callable, Literal, Optional, TypeVar, Union
17+
from typing import Any, Callable, Literal, Optional, TypeVar, Union, get_args, get_origin
1818

1919
import h5py
2020
import numpy as np
@@ -189,6 +189,131 @@ def _hash_self(self) -> str:
189189
self.to_hdf5(bf)
190190
return hashlib.sha256(bf.getvalue()).hexdigest()
191191

192+
def _preprocess_update_values(self, update: Mapping[str, Any]) -> dict[str, Any]:
193+
"""Preprocess update values to convert lists to tuples where appropriate.
194+
195+
This helps avoid Pydantic v2 serialization warnings when using `model_copy()`
196+
with list values for tuple fields.
197+
"""
198+
if not update:
199+
return {}
200+
201+
def get_tuple_element_type(annotation) -> Optional[type]:
202+
"""Get the element type of a tuple annotation if it has one consistent type."""
203+
origin = get_origin(annotation)
204+
if origin is tuple:
205+
args = get_args(annotation)
206+
if args:
207+
# Check if it's a homogeneous tuple like tuple[bool, ...] or tuple[str, ...]
208+
if len(args) == 2 and args[1] is ...:
209+
return args[0]
210+
# Check if all elements have the same type
211+
if all(arg == args[0] for arg in args):
212+
return args[0]
213+
return None
214+
215+
def should_convert_to_tuple(annotation) -> tuple[bool, Optional[type]]:
216+
"""Check if the given annotation represents a tuple type and return element type if any."""
217+
origin = get_origin(annotation)
218+
219+
if origin is tuple:
220+
return True, get_tuple_element_type(annotation)
221+
222+
# Union types containing tuple
223+
if origin is Union:
224+
args = get_args(annotation)
225+
for arg in args:
226+
if get_origin(arg) is tuple:
227+
return True, get_tuple_element_type(arg)
228+
229+
return False, None
230+
231+
def convert_value(value: Any, field_info) -> Any:
232+
"""Convert value based on field type information."""
233+
annotation = field_info.annotation
234+
235+
# Handle list/tuple to tuple conversion with proper element types
236+
is_tuple, element_type = should_convert_to_tuple(annotation)
237+
238+
# Check if value is a numpy array and needs to be converted to tuple
239+
try:
240+
import numpy as np
241+
242+
if isinstance(value, np.ndarray) and is_tuple:
243+
# Convert numpy array to list first
244+
value = value.tolist()
245+
except ImportError:
246+
pass
247+
248+
# Handle autograd SequenceBox - convert to tuple
249+
if (
250+
is_tuple
251+
and hasattr(value, "__class__")
252+
and value.__class__.__name__ == "SequenceBox"
253+
):
254+
# SequenceBox is iterable, so convert it to tuple
255+
return tuple(value)
256+
257+
if isinstance(value, (list, tuple)) and is_tuple:
258+
# Convert elements based on element type
259+
if element_type is bool:
260+
# Convert integers to booleans
261+
value = [bool(item) if isinstance(item, int) else item for item in value]
262+
elif element_type is str:
263+
# Ensure all elements are strings
264+
value = [str(item) if not isinstance(item, str) else item for item in value]
265+
else:
266+
# Check if it's a numpy array or contains numpy types
267+
try:
268+
import numpy as np
269+
270+
if any(isinstance(item, np.generic) for item in value):
271+
# Convert numpy types to Python types
272+
value = [
273+
item.item() if isinstance(item, np.generic) else item
274+
for item in value
275+
]
276+
except ImportError:
277+
pass
278+
return tuple(value)
279+
280+
# Handle int to bool conversion
281+
if annotation is bool and isinstance(value, int):
282+
return bool(value)
283+
284+
# Handle dict to Tidy3dBaseModel conversion
285+
if isinstance(value, dict):
286+
# Check if the annotation is a Tidy3dBaseModel subclass
287+
origin = get_origin(annotation)
288+
if origin is None:
289+
# Not a generic type, check if it's a direct subclass
290+
try:
291+
if isinstance(annotation, type) and issubclass(annotation, Tidy3dBaseModel):
292+
return annotation(**value)
293+
except (TypeError, AttributeError):
294+
pass
295+
elif origin is Union:
296+
# For Union types, try to convert to the first matching Tidy3dBaseModel type
297+
args = get_args(annotation)
298+
for arg in args:
299+
try:
300+
if isinstance(arg, type) and issubclass(arg, Tidy3dBaseModel):
301+
return arg(**value)
302+
except (TypeError, AttributeError, ValueError):
303+
continue
304+
305+
return value
306+
307+
processed = {}
308+
for field_name, value in update.items():
309+
if field_name in self.model_fields:
310+
field_info = self.model_fields[field_name]
311+
processed[field_name] = convert_value(value, field_info)
312+
else:
313+
processed[field_name] = value
314+
315+
return processed
316+
192317
def copy(
193318
self, *, deep: bool = True, validate: bool = True, update: Mapping[str, Any] | None = None
194319
) -> Self:
@@ -209,6 +334,10 @@ def copy(
209334
if invalid:
210335
raise KeyError(f"'{self.type}' received invalid fields on copy: {invalid}")
211336

337+
# preprocess update values to convert lists to tuples where appropriate
338+
if update:
339+
update = self._preprocess_update_values(update)
340+
212341
new_model = self.model_copy(deep=deep, update=update)
213342

214343
if validate:
@@ -275,8 +404,8 @@ def updated_copy(
275404

276405
@staticmethod
277406
def _core_model_traversal(
278-
current_obj: Any, current_path_segments: list[str]
279-
) -> Iterator[tuple[Self, list[str]]]:
407+
current_obj: Any, current_path_segments: tuple[str, ...]
408+
) -> Iterator[tuple[Self, tuple[str, ...]]]:
280409
"""
281410
Recursively traverses a model structure yielding Tidy3dBaseModel instances and their paths.
282411
@@ -288,16 +417,16 @@ def _core_model_traversal(
288417
current_obj : Any
289418
The current object in the traversal, which can be a :class:`Tidy3dBaseModel`,
290419
list, tuple, or other type.
291-
current_path_segments : list[str]
292-
A list of strings representing the path segments from the initial model
420+
current_path_segments : tuple[str, ...]
421+
A tuple of strings representing the path segments from the initial model
293422
to the ``current_obj``.
294423
295424
Returns
296425
-------
297-
Iterator[tuple[Self, list[str]]]
426+
Iterator[tuple[Self, tuple[str, ...]]]
298427
An iterator yielding tuples, where the first element is a found :class:`Tidy3dBaseModel` instance
299-
and the second is a list of strings representing the path to that instance
300-
from the initial object. The path for the top-level model itself will be an empty list.
428+
and the second is a tuple of strings representing the path to that instance
429+
from the initial object. The path for the top-level model itself will be an empty tuple.
301430
"""
302431
if isinstance(current_obj, Tidy3dBaseModel):
303432
yield current_obj, current_path_segments
@@ -311,12 +440,12 @@ def _core_model_traversal(
311440

312441
field_value = getattr(current_obj, field_name)
313442
yield from Tidy3dBaseModel._core_model_traversal(
314-
field_value, [*current_path_segments, field_name]
443+
field_value, (*current_path_segments, field_name)
315444
)
316445
elif isinstance(current_obj, (list, tuple)):
317446
for index, item in enumerate(current_obj):
318447
yield from Tidy3dBaseModel._core_model_traversal(
319-
item, [*current_path_segments, str(index)]
448+
item, (*current_path_segments, str(index))
320449
)
321450

322451
def find_paths(self, target_field_name: str, target_field_value: Any = Undefined) -> list[str]:
@@ -357,7 +486,7 @@ def find_paths(self, target_field_name: str, target_field_value: Any = Undefined
357486
found_paths_set = set()
358487

359488
for sub_model_instance, path_segments_to_sub_model in Tidy3dBaseModel._core_model_traversal(
360-
self, []
489+
self, ()
361490
):
362491
if target_field_name in sub_model_instance.model_fields:
363492
passes_value_filter = True
@@ -415,7 +544,7 @@ def find_submodels(self, target_type: Self) -> list[Self]:
415544
"""
416545
found_models_dict = {}
417546

418-
for sub_model_candidate, _ in Tidy3dBaseModel._core_model_traversal(self, []):
547+
for sub_model_candidate, _ in Tidy3dBaseModel._core_model_traversal(self, ()):
419548
if isinstance(sub_model_candidate, target_type):
420549
if sub_model_candidate not in found_models_dict:
421550
found_models_dict[sub_model_candidate] = True

tidy3d/components/boundary.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66
from typing import Optional, Union
77

88
import numpy as np
9-
from pydantic import Field, NonNegativeFloat, NonNegativeInt, field_validator, model_validator
9+
from pydantic import (
10+
Field,
11+
NonNegativeFloat,
12+
NonNegativeInt,
13+
field_validator,
14+
model_validator,
15+
)
1016

1117
from tidy3d.compat import Self
1218
from tidy3d.constants import EPSILON_0, MU_0, PML_SIGMA
@@ -834,6 +840,14 @@ class BoundarySpec(Tidy3dBaseModel):
834840
"so explicitly setting the boundaries is recommended.",
835841
)
836842

843+
@field_validator("x", "y", "z", mode="before")
844+
@classmethod
845+
def dict_to_boundary(cls, v):
846+
"""Convert dict representation to Boundary object if needed."""
847+
if isinstance(v, dict) and "plus" in v and "minus" in v:
848+
return Boundary(**v)
849+
return v
850+
837851
def __getitem__(self, field_name: str) -> Boundary:
838852
"""Get the :class:`Boundary` field by name (``boundary_spec[field_name]``).
839853

tidy3d/components/source/freq_range.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from __future__ import annotations
44

55
import numpy as np
6-
import pydantic.v1 as pydantic
76
from numpy.typing import NDArray
7+
from pydantic import Field, PositiveFloat
88

99
from tidy3d import constants as td_const
1010
from tidy3d.components.base import Tidy3dBaseModel
@@ -35,15 +35,13 @@ class FreqRange(Tidy3dBaseModel):
3535
>>> source = freq_range.to_gaussian_pulse()
3636
"""
3737

38-
freq0: pydantic.PositiveFloat = pydantic.Field(
39-
...,
38+
freq0: PositiveFloat = Field(
4039
title="Central frequency",
4140
description="Real-valued positive central frequency.",
4241
units="Hz",
4342
)
4443

45-
fwidth: pydantic.PositiveFloat = pydantic.Field(
46-
...,
44+
fwidth: PositiveFloat = Field(
4745
title="Frequency bandwidth",
4846
description="Real-valued positive width of the frequency range (bandwidth).",
4947
units="Hz",

0 commit comments

Comments
 (0)