@@ -354,6 +354,7 @@ class PGFUtilsConfig(TypedDict):
354354 axes_background : Color
355355 extra_tracking : str
356356 environment : str
357+ separate_legend : bool
357358
358359
359360class PostProcessingConfig (TypedDict ):
@@ -413,6 +414,7 @@ def __init__(self, load_file: bool = True) -> None:
413414 axes_background = parse_color ("white" ),
414415 extra_tracking = "" ,
415416 environment = "" ,
417+ separate_legend = False ,
416418 )
417419 self .post_processing = dict (fix_raster_paths = True , tikzpicture = False )
418420 self .rcparams = {}
@@ -874,7 +876,7 @@ def setup_figure(
874876 matplotlib .rcParams .update (config .rcparams )
875877
876878
877- def save (figure = None ):
879+ def save (figure : "matplotlib.figure.Figure|None" = None ):
878880 """Save the figure.
879881
880882 The filename is based on the name of the script which calls this function. For
@@ -890,18 +892,55 @@ def save(figure=None):
890892 """
891893 global _interactive
892894
893- from matplotlib import __version__ as mpl_version , pyplot as plt
895+ from matplotlib import pyplot as plt
894896
895897 # Get the current figure if needed.
896898 if figure is None :
897899 figure = plt .gcf ()
898900
901+ # Figures to save.
902+ to_save = [("main_figure" , figure , None )]
903+
899904 # Go through and fix up a few little quirks on the axes within this figure.
900905 for axes in figure .get_axes ():
901- # There are no rcParams for the legend properties. Go through and set these
902- # directly before we save.
906+ # Check if these axes have a legend.
903907 legend = axes .get_legend ()
904908 if legend :
909+ # Want to save the legend as a separate figure.
910+ if config .pgfutils ["separate_legend" ]:
911+ # Create a new figure to hold the legend with empty axes.
912+ legend_fig = plt .figure ()
913+ legend_ax = legend_fig .add_subplot ()
914+ legend_ax .axis ("off" )
915+
916+ # Regenerate the legend on the new axes, allowing it to use the whole
917+ # figure. Removing it from the original axes and then using add_artist
918+ # on these axes seems to get the bounding box wrong.
919+ legend_standalone = legend_ax .legend (
920+ * legend .axes .get_legend_handles_labels (),
921+ bbox_to_anchor = (0 , 0 , 1 , 1 ),
922+ bbox_transform = legend_fig .transFigure ,
923+ ncols = legend ._ncols ,
924+ numpoints = legend .numpoints ,
925+ scatterpoints = legend .scatterpoints ,
926+ )
927+
928+ # Measure its bounding box.
929+ bbox = legend_standalone .get_tightbbox ()
930+ if not bbox :
931+ raise RuntimeError ("could not determine legend bounding box" )
932+
933+ # This is in pixels; convert to inches as savefig() will need.
934+ bbox = bbox .transformed (legend_fig .dpi_scale_trans .inverted ())
935+
936+ # Remove the original legend and replace the reference.
937+ legend .remove ()
938+ legend = legend_standalone
939+
940+ # And save the standalone figure as a legend.
941+ to_save .append (("legend" , legend_fig , bbox ))
942+
943+ # There are no rcParams for the legend properties; set them directly.
905944 frame = legend .get_frame ()
906945 frame .set_linewidth (config .pgfutils ["legend_border_width" ])
907946 frame .set_alpha (config .pgfutils ["legend_opacity" ])
@@ -948,14 +987,24 @@ def save(figure=None):
948987 return
949988
950989 # Look at the next frame up for the name of the calling script.
951- script = Path (inspect .getfile (sys ._getframe (1 )))
990+ script_fn = Path (inspect .getfile (sys ._getframe (1 )))
991+
992+ # And save each figure object.
993+ legend_id = 0
994+ for figtype , fig , bbox in to_save :
995+ # Main figure object for the plot. Use the same path and stem.
996+ if figtype == "main_figure" :
997+ pypgf_fn = script_fn .with_suffix (".pypgf" )
952998
953- # The initial Matplotlib output file, and the final figure file.
954- mpname = script .with_suffix (".pgf" )
955- figname = script .with_suffix (".pypgf" )
999+ # Legend being saved to a separate file. Add a unique suffix to the filename.
1000+ elif figtype == "legend" :
1001+ pypgf_fn = script_fn .with_name (f"{ script_fn .stem } _legend{ legend_id } .pypgf" )
1002+ legend_id += 1
9561003
957- # Get Matplotlib to save it.
958- figure .savefig (mpname )
1004+ else :
1005+ raise RuntimeError ("unknown figure type" )
1006+
1007+ _save_and_postprocess (fig , pypgf_fn , script_fn , bbox )
9591008
9601009 # We want to output tracked files.
9611010 if "PGFUTILS_TRACK_FILES" in os .environ :
@@ -966,8 +1015,11 @@ def save(figure=None):
9661015 if in_tracking_dir ("import" , fn ):
9671016 files .append (f"r:{ _relative_if_subdir (fn )} " )
9681017
969- # Include read files if in a tracked directory.
1018+ # Include read files if in a tracked directory, and they don't correspond to the
1019+ # initial PGF file written by Matplotlib that we read in for post-processing.
9701020 for fn in tracker .read :
1021+ if fn .suffix == ".pgf" :
1022+ continue
9711023 if in_tracking_dir ("data" , fn ):
9721024 files .append (f"r:{ _relative_if_subdir (fn )} " )
9731025
@@ -999,10 +1051,40 @@ def save(figure=None):
9991051 with open (dest , "w" ) as f :
10001052 f .write (filestr )
10011053
1054+
1055+ def _save_and_postprocess (
1056+ figure : "matplotlib.figure.Figure" ,
1057+ pypgf_fn : Path ,
1058+ script_fn : Path ,
1059+ bbox : "matplotlib.transforms.Bbox|None" ,
1060+ ):
1061+ """Save and postprocess a figure.
1062+
1063+ Parameters
1064+ ----------
1065+ figure
1066+ The figure instance to save.
1067+ pypgf_fn
1068+ The filename to write the final post-processed figure to.
1069+ script_fn
1070+ The filename of the script which produced the figure.
1071+ bbox
1072+ A bounding box (in inches) giving the portion of the figure to save. If None,
1073+ the entire figure is saved.
1074+
1075+ """
1076+ from matplotlib import __version__ as mpl_version
1077+
1078+ mpl_fn = pypgf_fn .with_suffix (".pgf" )
1079+ if bbox is None :
1080+ figure .savefig (mpl_fn )
1081+ else :
1082+ figure .savefig (mpl_fn , bbox_inches = bbox )
1083+
10021084 # List of all postprocessing functions we are running on this figure. Each should
10031085 # take in a single line as a string, and return the line with any required
10041086 # modifications.
1005- pp_funcs = []
1087+ pp_funcs : list [ Callable [[ str ], str ]] = []
10061088
10071089 # Local cache of postprocessing options.
10081090 fix_raster_paths = config .post_processing ["fix_raster_paths" ]
@@ -1011,7 +1093,7 @@ def save(figure=None):
10111093 # Add the appropriate directory prefix to all raster images
10121094 # included via \pgfimage.
10131095 if fix_raster_paths :
1014- figdir = figname .parent
1096+ figdir = pypgf_fn .parent
10151097
10161098 # Only apply this if the figure is not in the top-level directory.
10171099 if not figdir .samefile ("." ):
@@ -1027,7 +1109,7 @@ def save(figure=None):
10271109 pp_funcs .append (lambda s : re .sub (expr , repl , s ))
10281110
10291111 # Postprocess the figure, moving it into the final destination.
1030- with open (mpname , "r" ) as infile , open (figname , "w" ) as outfile :
1112+ with open (mpl_fn , "r" ) as infile , open (pypgf_fn , "w" ) as outfile :
10311113 # Make some modifications to the header.
10321114 line = infile .readline ()
10331115 while line [0 ] == "%" :
@@ -1040,13 +1122,13 @@ def save(figure=None):
10401122 outfile .write (", matplotlib-pgfutils v" )
10411123 outfile .write (__version__ )
10421124 outfile .write ("\n %% Script: " )
1043- outfile .write (str (script .resolve ()))
1125+ outfile .write (str (script_fn .resolve ()))
10441126 outfile .write ("\n " )
10451127
10461128 # Update the \input instructions.
10471129 elif r"\input{<filename>.pgf}" in line :
10481130 outfile .write ("%% \\ input{" )
1049- outfile .write (str (figname ))
1131+ outfile .write (str (pypgf_fn ))
10501132 outfile .write ("}\n " )
10511133
10521134 # If we're changing the figure to use the tikzpicture environment, we also
@@ -1084,4 +1166,4 @@ def save(figure=None):
10841166 outfile .write (line )
10851167
10861168 # Delete the original file.
1087- os .remove (mpname )
1169+ os .remove (mpl_fn )
0 commit comments