77from scipy .ndimage import gaussian_filter
88import FiberFusing
99import matplotlib .pyplot as plt
10- from FiberFusing .coordinate_system import CoordinateSystem
11- from mpl_toolkits .axes_grid1 import make_axes_locatable
1210import matplotlib .colors as colors
13- from MPSPlots .styles import mps
14- from FiberFusing .helper import _plot_helper
15- from pydantic import field_validator , ConfigDict
11+ from mpl_toolkits .axes_grid1 import make_axes_locatable
12+ from pydantic import field_validator
1613from pydantic .dataclasses import dataclass
14+ from MPSPlots import helper
15+
16+
17+ from FiberFusing .coordinate_system import CoordinateSystem
18+ from FiberFusing .utils import config_dict
1719
1820class DomainAlignment (Enum ):
1921 """Boundary positioning modes."""
@@ -25,7 +27,7 @@ class DomainAlignment(Enum):
2527 CENTERING = "centering"
2628
2729
28- @dataclass (config = ConfigDict ( extra = 'forbid' , kw_only = True , arbitrary_types_allowed = True ) )
30+ @dataclass (config = config_dict )
2931class Geometry ():
3032 """
3133 Represents the refractive index (RI) geometric profile including background and fiber structures.
@@ -326,8 +328,8 @@ def generate_mesh(self) -> numpy.ndarray:
326328
327329 return mesh
328330
329- @_plot_helper
330- def plot_patch (self , ax : plt .Axes = None , show : bool = True ) -> None :
331+ @helper . pre_plot ( nrows = 1 , ncols = 1 )
332+ def plot_patch (self , axes : plt .Axes ) -> None :
331333 """
332334 Render the patch representation of the geometry onto a given matplotlib axis.
333335
@@ -347,16 +349,15 @@ def plot_patch(self, ax: plt.Axes = None, show: bool = True) -> None:
347349 continue
348350
349351 if isinstance (structure , FiberFusing .profile .Profile ):
350- structure .plot (ax = ax , show = False , show_added = False , show_removed = False , show_centers = False , show_fibers = True )
352+ structure .plot (axes = axes , show = False , show_added = False , show_removed = False , show_centers = False , show_fibers = True )
351353 continue
352354
353- structure .plot (ax = ax , show = False )
355+ structure .plot (axes = axes , show = False )
354356
355- ax .set (title = 'Fiber structure' , xlabel = r'x-distance [m]' , ylabel = r'y-distance [m]' )
356- ax .ticklabel_format (axis = 'both' , style = 'sci' , scilimits = (- 6 , - 6 ), useOffset = False )
357+ axes .set (title = 'Fiber structure' , xlabel = r'x-distance [m]' , ylabel = r'y-distance [m]' )
357358
358- @_plot_helper
359- def plot_raster (self , ax : plt .Axes = None , gamma : float = 5 ) -> None :
359+ @helper . pre_plot ( nrows = 1 , ncols = 1 )
360+ def plot_raster (self , axes : plt .Axes , gamma : float = 5 ) -> None :
360361 """
361362 Render the rasterized representation of the geometry onto a given matplotlib axis.
362363
@@ -373,31 +374,27 @@ def plot_raster(self, ax: plt.Axes = None, gamma: float = 5) -> None:
373374 -------
374375 None
375376 """
376- image = ax .pcolormesh (
377+ image = axes .pcolormesh (
377378 self .coordinate_system .x_vector ,
378379 self .coordinate_system .y_vector ,
379380 self .mesh ,
380381 cmap = 'Blues' ,
381382 norm = colors .PowerNorm (gamma = gamma )
382383 )
383384
384- divider = make_axes_locatable (ax )
385+ divider = make_axes_locatable (axes )
385386 cax = divider .append_axes ('right' , size = '5%' , pad = 0.05 )
386- ax .get_figure ().colorbar (image , cax = cax , orientation = 'vertical' )
387+ axes .get_figure ().colorbar (image , cax = cax , orientation = 'vertical' )
387388
388- ax .set (title = 'Fiber structure' , xlabel = r'x-distance [m]' , ylabel = r'y-distance [m]' )
389- ax .ticklabel_format (axis = 'both' , style = 'sci' , scilimits = (- 6 , - 6 ), useOffset = False )
389+ axes .set (title = 'Fiber structure' , xlabel = r'x-distance [m]' , ylabel = r'y-distance [m]' )
390390
391- def plot (self , show_patch : bool = True , show_mesh : bool = True , show : bool = True , gamma : float = 5 ) -> plt .Figure :
391+ @helper .pre_plot (nrows = 1 , ncols = 2 , subplot_kw = dict (aspect = 'equal' , xlabel = 'x-distance [m]' , ylabel = 'y-distance [m]' ))
392+ def plot (self , axes , gamma : float = 5 ) -> plt .Figure :
392393 """
393394 Plot the different representations (patch and mesh) of the geometry.
394395
395396 Parameters
396397 ----------
397- show_patch : bool, optional
398- Whether to display the patch representation of the geometry. Default is True.
399- show_mesh : bool, optional
400- Whether to display the mesh (rasterized) representation of the geometry. Default is True.
401398 show : bool, optional
402399 Whether to immediately show the plot. Default is True.
403400 gamma : float, optional
@@ -408,29 +405,10 @@ def plot(self, show_patch: bool = True, show_mesh: bool = True, show: bool = Tru
408405 plt.Figure
409406 The matplotlib figure encompassing all the axes used in the plot.
410407 """
411- n_ax = bool (show_patch ) + bool (show_mesh )
412- unit_size = numpy .array ([1 , n_ax ])
413-
414- with plt .style .context (mps ):
415- figure , axes = plt .subplots (
416- * unit_size ,
417- figsize = 5 * numpy .flip (unit_size ),
418- sharex = True ,
419- sharey = True ,
420- subplot_kw = dict (aspect = 'equal' , xlabel = 'x-distance [m]' , ylabel = 'y-distance [m]' ),
421- )
422-
423- axes_iter = iter (axes .flatten ())
424-
425- if show_patch :
426- ax = next (axes_iter )
427- self .plot_patch (ax , show = False )
408+ axes [0 ].sharex (axes [1 ])
409+ axes [0 ].sharey (axes [1 ])
428410
429- if show_mesh :
430- ax = next (axes_iter )
431- self .plot_raster (ax , show = False , gamma = gamma )
432411
433- if show :
434- plt .show ()
412+ self .plot_patch (axes = axes [0 ], show = False )
435413
436- return figure
414+ self . plot_raster ( axes = axes [ 1 ], show = False , gamma = gamma )
0 commit comments