Skip to content

Commit fa387dd

Browse files
committed
Allow legends to be saved to separate files.
If configured with the separate_legend parameter set to true, all legends in the figure being saved will be removed from the main figure and save in separate files with the suffix _legend<N> added to the filename. This allows for placing the legends elsewhere within TeX, for example in the margins of the page next to the main figure. To achieve this, the internal saving and postprocessing code is split out to a function so it can be called on multiple files.
1 parent 2783c6a commit fa387dd

6 files changed

Lines changed: 306 additions & 97 deletions

File tree

pgfutils.py

Lines changed: 99 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ class PGFUtilsConfig(TypedDict):
354354
axes_background: Color
355355
extra_tracking: str
356356
environment: str
357+
separate_legend: bool
357358

358359

359360
class PostProcessingConfig(TypedDict):
@@ -413,6 +414,7 @@ def __init__(self, load_file: bool = True) -> None:
413414
axes_background=parse_color("white"),
414415
extra_tracking="",
415416
environment="",
417+
separate_legend=False,
416418
)
417419
self.post_processing = dict(fix_raster_paths=True, tikzpicture=False)
418420
self.rcparams = {}
@@ -874,7 +876,7 @@ def setup_figure(
874876
matplotlib.rcParams.update(config.rcparams)
875877

876878

877-
def save(figure=None):
879+
def save(figure: "matplotlib.figure.Figure|None" = None):
878880
"""Save the figure.
879881
880882
The filename is based on the name of the script which calls this function. For
@@ -890,18 +892,55 @@ def save(figure=None):
890892
"""
891893
global _interactive
892894

893-
from matplotlib import __version__ as mpl_version, pyplot as plt
895+
from matplotlib import pyplot as plt
894896

895897
# Get the current figure if needed.
896898
if figure is None:
897899
figure = plt.gcf()
898900

901+
# Figures to save.
902+
to_save = [("main_figure", figure, None)]
903+
899904
# Go through and fix up a few little quirks on the axes within this figure.
900905
for axes in figure.get_axes():
901-
# There are no rcParams for the legend properties. Go through and set these
902-
# directly before we save.
906+
# Check if these axes have a legend.
903907
legend = axes.get_legend()
904908
if legend:
909+
# Want to save the legend as a separate figure.
910+
if config.pgfutils["separate_legend"]:
911+
# Create a new figure to hold the legend with empty axes.
912+
legend_fig = plt.figure()
913+
legend_ax = legend_fig.add_subplot()
914+
legend_ax.axis("off")
915+
916+
# Regenerate the legend on the new axes, allowing it to use the whole
917+
# figure. Removing it from the original axes and then using add_artist
918+
# on these axes seems to get the bounding box wrong.
919+
legend_standalone = legend_ax.legend(
920+
*legend.axes.get_legend_handles_labels(),
921+
bbox_to_anchor=(0, 0, 1, 1),
922+
bbox_transform=legend_fig.transFigure,
923+
ncols=legend._ncols,
924+
numpoints=legend.numpoints,
925+
scatterpoints=legend.scatterpoints,
926+
)
927+
928+
# Measure its bounding box.
929+
bbox = legend_standalone.get_tightbbox()
930+
if not bbox:
931+
raise RuntimeError("could not determine legend bounding box")
932+
933+
# This is in pixels; convert to inches as savefig() will need.
934+
bbox = bbox.transformed(legend_fig.dpi_scale_trans.inverted())
935+
936+
# Remove the original legend and replace the reference.
937+
legend.remove()
938+
legend = legend_standalone
939+
940+
# And save the standalone figure as a legend.
941+
to_save.append(("legend", legend_fig, bbox))
942+
943+
# There are no rcParams for the legend properties; set them directly.
905944
frame = legend.get_frame()
906945
frame.set_linewidth(config.pgfutils["legend_border_width"])
907946
frame.set_alpha(config.pgfutils["legend_opacity"])
@@ -948,14 +987,24 @@ def save(figure=None):
948987
return
949988

950989
# Look at the next frame up for the name of the calling script.
951-
script = Path(inspect.getfile(sys._getframe(1)))
990+
script_fn = Path(inspect.getfile(sys._getframe(1)))
991+
992+
# And save each figure object.
993+
legend_id = 0
994+
for figtype, fig, bbox in to_save:
995+
# Main figure object for the plot. Use the same path and stem.
996+
if figtype == "main_figure":
997+
pypgf_fn = script_fn.with_suffix(".pypgf")
952998

953-
# The initial Matplotlib output file, and the final figure file.
954-
mpname = script.with_suffix(".pgf")
955-
figname = script.with_suffix(".pypgf")
999+
# Legend being saved to a separate file. Add a unique suffix to the filename.
1000+
elif figtype == "legend":
1001+
pypgf_fn = script_fn.with_name(f"{script_fn.stem}_legend{legend_id}.pypgf")
1002+
legend_id += 1
9561003

957-
# Get Matplotlib to save it.
958-
figure.savefig(mpname)
1004+
else:
1005+
raise RuntimeError("unknown figure type")
1006+
1007+
_save_and_postprocess(fig, pypgf_fn, script_fn, bbox)
9591008

9601009
# We want to output tracked files.
9611010
if "PGFUTILS_TRACK_FILES" in os.environ:
@@ -966,8 +1015,11 @@ def save(figure=None):
9661015
if in_tracking_dir("import", fn):
9671016
files.append(f"r:{_relative_if_subdir(fn)}")
9681017

969-
# Include read files if in a tracked directory.
1018+
# Include read files if in a tracked directory, and they don't correspond to the
1019+
# initial PGF file written by Matplotlib that we read in for post-processing.
9701020
for fn in tracker.read:
1021+
if fn.suffix == ".pgf":
1022+
continue
9711023
if in_tracking_dir("data", fn):
9721024
files.append(f"r:{_relative_if_subdir(fn)}")
9731025

@@ -999,10 +1051,40 @@ def save(figure=None):
9991051
with open(dest, "w") as f:
10001052
f.write(filestr)
10011053

1054+
1055+
def _save_and_postprocess(
1056+
figure: "matplotlib.figure.Figure",
1057+
pypgf_fn: Path,
1058+
script_fn: Path,
1059+
bbox: "matplotlib.transforms.Bbox|None",
1060+
):
1061+
"""Save and postprocess a figure.
1062+
1063+
Parameters
1064+
----------
1065+
figure
1066+
The figure instance to save.
1067+
pypgf_fn
1068+
The filename to write the final post-processed figure to.
1069+
script_fn
1070+
The filename of the script which produced the figure.
1071+
bbox
1072+
A bounding box (in inches) giving the portion of the figure to save. If None,
1073+
the entire figure is saved.
1074+
1075+
"""
1076+
from matplotlib import __version__ as mpl_version
1077+
1078+
mpl_fn = pypgf_fn.with_suffix(".pgf")
1079+
if bbox is None:
1080+
figure.savefig(mpl_fn)
1081+
else:
1082+
figure.savefig(mpl_fn, bbox_inches=bbox)
1083+
10021084
# List of all postprocessing functions we are running on this figure. Each should
10031085
# take in a single line as a string, and return the line with any required
10041086
# modifications.
1005-
pp_funcs = []
1087+
pp_funcs: list[Callable[[str], str]] = []
10061088

10071089
# Local cache of postprocessing options.
10081090
fix_raster_paths = config.post_processing["fix_raster_paths"]
@@ -1011,7 +1093,7 @@ def save(figure=None):
10111093
# Add the appropriate directory prefix to all raster images
10121094
# included via \pgfimage.
10131095
if fix_raster_paths:
1014-
figdir = figname.parent
1096+
figdir = pypgf_fn.parent
10151097

10161098
# Only apply this if the figure is not in the top-level directory.
10171099
if not figdir.samefile("."):
@@ -1027,7 +1109,7 @@ def save(figure=None):
10271109
pp_funcs.append(lambda s: re.sub(expr, repl, s))
10281110

10291111
# Postprocess the figure, moving it into the final destination.
1030-
with open(mpname, "r") as infile, open(figname, "w") as outfile:
1112+
with open(mpl_fn, "r") as infile, open(pypgf_fn, "w") as outfile:
10311113
# Make some modifications to the header.
10321114
line = infile.readline()
10331115
while line[0] == "%":
@@ -1040,13 +1122,13 @@ def save(figure=None):
10401122
outfile.write(", matplotlib-pgfutils v")
10411123
outfile.write(__version__)
10421124
outfile.write("\n%% Script: ")
1043-
outfile.write(str(script.resolve()))
1125+
outfile.write(str(script_fn.resolve()))
10441126
outfile.write("\n")
10451127

10461128
# Update the \input instructions.
10471129
elif r"\input{<filename>.pgf}" in line:
10481130
outfile.write("%% \\input{")
1049-
outfile.write(str(figname))
1131+
outfile.write(str(pypgf_fn))
10501132
outfile.write("}\n")
10511133

10521134
# If we're changing the figure to use the tikzpicture environment, we also
@@ -1084,4 +1166,4 @@ def save(figure=None):
10841166
outfile.write(line)
10851167

10861168
# Delete the original file.
1087-
os.remove(mpname)
1169+
os.remove(mpl_fn)

tests/sources/legend/legend_only.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
Line2D([0], [0], color=cmap(1), lw=4),
1616
]
1717

18+
# Disable all other frames and sources of text.
1819
fig = plt.figure(frameon=False)
1920
ax = fig.add_axes([0, 0, 1, 1])
2021
ax.axis("off")
22+
2123
plt.legend(custom_lines, ("One", "Two", "Three"), loc="center")
2224

2325
save()
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# SPDX-FileCopyrightText: Blair Bonnett
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
from pgfutils import save, setup_figure
5+
6+
setup_figure(width=1, height=1)
7+
8+
from matplotlib import pyplot as plt
9+
import numpy as np
10+
11+
t = np.arange(0, 1, 0.005)
12+
s1 = np.sin(2 * np.pi * t)
13+
s2 = np.sin(4 * np.pi * t)
14+
s3 = np.sin(6 * np.pi * t)
15+
16+
# Disable all other frames and sources of text.
17+
fig = plt.figure(frameon=False)
18+
ax = fig.add_axes([0, 0, 1, 1])
19+
ax.axis("off")
20+
21+
ax.plot(t, s1, label="One")
22+
ax.plot(t, s2, label="Two")
23+
ax.plot(t, s3, label="Three")
24+
ax.legend()
25+
26+
save()
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# SPDX-FileCopyrightText: Blair Bonnett
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
from pgfutils import save, setup_figure
5+
6+
setup_figure(width=1, height=1, separate_legend=True)
7+
8+
from matplotlib import pyplot as plt
9+
import numpy as np
10+
11+
t = np.arange(0, 1, 0.005)
12+
s1 = np.sin(2 * np.pi * t)
13+
s2 = np.sin(4 * np.pi * t)
14+
s3 = np.sin(6 * np.pi * t)
15+
16+
# Disable all other frames and sources of text.
17+
fig = plt.figure(frameon=False)
18+
ax = fig.add_axes([0, 0, 1, 1])
19+
ax.axis("off")
20+
21+
ax.plot(t, s1, label="One")
22+
ax.plot(t, s2, label="Two")
23+
ax.plot(t, s3, label="Three")
24+
ax.legend()
25+
26+
save()

0 commit comments

Comments
 (0)