55
66import matplotlib .patches as patches
77import matplotlib .pyplot as plt
8+ import numpy as np
89from plotly .subplots import make_subplots
910from tikzfigure import TikzFigure
1011
@@ -579,6 +580,41 @@ def text(
579580 """Add a text label at (x, y) on a subplot."""
580581 self ._get_or_create_subplot (row , col ).text (x , y , s , layer = layer , ** kwargs )
581582
583+ def imshow (
584+ self ,
585+ data ,
586+ layer = 0 ,
587+ row : int | None = None ,
588+ col : int | None = None ,
589+ ** kwargs ,
590+ ):
591+ """Add an image/matrix plot to a subplot."""
592+ self ._get_or_create_subplot (row , col ).add_imshow (data , layer = layer , ** kwargs )
593+
594+ def add_patch (
595+ self ,
596+ patch ,
597+ layer = 0 ,
598+ row : int | None = None ,
599+ col : int | None = None ,
600+ ** kwargs ,
601+ ):
602+ """Add a Matplotlib patch to a subplot."""
603+ self ._get_or_create_subplot (row , col ).add_patch (patch , layer = layer , ** kwargs )
604+
605+ def colorbar (
606+ self ,
607+ label : str = "" ,
608+ layer = 0 ,
609+ row : int | None = None ,
610+ col : int | None = None ,
611+ ** kwargs ,
612+ ):
613+ """Add a colorbar to the most recent imshow() on a subplot (matplotlib backend)."""
614+ self ._get_or_create_subplot (row , col ).add_colorbar (
615+ label = label , layer = layer , ** kwargs
616+ )
617+
582618 # ------------------------------------------------------------------
583619 # Multi-subplot helpers
584620 # ------------------------------------------------------------------
@@ -773,6 +809,34 @@ def savefig(
773809 figure .savefig (full_filepath )
774810 if verbose :
775811 print (f"Saved { full_filepath } " )
812+ elif backend == "plotly" :
813+ if layer_by_layer :
814+ layers = []
815+ for layer in self .layers :
816+ layers .append (layer )
817+ full_filepath = f"{ filename_no_extension } _{ layers } { extension } "
818+ fig = self .plot (
819+ backend = "plotly" ,
820+ savefig = False ,
821+ layers = layers ,
822+ )
823+ self ._save_plotly (fig , full_filepath )
824+ if verbose :
825+ print (f"Saved { full_filepath } " )
826+ else :
827+ if layers is None :
828+ layers = self .layers
829+ full_filepath = filename
830+ else :
831+ full_filepath = f"{ filename_no_extension } _{ layers } { extension } "
832+ fig = self .plot (
833+ backend = "plotly" ,
834+ savefig = False ,
835+ layers = layers ,
836+ )
837+ self ._save_plotly (fig , full_filepath )
838+ if verbose :
839+ print (f"Saved { full_filepath } " )
776840
777841 def plot (
778842 self ,
@@ -797,6 +861,7 @@ def plot(
797861 elif backend == "plotly" :
798862 return self .plot_plotly (
799863 savefig = savefig ,
864+ layers = layers ,
800865 usetex = resolved_usetex ,
801866 verbose = verbose ,
802867 )
@@ -832,7 +897,11 @@ def show(
832897 # self._matplotlib_fig.show()
833898 elif backend == "plotly" :
834899 resolved_usetex = self ._usetex if usetex is None else usetex
835- self .plot_plotly (savefig = False , usetex = resolved_usetex )
900+ fig = self .plot_plotly (
901+ savefig = False , layers = layers , usetex = resolved_usetex , verbose = verbose
902+ )
903+ fig .show ()
904+ return fig
836905 elif backend == "plotext" :
837906 figure = self .plot_plotext (
838907 savefig = False ,
@@ -1034,6 +1103,7 @@ def plot_plotly(
10341103 self ,
10351104 show = True ,
10361105 savefig = None ,
1106+ layers : list | None = None ,
10371107 usetex : bool | None = None ,
10381108 verbose : bool = False ,
10391109 ):
@@ -1063,38 +1133,123 @@ def plot_plotly(
10631133 ratio = self ._ratio ,
10641134 )
10651135 # print(self._width, fig_width, fig_height)
1066- # Create subplots
1136+ # Create subplot titles in row-major order (Plotly expects rows*cols entries)
1137+ subplot_titles = ["" ] * (self .nrows * self .ncols )
1138+ for (row , col ), sp in self ._subplot_dict .items ():
1139+ index = row * self .ncols + col
1140+ subplot_titles [index ] = sp ._title or f"({ row } , { col } )"
1141+
10671142 fig = make_subplots (
10681143 rows = self .nrows ,
10691144 cols = self .ncols ,
1070- subplot_titles = [
1071- sp ._title or f"({ row } , { col } )"
1072- for (row , col ), sp in self ._subplot_dict .items ()
1073- ],
1145+ subplot_titles = subplot_titles ,
10741146 )
10751147
10761148 # Plot each subplot and propagate axis labels/scale
1077- axis_index = 1
10781149 for (row , col ), line_plot in self ._subplot_dict .items ():
1079- traces = line_plot .plot_plotly ()
1150+ traces , shapes , annotations = line_plot .plot_plotly (layers = layers )
10801151 for trace in traces :
10811152 fig .add_trace (trace , row = row + 1 , col = col + 1 )
10821153
1083- # Axis label keys are "xaxis", "xaxis2", "xaxis3", ...
1084- xkey = "xaxis" if axis_index == 1 else f"xaxis{ axis_index } "
1085- ykey = "yaxis" if axis_index == 1 else f"yaxis{ axis_index } "
1086- layout_patch = {}
1087- if line_plot ._xlabel :
1088- layout_patch [xkey ] = {"title" : {"text" : line_plot ._xlabel }}
1089- if line_plot ._ylabel :
1090- layout_patch [ykey ] = {"title" : {"text" : line_plot ._ylabel }}
1154+ # Axis indices are row-major: (row*ncols + col + 1)
1155+ axis_index = row * self .ncols + col + 1
1156+ xref = "x" if axis_index == 1 else f"x{ axis_index } "
1157+ yref = "y" if axis_index == 1 else f"y{ axis_index } "
1158+
1159+ for shape in shapes :
1160+ shape = dict (shape )
1161+ if shape .get ("xref" ) not in {"paper" }:
1162+ shape ["xref" ] = xref
1163+ if shape .get ("yref" ) not in {"paper" }:
1164+ shape ["yref" ] = yref
1165+ fig .add_shape (shape )
1166+
1167+ for annotation in annotations :
1168+ annotation = dict (annotation )
1169+ annotation .setdefault ("xref" , xref )
1170+ annotation .setdefault ("yref" , yref )
1171+ fig .add_annotation (annotation )
1172+
1173+ # Apply per-axis config in a row/col-safe way
1174+ xaxis_kwargs = dict (
1175+ title_text = line_plot ._xlabel or None ,
1176+ showgrid = bool (line_plot ._grid ),
1177+ row = row + 1 ,
1178+ col = col + 1 ,
1179+ )
10911180 if line_plot ._xaxis_scale == "log" :
1092- layout_patch .setdefault (xkey , {})["type" ] = "log"
1181+ xaxis_kwargs ["type" ] = "log"
1182+ fig .update_xaxes (** xaxis_kwargs )
1183+
1184+ yaxis_kwargs = dict (
1185+ title_text = line_plot ._ylabel or None ,
1186+ showgrid = bool (line_plot ._grid ),
1187+ row = row + 1 ,
1188+ col = col + 1 ,
1189+ )
10931190 if line_plot ._yaxis_scale == "log" :
1094- layout_patch .setdefault (ykey , {})["type" ] = "log"
1095- if layout_patch :
1096- fig .update_layout (** layout_patch )
1097- axis_index += 1
1191+ yaxis_kwargs ["type" ] = "log"
1192+ fig .update_yaxes (** yaxis_kwargs )
1193+
1194+ # Axis limits
1195+ if line_plot ._xmin is not None or line_plot ._xmax is not None :
1196+ x_range = [
1197+ line_plot ._xmin if line_plot ._xmin is not None else None ,
1198+ line_plot ._xmax if line_plot ._xmax is not None else None ,
1199+ ]
1200+ if (
1201+ line_plot ._xaxis_scale == "log"
1202+ and x_range [0 ] is not None
1203+ and x_range [1 ] is not None
1204+ and x_range [0 ] > 0
1205+ and x_range [1 ] > 0
1206+ ):
1207+ x_range = [np .log10 (x_range [0 ]), np .log10 (x_range [1 ])]
1208+ fig .update_xaxes (
1209+ range = x_range ,
1210+ row = row + 1 ,
1211+ col = col + 1 ,
1212+ )
1213+ if line_plot ._ymin is not None or line_plot ._ymax is not None :
1214+ y_range = [
1215+ line_plot ._ymin if line_plot ._ymin is not None else None ,
1216+ line_plot ._ymax if line_plot ._ymax is not None else None ,
1217+ ]
1218+ if (
1219+ line_plot ._yaxis_scale == "log"
1220+ and y_range [0 ] is not None
1221+ and y_range [1 ] is not None
1222+ and y_range [0 ] > 0
1223+ and y_range [1 ] > 0
1224+ ):
1225+ y_range = [np .log10 (y_range [0 ]), np .log10 (y_range [1 ])]
1226+ fig .update_yaxes (
1227+ range = y_range ,
1228+ row = row + 1 ,
1229+ col = col + 1 ,
1230+ )
1231+
1232+ # Custom ticks (positions + optional labels)
1233+ if line_plot ._xticks is not None :
1234+ fig .update_xaxes (
1235+ tickmode = "array" ,
1236+ tickvals = line_plot ._xticks ,
1237+ ticktext = line_plot ._xticklabels ,
1238+ row = row + 1 ,
1239+ col = col + 1 ,
1240+ )
1241+ if line_plot ._yticks is not None :
1242+ fig .update_yaxes (
1243+ tickmode = "array" ,
1244+ tickvals = line_plot ._yticks ,
1245+ ticktext = line_plot ._yticklabels ,
1246+ row = row + 1 ,
1247+ col = col + 1 ,
1248+ )
1249+
1250+ # Aspect ratio
1251+ if line_plot ._aspect == "equal" :
1252+ fig .update_yaxes (scaleanchor = xref , row = row + 1 , col = col + 1 )
10981253
10991254 # Update layout settings
11001255 fig .update_layout (
@@ -1105,10 +1260,30 @@ def plot_plotly(
11051260 fig .update_layout (title = dict (text = self ._suptitle , x = 0.5 ))
11061261
11071262 if savefig :
1108- fig .write_image (savefig )
1263+ try :
1264+ fig .write_image (savefig )
1265+ except Exception as exc :
1266+ raise RuntimeError (
1267+ "Plotly image export failed. If you are exporting to PNG/PDF/SVG, "
1268+ "install kaleido (e.g., `pip install -U kaleido`)."
1269+ ) from exc
11091270
11101271 return fig
11111272
1273+ def _save_plotly (self , fig , filename : str ) -> None :
1274+ _ , extension = os .path .splitext (filename )
1275+ extension = extension .lower ()
1276+ if extension in {".html" , ".htm" }:
1277+ fig .write_html (filename )
1278+ return
1279+ try :
1280+ fig .write_image (filename )
1281+ except Exception as exc :
1282+ raise RuntimeError (
1283+ "Plotly image export failed. For PNG/PDF/SVG export, install kaleido "
1284+ "(e.g., `pip install -U kaleido`), or export to HTML instead."
1285+ ) from exc
1286+
11121287 # Property getters
11131288
11141289 @property
0 commit comments