Skip to content

Commit 59f72c8

Browse files
committed
update
1 parent 7e50c19 commit 59f72c8

14 files changed

Lines changed: 98 additions & 125 deletions

FiberFusing/background.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33

44
from typing import Tuple, Optional
55
from pydantic.dataclasses import dataclass
6-
from pydantic import ConfigDict
76
import numpy as np
87

98
from FiberFusing.coordinate_system import CoordinateSystem
109
from FiberFusing.helper import OverlayStructureBaseClass
11-
from FiberFusing.utils import NameSpace
10+
from FiberFusing.utils import NameSpace, config_dict
11+
from FiberFusing
1212

13-
14-
@dataclass(config=ConfigDict(extra='forbid', kw_only=True))
13+
@dataclass(config=config_dict)
1514
class BackGround(OverlayStructureBaseClass):
1615
"""
1716
Represents a background structure overlayed on a mesh, characterized by a circular shape.

FiberFusing/coordinate_system.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
import numpy as np
55
from typing import Tuple, Optional
66
from pydantic.dataclasses import dataclass
7-
from pydantic import ConfigDict, field_validator
7+
from pydantic import field_validator
88
from dataclasses import field
99

10+
from FiberFusing.utils import config_dict
1011

11-
@dataclass(config=ConfigDict(extra='forbid', strict=True, kw_only=True, arbitrary_types_allowed=True, frozen=False))
12+
13+
@dataclass(config=config_dict)
1214
class CoordinateSystem:
1315
"""
1416
A 2D Cartesian coordinate system for fiber optics simulations.

FiberFusing/fiber/generic_fiber.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
import pprint
77
from copy import deepcopy
88
import matplotlib.pyplot as plt
9+
from MPSPlots import helper
910

1011
from FiberFusing import Circle, CircleOpticalStructure
1112
from FiberFusing.geometries.point import Point
1213
from FiberFusing.coordinate_system import CoordinateSystem
1314
from FiberFusing.plottings import plot_polygon
14-
from FiberFusing.helper import _plot_helper
1515
from FiberFusing.graded_index import GradedIndex
1616

1717
pp = pprint.PrettyPrinter(indent=4, sort_dicts=False, compact=True, width=1)
@@ -508,8 +508,8 @@ def overlay_structures_on_mesh(self, mesh: numpy.ndarray, coordinate_system: Coo
508508
coordinate_system=coordinate_system
509509
)
510510

511-
@_plot_helper
512-
def plot(self, ax: plt.Axes = None) -> None:
511+
@helper.pre_plot(nrows=1, ncols=1)
512+
def plot(self, axes: plt.Axes = None) -> None:
513513
"""
514514
Plot the fiber geometry representation including patch and raster-mesh.
515515
@@ -519,7 +519,7 @@ def plot(self, ax: plt.Axes = None) -> None:
519519
Resolution for rasterizing structures. Default is 300.
520520
"""
521521
for structure in self.fiber_structure:
522-
plot_polygon(ax=ax, polygon=structure.polygon._shapely_object)
522+
plot_polygon(ax=axes, polygon=structure.polygon._shapely_object)
523523

524524
def get_structures_boundaries(self) -> numpy.ndarray:
525525
"""
@@ -565,8 +565,8 @@ def boundaries(self):
565565
"""
566566
return self.get_structure_max_min_boundaries()
567567

568-
@_plot_helper
569-
def plot_raster(self, coordinate_system, ax: plt.Axes = None) -> None:
568+
@helper.pre_plot(nrows=1, ncols=1)
569+
def plot_raster(self, axes: plt.Axes, coordinate_system) -> None:
570570
"""
571571
Render the rasterized representation of the geometry onto a given matplotlib axis.
572572
@@ -582,7 +582,16 @@ def plot_raster(self, coordinate_system, ax: plt.Axes = None) -> None:
582582
None
583583
"""
584584
mesh = self.get_raster_mesh(coordinate_system=coordinate_system)
585-
ax.pcolormesh(coordinate_system.x_vector, coordinate_system.y_vector, mesh, cmap='Blues')
585+
axes.pcolormesh(
586+
coordinate_system.x_vector,
587+
coordinate_system.y_vector,
588+
mesh,
589+
cmap='Blues'
590+
)
591+
592+
axes.set(
593+
title='Fiber structure',
594+
xlabel=r'x-distance',
595+
ylabel=r'y-distance'
596+
)
586597

587-
ax.set(title='Fiber structure', xlabel=r'x-distance [m]', ylabel=r'y-distance [m]')
588-
ax.ticklabel_format(axis='both', style='sci', scilimits=(-6, -6), useOffset=False)

FiberFusing/geometries/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,3 @@
1-
from pydantic import ConfigDict
2-
3-
config_dict = ConfigDict(
4-
extra='forbid',
5-
arbitrary_types_allowed=True,
6-
kw_only=True
7-
)
8-
91
from .base_class import Alteration # noqa: F401
102
from .point import Point # noqa: F401
113
from .linestring import LineString # noqa: F401

FiberFusing/geometries/linestring.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
import shapely.geometry as geo
1010

1111
from FiberFusing.geometries.base_class import Alteration
12-
from FiberFusing import geometries
12+
from FiberFusing.utils import config_dict
13+
import geometries
1314

14-
@dataclass(config=geometries.config_dict)
15+
@dataclass(config=config_dict)
1516
class LineString(Alteration):
1617
coordinates: Optional[Tuple] = field(default=None, repr=False)
1718
instance: Optional[geo.LineString] = field(default=None, repr=False)

FiberFusing/geometries/point.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
from pydantic.dataclasses import dataclass
88
import shapely.geometry as geo
99
import matplotlib.pyplot as plt
10+
from MPSPlots import helper
1011

11-
from FiberFusing.helper import _plot_helper
1212
from FiberFusing import geometries
13+
from FiberFusing.utils import config_dict
1314

14-
@dataclass(config=geometries.config_dict)
15+
@dataclass(config=config_dict)
1516
class Point(geometries.base_class.Alteration):
1617
position: Optional[Tuple[float, float]] = None
1718
instance: Optional[geo.Point] = None
@@ -106,8 +107,8 @@ def distance(self, other: Self) -> float:
106107
"""
107108
return numpy.sqrt((self.x - other.x)**2 + (self.y - other.y)**2)
108109

109-
@_plot_helper
110-
def plot(self, ax: plt.Axes = None, marker: str = 'x', size: int = 20, label: str = None) -> None:
110+
@helper.pre_plot(nrows=1, ncols=1)
111+
def plot(self, axes: plt.Axes, marker: str = 'x', size: int = 20, label: str = None) -> None:
111112
"""
112113
Renders this point on the given axis, optionally with text.
113114
@@ -122,4 +123,4 @@ def plot(self, ax: plt.Axes = None, marker: str = 'x', size: int = 20, label: st
122123
label : str
123124
The label for the point, if any.
124125
"""
125-
ax.scatter(self.x, self.y, label=label, marker=marker, s=size)
126+
axes.scatter(self.x, self.y, label=label, marker=marker, s=size)

FiberFusing/geometries/polygon.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
import matplotlib.pyplot as plt
99
from shapely.ops import split
1010
from shapely import affinity
11+
from MPSPlots import helper
1112

1213
from FiberFusing.coordinate_system import CoordinateSystem
1314
from FiberFusing.plottings import plot_polygon
14-
from FiberFusing.helper import _plot_helper
1515
from FiberFusing import geometries
1616

1717
class Polygon(geometries.base_class.Alteration):
@@ -142,8 +142,8 @@ def _contain_points_with_holes(self, coordinates: np.ndarray, polygon: geo.Polyg
142142
# Points are inside if they are in the exterior but not in any hole
143143
return exterior_mask & ~hole_mask
144144

145-
@_plot_helper
146-
def plot(self, ax: plt.Axes, **kwargs) -> None:
145+
@helper.pre_plot(nrows=1, ncols=1)
146+
def plot(self, axes: plt.Axes, **kwargs) -> None:
147147
"""
148148
Plots the polygon on the given Matplotlib axis.
149149
@@ -154,7 +154,7 @@ def plot(self, ax: plt.Axes, **kwargs) -> None:
154154
**kwargs
155155
Additional keyword arguments passed to the plotting function.
156156
"""
157-
plot_polygon(ax=ax, polygon=self._shapely_object, **kwargs)
157+
plot_polygon(ax=axes, polygon=self._shapely_object, **kwargs)
158158

159159
def rasterize(self, coordinate_system: CoordinateSystem) -> np.ndarray:
160160
"""

FiberFusing/geometry.py

Lines changed: 25 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77
from scipy.ndimage import gaussian_filter
88
import FiberFusing
99
import matplotlib.pyplot as plt
10-
from FiberFusing.coordinate_system import CoordinateSystem
11-
from mpl_toolkits.axes_grid1 import make_axes_locatable
1210
import matplotlib.colors as colors
13-
from MPSPlots.styles import mps
14-
from FiberFusing.helper import _plot_helper
15-
from pydantic import field_validator, ConfigDict
11+
from mpl_toolkits.axes_grid1 import make_axes_locatable
12+
from pydantic import field_validator
1613
from pydantic.dataclasses import dataclass
14+
from MPSPlots import helper
15+
16+
17+
from FiberFusing.coordinate_system import CoordinateSystem
18+
from FiberFusing.utils import config_dict
1719

1820
class DomainAlignment(Enum):
1921
"""Boundary positioning modes."""
@@ -25,7 +27,7 @@ class DomainAlignment(Enum):
2527
CENTERING = "centering"
2628

2729

28-
@dataclass(config=ConfigDict(extra='forbid', kw_only=True, arbitrary_types_allowed=True))
30+
@dataclass(config=config_dict)
2931
class Geometry():
3032
"""
3133
Represents the refractive index (RI) geometric profile including background and fiber structures.
@@ -326,8 +328,8 @@ def generate_mesh(self) -> numpy.ndarray:
326328

327329
return mesh
328330

329-
@_plot_helper
330-
def plot_patch(self, ax: plt.Axes = None, show: bool = True) -> None:
331+
@helper.pre_plot(nrows=1, ncols=1)
332+
def plot_patch(self, axes: plt.Axes) -> None:
331333
"""
332334
Render the patch representation of the geometry onto a given matplotlib axis.
333335
@@ -347,16 +349,15 @@ def plot_patch(self, ax: plt.Axes = None, show: bool = True) -> None:
347349
continue
348350

349351
if isinstance(structure, FiberFusing.profile.Profile):
350-
structure.plot(ax=ax, show=False, show_added=False, show_removed=False, show_centers=False, show_fibers=True)
352+
structure.plot(axes=axes, show=False, show_added=False, show_removed=False, show_centers=False, show_fibers=True)
351353
continue
352354

353-
structure.plot(ax=ax, show=False)
355+
structure.plot(axes=axes, show=False)
354356

355-
ax.set(title='Fiber structure', xlabel=r'x-distance [m]', ylabel=r'y-distance [m]')
356-
ax.ticklabel_format(axis='both', style='sci', scilimits=(-6, -6), useOffset=False)
357+
axes.set(title='Fiber structure', xlabel=r'x-distance [m]', ylabel=r'y-distance [m]')
357358

358-
@_plot_helper
359-
def plot_raster(self, ax: plt.Axes = None, gamma: float = 5) -> None:
359+
@helper.pre_plot(nrows=1, ncols=1)
360+
def plot_raster(self, axes: plt.Axes, gamma: float = 5) -> None:
360361
"""
361362
Render the rasterized representation of the geometry onto a given matplotlib axis.
362363
@@ -373,31 +374,27 @@ def plot_raster(self, ax: plt.Axes = None, gamma: float = 5) -> None:
373374
-------
374375
None
375376
"""
376-
image = ax.pcolormesh(
377+
image = axes.pcolormesh(
377378
self.coordinate_system.x_vector,
378379
self.coordinate_system.y_vector,
379380
self.mesh,
380381
cmap='Blues',
381382
norm=colors.PowerNorm(gamma=gamma)
382383
)
383384

384-
divider = make_axes_locatable(ax)
385+
divider = make_axes_locatable(axes)
385386
cax = divider.append_axes('right', size='5%', pad=0.05)
386-
ax.get_figure().colorbar(image, cax=cax, orientation='vertical')
387+
axes.get_figure().colorbar(image, cax=cax, orientation='vertical')
387388

388-
ax.set(title='Fiber structure', xlabel=r'x-distance [m]', ylabel=r'y-distance [m]')
389-
ax.ticklabel_format(axis='both', style='sci', scilimits=(-6, -6), useOffset=False)
389+
axes.set(title='Fiber structure', xlabel=r'x-distance [m]', ylabel=r'y-distance [m]')
390390

391-
def plot(self, show_patch: bool = True, show_mesh: bool = True, show: bool = True, gamma: float = 5) -> plt.Figure:
391+
@helper.pre_plot(nrows=1, ncols=2, subplot_kw=dict(aspect='equal', xlabel='x-distance [m]', ylabel='y-distance [m]'))
392+
def plot(self, axes, gamma: float = 5) -> plt.Figure:
392393
"""
393394
Plot the different representations (patch and mesh) of the geometry.
394395
395396
Parameters
396397
----------
397-
show_patch : bool, optional
398-
Whether to display the patch representation of the geometry. Default is True.
399-
show_mesh : bool, optional
400-
Whether to display the mesh (rasterized) representation of the geometry. Default is True.
401398
show : bool, optional
402399
Whether to immediately show the plot. Default is True.
403400
gamma : float, optional
@@ -408,29 +405,10 @@ def plot(self, show_patch: bool = True, show_mesh: bool = True, show: bool = Tru
408405
plt.Figure
409406
The matplotlib figure encompassing all the axes used in the plot.
410407
"""
411-
n_ax = bool(show_patch) + bool(show_mesh)
412-
unit_size = numpy.array([1, n_ax])
413-
414-
with plt.style.context(mps):
415-
figure, axes = plt.subplots(
416-
*unit_size,
417-
figsize=5 * numpy.flip(unit_size),
418-
sharex=True,
419-
sharey=True,
420-
subplot_kw=dict(aspect='equal', xlabel='x-distance [m]', ylabel='y-distance [m]'),
421-
)
422-
423-
axes_iter = iter(axes.flatten())
424-
425-
if show_patch:
426-
ax = next(axes_iter)
427-
self.plot_patch(ax, show=False)
408+
axes[0].sharex(axes[1])
409+
axes[0].sharey(axes[1])
428410

429-
if show_mesh:
430-
ax = next(axes_iter)
431-
self.plot_raster(ax, show=False, gamma=gamma)
432411

433-
if show:
434-
plt.show()
412+
self.plot_patch(axes=axes[0], show=False)
435413

436-
return figure
414+
self.plot_raster(axes=axes[1], show=False, gamma=gamma)

FiberFusing/helper.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,4 @@
11
import numpy
2-
from MPSPlots.styles import mps
3-
import matplotlib.pyplot as plt
4-
5-
6-
def _plot_helper(function):
7-
def wrapper(self, ax: plt.Axes = None, show: bool = True, **kwargs):
8-
if ax is None:
9-
with plt.style.context(mps):
10-
_, ax = plt.subplots(1, 1)
11-
ax.set_aspect('equal')
12-
ax.set(title='Fiber structure', xlabel=r'x-distance [m]', ylabel=r'y-distance [m]')
13-
ax.ticklabel_format(axis='both', style='sci') # , scilimits=(-6, -6), useOffset=False)
14-
15-
function(self, ax=ax, **kwargs)
16-
17-
_, labels = ax.get_legend_handles_labels()
18-
19-
# Only add a legend if there are labels
20-
if labels:
21-
ax.legend()
22-
23-
if show:
24-
plt.show()
25-
26-
return ax
27-
28-
return wrapper
29-
302

313
class OverlayStructureBaseClass:
324
def _overlay_structure_on_mesh_(self, structure_list: dict, mesh: numpy.ndarray, coordinate_system: object) -> numpy.ndarray:

FiberFusing/optical_structure.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@
22
from .geometries import Point, LineString, Polygon, EmptyPolygon # noqa: F401
33
from .shapes import Circle, Square, Ellipse # noqa: F401
44
from .background import BackGround # noqa: F401
5-
from pydantic.dataclasses import dataclass
6-
from pydantic import ConfigDict
7-
from FiberFusing.graded_index import GradedIndex
85

9-
from typing import Optional, Tuple
6+
from pydantic.dataclasses import dataclass
7+
from typing import Tuple
108
import numpy as np
119

1210

13-
@dataclass(config=ConfigDict(extra='forbid', kw_only=True))
11+
from FiberFusing.graded_index import GradedIndex
12+
from FiberFusing.utils import config_dict
13+
14+
15+
@dataclass(config=config_dict)
1416
class CircleOpticalStructure:
1517
"""
1618
Represents a circular optical structure.

0 commit comments

Comments
 (0)