Skip to content

Commit af494df

Browse files
authored
fix: prevent mpl figure DPI from compounding on cell rerun (#9474)
1 parent 6b8d3d1 commit af494df

4 files changed

Lines changed: 159 additions & 5 deletions

File tree

marimo/_plugins/stateless/mpl/_mpl.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,17 @@ def interactive(figure: Figure | SubFigure | Axes) -> Html:
128128
if not ctx.virtual_files_supported:
129129
return NonInteractiveMplHtml(figure)
130130

131+
# Figure::figure returns self; SubFigure::figure returns the parent Figure
132+
is_subfigure = figure.figure is not figure
133+
if is_subfigure:
134+
warnings.warn(
135+
message="SubFigure is not supported in interactive mode; "
136+
"rendering as static PNG instead. "
137+
"Consider using a regular Figure instead.",
138+
stacklevel=2,
139+
)
140+
return NonInteractiveMplHtml(figure=figure)
141+
131142
from marimo._plugins.ui._impl.from_mpl_interactive import mpl_interactive
132143

133144
return mpl_interactive(figure)

marimo/_plugins/ui/_impl/from_mpl_interactive.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,12 @@ class mpl_interactive(UIElement[ModelIdRef, dict[str, Any]]):
163163
def __init__(self, figure: Figure | SubFigure) -> None:
164164
self._figure = figure
165165

166+
# SubFigure delegates dpi/size_inches to its parent Figure;
167+
# Figure.figure returns self, so this works for both.
168+
root = figure.figure
169+
self._original_dpi = root.get_dpi()
170+
self._original_size_inches = tuple(root.get_size_inches())
171+
166172
# Create FigureManagerWebAgg
167173
self._figure_manager = new_figure_manager_given_figure(
168174
id(figure), figure
@@ -233,7 +239,11 @@ def _initialize(
233239
ctx = get_context()
234240
ctx.cell_lifecycle_registry.add(
235241
_MplCleanupHandle(
236-
self._comm, self._figure_manager, self._sync_ws
242+
comm=self._comm,
243+
figure_manager=self._figure_manager,
244+
sync_ws=self._sync_ws,
245+
original_dpi=self._original_dpi,
246+
original_size_inches=self._original_size_inches,
237247
)
238248
)
239249
except ContextNotInitializedError:
@@ -302,12 +312,16 @@ class _MplCleanupHandle(CellLifecycleItem):
302312
def __init__(
303313
self,
304314
comm: MarimoComm,
315+
original_dpi: float,
316+
original_size_inches: tuple[float, float],
305317
figure_manager: Any = None,
306318
sync_ws: Any = None,
307319
) -> None:
308320
self._comm = comm
309321
self._figure_manager = figure_manager
310322
self._sync_ws = sync_ws
323+
self._original_dpi = original_dpi
324+
self._original_size_inches = original_size_inches
311325

312326
def create(self, context: RuntimeContext) -> None:
313327
del context
@@ -323,9 +337,18 @@ def dispose(self, context: RuntimeContext, deletion: bool) -> bool:
323337
except Exception:
324338
pass
325339
if self._figure_manager is not None:
340+
try:
341+
# get the root figure (in case of Subfigure) which handles dpi
342+
root = self._figure_manager.canvas.figure.figure
343+
root.set_dpi(self._original_dpi)
344+
root.set_size_inches(*self._original_size_inches)
345+
except Exception:
346+
pass
347+
326348
try:
327349
self._figure_manager.canvas.close()
328350
except Exception:
329351
pass
352+
330353
self._comm.close()
331354
return True

tests/_plugins/stateless/test_mpl.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,48 @@ def test_mpl_interactive_fallback_when_virtual_files_not_supported() -> None:
147147
plt.close(fig)
148148

149149

150+
@pytest.mark.requires("matplotlib")
151+
def test_mpl_interactive_subfigure_falls_back_to_png() -> None:
152+
"""SubFigure cannot be attached to a WebAgg canvas (matplotlib limitation),
153+
so mo.mpl.interactive should fall back to a static PNG and warn.
154+
"""
155+
import warnings
156+
157+
import matplotlib.pyplot as plt
158+
159+
from marimo._plugins.stateless.mpl._mpl import (
160+
NonInteractiveMplHtml,
161+
interactive,
162+
)
163+
from marimo._runtime.context.kernel_context import KernelRuntimeContext
164+
165+
parent = plt.figure(figsize=(8, 4), dpi=100)
166+
sub_left, _sub_right = parent.subfigures(1, 2)
167+
sub_left.subplots().plot([1, 2, 3])
168+
169+
mock_ctx = MagicMock(spec=KernelRuntimeContext)
170+
mock_ctx.virtual_files_supported = True
171+
172+
with (
173+
patch(
174+
"marimo._plugins.stateless.mpl._mpl.get_context",
175+
return_value=mock_ctx,
176+
),
177+
warnings.catch_warnings(record=True) as captured,
178+
):
179+
warnings.simplefilter("always")
180+
result = interactive(sub_left)
181+
182+
assert isinstance(result, NonInteractiveMplHtml)
183+
184+
# A UserWarning explaining the fallback should have been emitted.
185+
subfigure_warnings = [w for w in captured if "SubFigure" in str(w.message)]
186+
assert len(subfigure_warnings) == 1
187+
assert issubclass(subfigure_warnings[0].category, UserWarning)
188+
189+
plt.close(parent)
190+
191+
150192
@pytest.mark.requires("matplotlib")
151193
def test_new_figure_manager_suppresses_thread_warning() -> None:
152194
"""Regression test for https://github.com/marimo-team/marimo/issues/8747.

tests/_plugins/ui/_impl/test_from_mpl_interactive.py

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66

77
import pytest
88

9+
mpl = pytest.importorskip("matplotlib")
10+
mpl.use(
11+
"Agg"
12+
) # Non-interactive backend; avoids DPR pre-inflation on HiDPI hosts.
13+
914

1015
@pytest.mark.requires("matplotlib")
1116
class TestSyncWebSocket:
@@ -180,6 +185,63 @@ def capture_send(data: Any, buffers: Any = None) -> None:
180185
plt.close(fig)
181186

182187

188+
@pytest.mark.requires("matplotlib")
189+
class TestDpiPreservationOnRerun:
190+
"""Re-running a cell that wraps the same figure should not compound DPI.
191+
192+
matplotlib's ``FigureCanvasBase.__init__`` unconditionally captures
193+
``figure._original_dpi = figure.dpi``. After a HiDPI client connects
194+
and scales ``figure.dpi`` up by the device pixel ratio, a subsequent
195+
canvas creation on the same figure would treat that scaled value as
196+
"original" and scale it again — making the resolution compound on
197+
every rerun (see issue #9466).
198+
"""
199+
200+
def test_dpi_does_not_compound_across_reruns(self) -> None:
201+
import matplotlib.pyplot as plt
202+
203+
from marimo._plugins.ui._impl.from_mpl_interactive import (
204+
_MplCleanupHandle,
205+
mpl_interactive,
206+
)
207+
208+
fig, ax = plt.subplots(figsize=(5, 5), dpi=100)
209+
ax.plot([1, 2, 3])
210+
211+
for _ in range(3):
212+
with patch("marimo._plugins.ui._impl.comm.broadcast_notification"):
213+
element = mpl_interactive(fig)
214+
215+
# Simulate the HiDPI handshake from the frontend.
216+
element._figure_manager.handle_json(
217+
{"type": "set_device_pixel_ratio", "device_pixel_ratio": 2}
218+
)
219+
element._figure_manager.handle_json(
220+
{"type": "resize", "width": 500, "height": 500}
221+
)
222+
# While the canvas is live, dpi reflects the device-scaled value.
223+
assert fig.dpi == 200
224+
assert tuple(fig.get_size_inches()) == (5.0, 5.0)
225+
226+
# Simulate cell teardown — the cleanup handle is what marimo
227+
# registers via cell_lifecycle_registry; running it directly
228+
# avoids needing a live runtime context in the test.
229+
cleanup = _MplCleanupHandle(
230+
comm=element._comm,
231+
figure_manager=element._figure_manager,
232+
sync_ws=element._sync_ws,
233+
original_dpi=element._original_dpi,
234+
original_size_inches=element._original_size_inches,
235+
)
236+
cleanup.dispose(context=MagicMock(), deletion=False)
237+
238+
# After dispose the figure is restored to the user's intent.
239+
assert fig.dpi == 100
240+
assert tuple(fig.get_size_inches()) == (5.0, 5.0)
241+
242+
plt.close(fig)
243+
244+
183245
@pytest.mark.requires("matplotlib")
184246
class TestMplCleanupHandle:
185247
"""Test that _MplCleanupHandle properly closes the comm."""
@@ -190,7 +252,9 @@ def test_dispose_closes_comm(self) -> None:
190252
)
191253

192254
mock_comm = MagicMock()
193-
handle = _MplCleanupHandle(mock_comm)
255+
handle = _MplCleanupHandle(
256+
mock_comm, original_dpi=100, original_size_inches=(5.0, 5.0)
257+
)
194258
result = handle.dispose(context=MagicMock(), deletion=False)
195259

196260
assert result is True
@@ -202,7 +266,9 @@ def test_dispose_on_deletion(self) -> None:
202266
)
203267

204268
mock_comm = MagicMock()
205-
handle = _MplCleanupHandle(mock_comm)
269+
handle = _MplCleanupHandle(
270+
mock_comm, original_dpi=100, original_size_inches=(5.0, 5.0)
271+
)
206272
result = handle.dispose(context=MagicMock(), deletion=True)
207273

208274
assert result is True
@@ -216,7 +282,13 @@ def test_dispose_cleans_up_figure_manager(self) -> None:
216282
mock_comm = MagicMock()
217283
mock_manager = MagicMock()
218284
mock_ws = MagicMock()
219-
handle = _MplCleanupHandle(mock_comm, mock_manager, mock_ws)
285+
handle = _MplCleanupHandle(
286+
mock_comm,
287+
figure_manager=mock_manager,
288+
sync_ws=mock_ws,
289+
original_dpi=100,
290+
original_size_inches=(5.0, 5.0),
291+
)
220292
result = handle.dispose(context=MagicMock(), deletion=False)
221293

222294
assert result is True
@@ -234,7 +306,13 @@ def test_dispose_tolerates_manager_errors(self) -> None:
234306
mock_manager.remove_web_socket.side_effect = RuntimeError("boom")
235307
mock_manager.canvas.close.side_effect = RuntimeError("boom")
236308
mock_ws = MagicMock()
237-
handle = _MplCleanupHandle(mock_comm, mock_manager, mock_ws)
309+
handle = _MplCleanupHandle(
310+
mock_comm,
311+
figure_manager=mock_manager,
312+
sync_ws=mock_ws,
313+
original_dpi=100,
314+
original_size_inches=(5.0, 5.0),
315+
)
238316
# Should not raise even if manager cleanup fails
239317
result = handle.dispose(context=MagicMock(), deletion=False)
240318

0 commit comments

Comments
 (0)