Skip to content

Commit d898fcf

Browse files
committed
everything importable for doctests
1 parent 2b1c58e commit d898fcf

9 files changed

Lines changed: 51 additions & 30 deletions

File tree

tidy3d/plugins/adjoint/components/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from .data.data_array import JaxDataArray
55
from .data.dataset import JaxPermittivityDataset
66
from .data.monitor_data import JaxModeData
7-
from .data.sim_data import JaxSimulationData
8-
from .geometry import JaxBox, JaxComplexPolySlab, JaxPolySlab
7+
from .geometry import JaxBox, JaxComplexPolySlab, JaxGeometryGroup, JaxPolySlab
98
from .medium import JaxAnisotropicMedium, JaxCustomMedium, JaxMedium
109
from .simulation import JaxSimulation
1110
from .structure import JaxStructure, JaxStructureStaticGeometry, JaxStructureStaticMedium
@@ -22,7 +21,6 @@
2221
"JaxStructureStaticMedium",
2322
"JaxStructureStaticGeometry",
2423
"JaxSimulation",
25-
"JaxSimulationData",
2624
"JaxModeData",
2725
"JaxPermittivityDataset",
2826
"JaxDataArray",

tidy3d/plugins/adjoint/components/geometry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
PolySlab,
2525
)
2626
from tidy3d.components.monitor import FieldMonitor, PermittivityMonitor
27-
from tidy3d.components.types import ArrayFloat2D, Bound, Coordinate2D # , annotate_type
27+
from tidy3d.components.types import ArrayFloat2D, Bound, Coordinate2D
2828
from tidy3d.constants import MICROMETER, fp_eps
2929
from tidy3d.exceptions import AdjointError
3030
from tidy3d.log import log

tidy3d/plugins/adjoint/components/medium.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ class JaxCustomMedium(CustomMedium, AbstractJaxMedium):
290290
jax_field=True,
291291
)
292292

293-
@model_validator(pre=True)
293+
@model_validator(mode="before")
294294
def _pre_deprecation_dataset(data):
295295
"""Don't allow permittivity as a field until we support it."""
296296
if data.get("permittivity") or data.get("conductivity"):

tidy3d/plugins/adjoint/components/simulation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from tidy3d.components.simulation import Simulation
3232
from tidy3d.components.structure import Structure
3333
from tidy3d.components.subpixel_spec import Staircasing, SubpixelSpec
34-
from tidy3d.components.types import Ax, annotate_type
34+
from tidy3d.components.types import Ax, discriminated_union
3535
from tidy3d.constants import HERTZ, SECOND
3636
from tidy3d.exceptions import AdjointError
3737
from tidy3d.log import log
@@ -69,7 +69,7 @@
6969
)
7070

7171
OutputMonitorTypes = (DiffractionMonitor, FieldMonitor, ModeMonitor)
72-
OutputMonitorType = tuple[annotate_type(Union[OutputMonitorTypes]), ...]
72+
OutputMonitorType = tuple[discriminated_union(Union[OutputMonitorTypes]), ...]
7373

7474

7575
class JaxInfo(Tidy3dBaseModel):
@@ -122,7 +122,7 @@ class JaxInfo(Tidy3dBaseModel):
122122
class JaxSimulation(Simulation, JaxObject):
123123
"""A :class:`.Simulation` registered with jax."""
124124

125-
input_structures: tuple[annotate_type(JaxStructureType), ...] = Field(
125+
input_structures: tuple[discriminated_union(JaxStructureType), ...] = Field(
126126
(),
127127
title="Input Structures",
128128
description="Tuple of jax-compatible structures"
@@ -164,7 +164,7 @@ class JaxSimulation(Simulation, JaxObject):
164164
units=SECOND,
165165
)
166166

167-
@field_validator("output_monitors", always=True)
167+
@field_validator("output_monitors")
168168
def _output_monitors_colocate_false(val):
169169
"""Make sure server-side colocation is off."""
170170
new_vals = []

tidy3d/plugins/adjoint/components/types.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Union
44

55
import numpy as np
6+
from pydantic_core import core_schema
67

78
from tidy3d.components.type_util import _add_schema
89

@@ -29,16 +30,16 @@ class NumpyArrayType(np.ndarray):
2930
"""Subclass of ``np.ndarray`` with a schema defined for pydantic."""
3031

3132
@classmethod
32-
# TODO[pydantic]: We couldn't refactor `__modify_schema__`, please create the `__get_pydantic_json_schema__` manually.
33-
# Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information.
34-
def __modify_schema__(cls, field_schema):
35-
"""Sets the schema of np.ndarray object."""
36-
37-
schema = dict(
38-
title="npdarray",
39-
type="numpy.ndarray",
40-
)
41-
field_schema.update(schema)
33+
def __get_pydantic_core_schema__(cls, source, handler):
34+
return core_schema.no_info_plain_validator_function(lambda v, _: np.asarray(v))
35+
36+
@classmethod
37+
def __get_pydantic_json_schema__(cls, core_schema, handler):
38+
return {
39+
"title": "npdarray",
40+
"type": "numpy.ndarray",
41+
"items": {},
42+
}
4243

4344

4445
_add_schema(JaxArrayType, title="JaxArray", field_type_str="jax.numpy.ndarray")

tidy3d/plugins/design/parameter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class ParameterNumeric(Parameter, ABC):
6363
description="(min, max) range within which are allowed values for the variable. Is inclusive of max value.",
6464
)
6565

66-
@field_validator("span", always=True)
66+
@field_validator("span")
6767
def _span_valid(val):
6868
"""Span min <= span max."""
6969
span_min, span_max = val

tidy3d/plugins/expressions/__init__.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,44 @@
11
from .base import Expression
22
from .functions import Cos, Exp, Log, Log10, Sin, Sqrt, Tan
33
from .metrics import ModeAmp, ModePower, generate_validation_data
4+
from .operators import (
5+
Abs,
6+
Add,
7+
Divide,
8+
FloorDivide,
9+
MatMul,
10+
Modulus,
11+
Multiply,
12+
Negate,
13+
Power,
14+
Subtract,
15+
)
416
from .variables import Constant, Variable
517

618
__all__ = [
7-
"Expression",
19+
"Abs",
20+
"Add",
821
"Constant",
9-
"Variable",
10-
"ModeAmp",
11-
"ModePower",
12-
"generate_validation_data",
13-
"Sin",
1422
"Cos",
15-
"Tan",
23+
"Divide",
1624
"Exp",
25+
"Expression",
26+
"FloorDivide",
1727
"Log",
1828
"Log10",
29+
"MatMul",
30+
"ModeAmp",
31+
"ModePower",
32+
"Modulus",
33+
"Multiply",
34+
"Negate",
35+
"Power",
36+
"Sin",
1937
"Sqrt",
38+
"Subtract",
39+
"Tan",
40+
"Variable",
41+
"generate_validation_data",
2042
]
2143

2244
# The following code dynamically collects all classes that are subclasses of Expression
@@ -41,4 +63,4 @@
4163
_local_vars[name] = obj
4264

4365
for cls in _model_classes:
44-
cls.update_forward_refs(**_local_vars)
66+
cls.model_rebuild(force=True)

tidy3d/plugins/smatrix/component_modelers/terminal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def _validate_3d_simulation(cls, val):
219219
)
220220
return val
221221

222-
@model_validator("radiation_monitors")
222+
@model_validator(mode="after")
223223
def _validate_radiation_monitors(self):
224224
freqs = set(self.freqs)
225225
for rad_mon in self.radiation_monitors:

tidy3d/plugins/smatrix/ports/coaxial_lumped.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def _center_not_inf(val):
8888
raise ValidationError("'center' can not contain 'td.inf' terms.")
8989
return val
9090

91-
@model_validator("inner_diameter", always=True)
91+
@model_validator(mode="after")
9292
def _ensure_inner_diameter_is_smaller(self):
9393
"""Ensures that the inner diameter is smaller than the outer diameter, so that the final
9494
shape is an annulus."""

0 commit comments

Comments
 (0)