44from collections .abc import Sequence
55from functools import partial
66from itertools import combinations
7+ from typing import Any
78
8- import gdsfactory as gf
99import matplotlib .pyplot as plt
1010import numpy as np
11+ import numpy .typing as npt
12+ from matplotlib .axes import Axes
1113
1214
13- def _check_ports (sp : dict [str , np .ndarray ], ports : Sequence [str ]) -> None :
15+ def _check_ports (
16+ sp : dict [str , npt .NDArray [np .floating [Any ]]], ports : Sequence [str ]
17+ ) -> None :
1418 """Ensure ports exist in Sparameters."""
1519 for port in ports :
1620 if port not in sp :
1721 raise ValueError (f"Did not find port { port !r} in { list (sp .keys ())} " )
1822
1923
2024def plot_sparameters (
21- sp : dict [str , np .ndarray ],
25+ sp : dict [str , npt . NDArray [ np .floating [ Any ]] ],
2226 logscale : bool = True ,
2327 plot_phase : bool = False ,
2428 keys : tuple [str , ...] | None = None ,
@@ -39,7 +43,7 @@ def plot_sparameters(
3943
4044 """
4145 w = sp ["wavelengths" ] * units
42- keys = keys or [ key for key in sp if not key .lower ().startswith ("wav" )]
46+ keys = keys or tuple ( key for key in sp if not key .lower ().startswith ("wav" ))
4347
4448 for key in keys :
4549 if with_simpler_input_keys :
@@ -74,9 +78,12 @@ def plot_sparameters(
7478
7579
7680def plot_imbalance (
77- sp : dict [str , np .ndarray ], ports : Sequence [str ], ax : plt .Axes | None = None
81+ sp : dict [str , npt .NDArray [np .floating [Any ]]],
82+ ports : Sequence [str ],
83+ ax : Axes | None = None ,
7884) -> None :
7985 """Plots imbalance in dB for coupler.
86+
8087 The imbalance is always defined between two ports, so this function plots the
8188 imbalance between all unique port combinations.
8289
@@ -107,7 +114,9 @@ def plot_imbalance(
107114
108115
109116def plot_loss (
110- sp : dict [str , np .ndarray ], ports : Sequence [str ], ax : plt .Axes | None = None
117+ sp : dict [str , npt .NDArray [np .floating [Any ]]],
118+ ports : Sequence [str ],
119+ ax : Axes | None = None ,
111120) -> None :
112121 """Plots loss dB for coupler.
113122
@@ -137,7 +146,9 @@ def plot_loss(
137146
138147
139148def plot_reflection (
140- sp : dict [str , np .ndarray ], ports : Sequence [str ], ax : plt .Axes | None = None
149+ sp : dict [str , npt .NDArray [np .floating [Any ]]],
150+ ports : Sequence [str ],
151+ ax : Axes | None = None ,
141152) -> None :
142153 """Plots reflection in dB for coupler.
143154
@@ -172,11 +183,3 @@ def plot_reflection(
172183plot_imbalance2x2 = partial (plot_imbalance , ports = ["o1@0,o3@0" , "o1@0,o4@0" ])
173184plot_reflection1x2 = partial (plot_reflection , ports = ["o1@0,o1@0" ])
174185plot_reflection2x2 = partial (plot_reflection , ports = ["o1@0,o1@0" , "o2@0,o1@0" ])
175-
176- if __name__ == "__main__" :
177- import gplugins as sim
178-
179- sp = sim .get_sparameters_data_tidy3d (component = gf .components .mmi1x2 )
180- # plot_sparameters(sp, logscale=False, keys=["o1@0,o2@0"])
181- # plot_sparameters(sp, logscale=False, keys=["S21"])
182- # plt.show()
0 commit comments