Skip to content

Commit 7197051

Browse files
authored
Merge pull request #568 from gdsfactory/mypy
Fix mypy
2 parents 9f0ff26 + aa5a277 commit 7197051

15 files changed

Lines changed: 133 additions & 96 deletions

.pre-commit-config.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ repos:
2424
- id: ruff-format
2525
files: ^gplugins/(femwell|gmsh|meow|sax|tidy3d|klayout|vlsir)/
2626

27+
- repo: https://github.com/pre-commit/mirrors-mypy
28+
rev: "v1.15.0"
29+
hooks:
30+
- id: mypy
31+
args: [--ignore-missing-imports, --strict, --config-file=pyproject.toml]
32+
additional_dependencies:
33+
- gdsfactory
34+
- pytest
35+
files: ^gplugins/(femwell|gmsh|meow|sax|tidy3d|klayout|vlsir)/
36+
2737
# - repo: https://github.com/shellcheck-py/shellcheck-py
2838
# rev: v0.9.0.5
2939
# hooks:

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ venv:
66

77
install:
88
uv venv --python 3.11
9-
uv pip install -e .[dev,docs,devsim,femwell,gmsh,klayout,meow,sax,schematic,tidy3d,vlsir]
9+
uv pip install -e .[dev,docs,femwell,gmsh,meow,sax,tidy3d,klayout,vlsir]
1010
uv run pre-commit install
1111

1212
dev: test-data gmsh elmer install

gplugins/common/utils/get_component_with_net_layers.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22

33
import gdsfactory as gf
4+
import klayout.db as kdb
45
from gdsfactory import Component
56
from gdsfactory.technology import LayerStack, LogicalLayer
67

@@ -35,8 +36,8 @@ def get_component_layer_stack(
3536

3637

3738
def get_component_with_net_layers(
38-
component,
39-
layer_stack,
39+
component: Component,
40+
layer_stack: LayerStack,
4041
port_names: list[str],
4142
delimiter: str = "#",
4243
new_layers_init: tuple[int, int] = (10010, 0),
@@ -74,7 +75,9 @@ def get_component_with_net_layers(
7475
net_component = net_component.remove_layers(layers=(port.layer,))
7576
for polygon in polygons:
7677
# If polygon belongs to port, create a unique new layer, and add the polygon to it
77-
if polygon.sized(3 * gf.kcl.dbu).inside(port.center):
78+
if polygon.sized(int(3 * gf.kcl.dbu)).inside(
79+
kdb.Point(*port.to_itype().center)
80+
):
7881
# if gdstk.inside(
7982
# [port.center],
8083
# gdstk.offset(gdstk.Polygon(polygon), gf.get_active_pdk().grid_size),
@@ -111,8 +114,3 @@ def get_component_with_net_layers(
111114

112115
net_component.name = f"{component.name}_net_layers"
113116
return net_component
114-
115-
116-
if __name__ == "__main__":
117-
c = get_component_with_net_layers()
118-
c.show()

gplugins/common/utils/get_effective_indices.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
from __future__ import annotations
44

5-
from typing import Literal
5+
from typing import Any, Literal, cast
66

77
import numpy as np
8+
import numpy.typing as npt
89
from scipy.optimize import fsolve
910

1011

@@ -19,9 +20,9 @@ def get_effective_indices(
1920
"""Returns the effective refractive indices for a 1D mode.
2021
2122
Args:
22-
epsilon_core: Relative permittivity of the film.
23-
epsilon_substrate: Relative permittivity of the substrate.
24-
epsilon_cladding: Relative permittivity of the cladding.
23+
core_material: Refractive index of the core material.
24+
nsubstrate: Refractive index of the substrate.
25+
clad_materialding: Refractive index of the cladding.
2526
thickness: Thickness of the film in um.
2627
wavelength: Wavelength in um.
2728
polarization: Either "te" or "tm".
@@ -66,18 +67,20 @@ def get_effective_indices(
6667

6768
k_0 = 2 * np.pi / wavelength
6869

69-
def k_f(e_eff):
70+
def k_f(e_eff: npt.NDArray[np.floating[Any]]) -> npt.NDArray[np.floating[Any]]:
7071
return k_0 * np.sqrt(epsilon_core - e_eff) / (epsilon_core if tm else 1)
7172

72-
def k_s(e_eff):
73+
def k_s(e_eff: npt.NDArray[np.floating[Any]]) -> npt.NDArray[np.floating[Any]]:
7374
return (
7475
k_0 * np.sqrt(e_eff - epsilon_substrate) / (epsilon_substrate if tm else 1)
7576
)
7677

77-
def k_c(e_eff):
78+
def k_c(e_eff: npt.NDArray[np.floating[Any]]) -> npt.NDArray[np.floating[Any]]:
7879
return k_0 * np.sqrt(e_eff - epsilon_cladding) / (epsilon_cladding if tm else 1)
7980

80-
def objective(e_eff):
81+
def objective(
82+
e_eff: npt.NDArray[np.floating[Any]],
83+
) -> npt.NDArray[np.floating[Any]]:
8184
return 1 / np.tan(k_f(e_eff) * thickness) - (
8285
k_f(e_eff) ** 2 - k_s(e_eff) * k_c(e_eff)
8386
) / (k_f(e_eff) * (k_s(e_eff) + k_c(e_eff)))
@@ -92,14 +95,14 @@ def objective(e_eff):
9295
return []
9396

9497
# and then use fsolve to get exact indices
95-
indices_temp = fsolve(objective, indices_temp)
98+
indices_temp = cast(npt.NDArray[np.floating[Any]], fsolve(objective, indices_temp))
9699

97-
indices = []
100+
indices: list[float] = []
98101
for index in indices_temp:
99102
if not any(np.isclose(index, i, atol=1e-5) for i in indices):
100103
indices.append(index)
101104

102-
return np.sqrt(indices).tolist()
105+
return cast(list[float], np.sqrt(indices).tolist())
103106

104107

105108
def test_effective_index() -> None:

gplugins/common/utils/parse_layer_stack.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def list_unique_layer_stack_z(
2626
def map_unique_layer_stack_z(
2727
layer_stack: LayerStack,
2828
include_zmax: bool = True,
29-
):
29+
) -> dict[str, set[float]]:
3030
"""Map unique LayerStack z coordinates to various layers.
3131
3232
Args:
@@ -57,7 +57,7 @@ def map_unique_layer_stack_z(
5757
def get_layer_overlaps_z(
5858
layer_stack: LayerStack,
5959
include_zmax: bool = True,
60-
):
60+
) -> dict[float, set[str]]:
6161
"""Maps layers to unique LayerStack z coordinates.
6262
6363
Args:
@@ -67,7 +67,7 @@ def get_layer_overlaps_z(
6767
"""
6868
z_grid = list_unique_layer_stack_z(layer_stack)
6969
unique_z_dict = map_unique_layer_stack_z(layer_stack, include_zmax)
70-
intersection_z_dict = {}
70+
intersection_z_dict: dict[float, set[str]] = {}
7171
for z in z_grid:
7272
current_layers = {
7373
layername for layername, layer_zs in unique_z_dict.items() if z in layer_zs
@@ -77,11 +77,12 @@ def get_layer_overlaps_z(
7777
return intersection_z_dict
7878

7979

80-
def get_layers_at_z(layer_stack: LayerStack, z: float):
80+
def get_layers_at_z(layer_stack: LayerStack, z: float) -> list[str]:
8181
"""Returns layers present at a given z-position.
8282
8383
Args:
8484
layer_stack: LayerStack
85+
z: float
8586
Returns:
8687
List of layers
8788
"""
@@ -93,11 +94,11 @@ def get_layers_at_z(layer_stack: LayerStack, z: float):
9394
raise ValueError("Requested z-value is above the minimum layer_stack z")
9495
for z_unique in intersection_z_dict.keys():
9596
if z <= z_unique:
96-
return intersection_z_dict[z_unique]
97+
return list(intersection_z_dict[z_unique])
9798
raise AssertionError("Could not find z-value in layer_stack z-range.")
9899

99100

100-
def order_layer_stack(layer_stack: LayerStack):
101+
def order_layer_stack(layer_stack: LayerStack) -> list[str]:
101102
"""Orders layer_stack according to mesh_order.
102103
103104
Args:

gplugins/common/utils/plot.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,25 @@
44
from collections.abc import Sequence
55
from functools import partial
66
from itertools import combinations
7+
from typing import Any
78

8-
import gdsfactory as gf
99
import matplotlib.pyplot as plt
1010
import numpy as np
11+
import numpy.typing as npt
12+
from matplotlib.axes import Axes
1113

1214

13-
def _check_ports(sp: dict[str, np.ndarray], ports: Sequence[str]) -> None:
15+
def _check_ports(
16+
sp: dict[str, npt.NDArray[np.floating[Any]]], ports: Sequence[str]
17+
) -> None:
1418
"""Ensure ports exist in Sparameters."""
1519
for port in ports:
1620
if port not in sp:
1721
raise ValueError(f"Did not find port {port!r} in {list(sp.keys())}")
1822

1923

2024
def plot_sparameters(
21-
sp: dict[str, np.ndarray],
25+
sp: dict[str, npt.NDArray[np.floating[Any]]],
2226
logscale: bool = True,
2327
plot_phase: bool = False,
2428
keys: tuple[str, ...] | None = None,
@@ -39,7 +43,7 @@ def plot_sparameters(
3943
4044
"""
4145
w = sp["wavelengths"] * units
42-
keys = keys or [key for key in sp if not key.lower().startswith("wav")]
46+
keys = keys or tuple(key for key in sp if not key.lower().startswith("wav"))
4347

4448
for key in keys:
4549
if with_simpler_input_keys:
@@ -74,9 +78,12 @@ def plot_sparameters(
7478

7579

7680
def plot_imbalance(
77-
sp: dict[str, np.ndarray], ports: Sequence[str], ax: plt.Axes | None = None
81+
sp: dict[str, npt.NDArray[np.floating[Any]]],
82+
ports: Sequence[str],
83+
ax: Axes | None = None,
7884
) -> None:
7985
"""Plots imbalance in dB for coupler.
86+
8087
The imbalance is always defined between two ports, so this function plots the
8188
imbalance between all unique port combinations.
8289
@@ -107,7 +114,9 @@ def plot_imbalance(
107114

108115

109116
def plot_loss(
110-
sp: dict[str, np.ndarray], ports: Sequence[str], ax: plt.Axes | None = None
117+
sp: dict[str, npt.NDArray[np.floating[Any]]],
118+
ports: Sequence[str],
119+
ax: Axes | None = None,
111120
) -> None:
112121
"""Plots loss dB for coupler.
113122
@@ -137,7 +146,9 @@ def plot_loss(
137146

138147

139148
def plot_reflection(
140-
sp: dict[str, np.ndarray], ports: Sequence[str], ax: plt.Axes | None = None
149+
sp: dict[str, npt.NDArray[np.floating[Any]]],
150+
ports: Sequence[str],
151+
ax: Axes | None = None,
141152
) -> None:
142153
"""Plots reflection in dB for coupler.
143154
@@ -172,11 +183,3 @@ def plot_reflection(
172183
plot_imbalance2x2 = partial(plot_imbalance, ports=["o1@0,o3@0", "o1@0,o4@0"])
173184
plot_reflection1x2 = partial(plot_reflection, ports=["o1@0,o1@0"])
174185
plot_reflection2x2 = partial(plot_reflection, ports=["o1@0,o1@0", "o2@0,o1@0"])
175-
176-
if __name__ == "__main__":
177-
import gplugins as sim
178-
179-
sp = sim.get_sparameters_data_tidy3d(component=gf.components.mmi1x2)
180-
# plot_sparameters(sp, logscale=False, keys=["o1@0,o2@0"])
181-
# plot_sparameters(sp, logscale=False, keys=["S21"])
182-
# plt.show()

gplugins/gmsh/define_polysurfaces.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@
66

77

88
def define_polysurfaces(
9-
polygons_dict: dict,
9+
polygons_dict: dict[str, Any],
1010
layer_stack: LayerStack,
11-
layer_physical_map: dict,
12-
layer_meshbool_map: dict,
11+
layer_physical_map: dict[str, Any],
12+
layer_meshbool_map: dict[str, Any],
1313
model: Any,
14-
resolutions: dict,
14+
resolutions: dict[str, Any] | None = None,
1515
scale_factor: float = 1,
16-
):
16+
) -> list[PolySurface]:
1717
"""Define meshwell polysurfaces dimtags from gdsfactory information."""
18-
polysurfaces_list = []
18+
polysurfaces_list: list[PolySurface] = []
1919

2020
if resolutions is None:
2121
resolutions = {}
@@ -24,6 +24,11 @@ def define_polysurfaces(
2424
if polygons_dict[layername].is_empty:
2525
continue
2626

27+
layer_stack_ = layer_stack.layers.get(layername)
28+
29+
if layer_stack_ is None:
30+
continue
31+
2732
polysurfaces_list.append(
2833
PolySurface(
2934
polygons=scale(
@@ -33,7 +38,7 @@ def define_polysurfaces(
3338
),
3439
model=model,
3540
resolution=resolutions.get(layername, None),
36-
mesh_order=layer_stack.layers.get(layername).mesh_order,
41+
mesh_order=layer_stack_.mesh_order,
3742
physical_name=layer_physical_map[layername]
3843
if layername in layer_physical_map
3944
else layername,

gplugins/gmsh/uz_xsection_mesh.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
# type: ignore
21
from __future__ import annotations
32

43
from collections.abc import Sequence
4+
from typing import Any
55

66
import gdsfactory as gf
77
import numpy as np
@@ -206,12 +206,12 @@ def uz_xsection_mesh(
206206
component: ComponentOrReference,
207207
xsection_bounds: tuple[tuple[float, float], tuple[float, float]],
208208
layer_stack: LayerStack,
209-
layer_physical_map: dict,
210-
layer_meshbool_map: dict,
211-
resolutions: dict | None = None,
209+
layer_physical_map: dict[str, Any],
210+
layer_meshbool_map: dict[str, Any],
211+
resolutions: dict[str, Any] | None = None,
212212
default_characteristic_length: float = 0.5,
213213
background_tag: str | None = None,
214-
background_padding: Sequence[float, float, float, float, float, float] = (2.0,) * 6,
214+
background_padding: Sequence[float] = (2.0,) * 6,
215215
background_mesh_order: int | float = 2**63 - 1,
216216
global_scaling: float = 1,
217217
global_scaling_premesh: float = 1,
@@ -225,9 +225,9 @@ def uz_xsection_mesh(
225225
n_threads: int = get_number_of_cores(),
226226
gmsh_version: float | None = None,
227227
interface_delimiter: str = "___",
228-
background_remeshing_file=None,
228+
background_remeshing_file: str | None = None,
229229
optimization_flags: tuple[tuple[str, int]] | None = None,
230-
**kwargs,
230+
**kwargs: Any,
231231
):
232232
"""Mesh uz cross-section of component along line u = [[x1,y1] , [x2,y2]].
233233

0 commit comments

Comments
 (0)