Skip to content

Commit cdf6025

Browse files
authored
refactor(runtime): extract ModuleReloader/ModuleWatcher into AutoreloadManager (#9590)
1 parent 5d1923c commit cdf6025

3 files changed

Lines changed: 195 additions & 64 deletions

File tree

marimo/_runtime/reload/manager.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2026 Marimo. All rights reserved.
2+
"""Autoreload manager: owns `ModuleReloader` and `ModuleWatcher` on behalf of the kernel."""
3+
4+
from __future__ import annotations
5+
6+
import contextlib
7+
import sys
8+
from typing import TYPE_CHECKING, Literal
9+
10+
from marimo._runtime.reload.autoreload import ModuleReloader
11+
from marimo._runtime.reload.module_watcher import ModuleWatcher
12+
from marimo._utils.platform import is_pyodide
13+
14+
if TYPE_CHECKING:
15+
from collections.abc import Iterator
16+
17+
from marimo._ast.cell import CellImpl
18+
from marimo._runtime.runner.hook_context import OnFinishHookContext
19+
from marimo._runtime.runtime import Kernel
20+
21+
AutoReloadMode = Literal["off", "lazy", "autorun"]
22+
23+
24+
class AutoreloadManager:
25+
"""Owns ModuleReloader + ModuleWatcher and reacts to config changes."""
26+
27+
def __init__(self, kernel: Kernel) -> None:
28+
self._kernel = kernel
29+
self._reloader: ModuleReloader | None = None
30+
self._watcher: ModuleWatcher | None = None
31+
32+
# Re-arm the watcher after every kernel run, regardless of trigger.
33+
kernel._hooks.add_on_finish(self._on_finish_hook)
34+
35+
@property
36+
def reloader(self) -> ModuleReloader | None:
37+
return self._reloader
38+
39+
@property
40+
def watcher(self) -> ModuleWatcher | None:
41+
return self._watcher
42+
43+
def update_from_config(self, mode: AutoReloadMode) -> None:
44+
"""Start, stop, or swap the watcher/reloader to match `runtime.auto_reload`."""
45+
# Pyodide doesn't support hot module reloading.
46+
if (mode == "lazy" or mode == "autorun") and not is_pyodide():
47+
if self._reloader is None:
48+
self._reloader = ModuleReloader()
49+
50+
if self._watcher is not None and self._watcher.mode != mode:
51+
self._watcher.stop()
52+
self._watcher = None
53+
54+
if self._watcher is None:
55+
self._watcher = ModuleWatcher(
56+
self._kernel.graph,
57+
reloader=self._reloader,
58+
enqueue_run_stale_cells=self._kernel._execute_stale_cells_callback,
59+
mode=mode,
60+
stream=self._kernel.stream,
61+
)
62+
else:
63+
self._reloader = None
64+
if self._watcher is not None:
65+
self._watcher.stop()
66+
self._watcher = None
67+
68+
def teardown(self) -> None:
69+
if self._watcher is not None:
70+
self._watcher.stop()
71+
self._watcher = None
72+
self._reloader = None
73+
74+
def flag_if_imports_stale(self, cell: CellImpl) -> None:
75+
reloader = self._reloader
76+
if reloader is None:
77+
return
78+
if reloader.cell_uses_stale_modules(cell):
79+
self._kernel.graph.set_stale({cell.cell_id}, prune_imports=True)
80+
81+
@contextlib.contextmanager
82+
def cell_scope(self) -> Iterator[None]:
83+
"""Reload modified modules on entry; record mtimes for newly-imported modules on exit."""
84+
if self._reloader is None:
85+
yield
86+
return
87+
snapshot = set(sys.modules)
88+
self._reloader.check(modules=sys.modules, reload=True)
89+
try:
90+
yield
91+
finally:
92+
new_modules = set(sys.modules) - snapshot
93+
self._reloader.check(
94+
modules={m: sys.modules[m] for m in new_modules},
95+
reload=False,
96+
)
97+
98+
def _on_finish_hook(self, ctx: OnFinishHookContext) -> None:
99+
del ctx
100+
if self._watcher is not None:
101+
self._watcher.run_is_processed.set()

marimo/_runtime/runtime.py

Lines changed: 6 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,7 @@
129129
start_parent_poller,
130130
)
131131
from marimo._runtime.redirect_streams import redirect_streams
132-
from marimo._runtime.reload.autoreload import ModuleReloader
133-
from marimo._runtime.reload.module_watcher import ModuleWatcher
132+
from marimo._runtime.reload.manager import AutoreloadManager
134133
from marimo._runtime.request_router import RequestRouter
135134
from marimo._runtime.runner import cell_runner, hook_context
136135
from marimo._runtime.runner.hooks import (
@@ -556,8 +555,7 @@ def __init__(
556555
self.module_registry = ModuleRegistry(
557556
self.graph, excluded_modules=set()
558557
)
559-
self.module_reloader: ModuleReloader | None = None
560-
self.module_watcher: ModuleWatcher | None = None
558+
self.autoreload_manager = AutoreloadManager(self)
561559

562560
# Load runtime settings from user config
563561
self.user_config = user_config
@@ -627,8 +625,7 @@ def teardown(self) -> None:
627625
self.stdin._stop()
628626
self.stream.stop()
629627

630-
if self.module_watcher is not None:
631-
self.module_watcher.stop()
628+
self.autoreload_manager.teardown()
632629

633630
# TODO(akshayka): There's a memory leak in run mode, with memory
634631
# usage increasing with each session creation. Somehow the kernel
@@ -653,35 +650,7 @@ def _update_runtime_from_user_config(self, config: MarimoConfig) -> None:
653650
self.user_config = config
654651

655652
self.packages_callbacks.update_package_manager(package_manager)
656-
657-
if (
658-
(autoreload_mode == "lazy" or autoreload_mode == "autorun")
659-
# Pyodide doesn't support hot module reloading
660-
and not is_pyodide()
661-
):
662-
if self.module_reloader is None:
663-
self.module_reloader = ModuleReloader()
664-
665-
if (
666-
self.module_watcher is not None
667-
and self.module_watcher.mode != autoreload_mode
668-
):
669-
self.module_watcher.stop()
670-
self.module_watcher = None
671-
672-
if self.module_watcher is None:
673-
self.module_watcher = ModuleWatcher(
674-
self.graph,
675-
reloader=self.module_reloader,
676-
enqueue_run_stale_cells=self._execute_stale_cells_callback,
677-
mode=autoreload_mode,
678-
stream=self.stream,
679-
)
680-
else:
681-
self.module_reloader = None
682-
if self.module_watcher is not None:
683-
self.module_watcher.stop()
684-
self.module_watcher = None
653+
self.autoreload_manager.update_from_config(autoreload_mode)
685654

686655
@property
687656
def globals(self) -> dict[Any, Any]:
@@ -756,25 +725,12 @@ def _install_execution_context(
756725
stderr=self.stderr,
757726
stdin=self.stdin,
758727
),
728+
self.autoreload_manager.cell_scope(),
759729
):
760-
modules = None
761730
try:
762-
if self.module_reloader is not None:
763-
# Reload modules if they have changed
764-
modules = set(sys.modules)
765-
self.module_reloader.check(
766-
modules=sys.modules, reload=True
767-
)
768731
yield exec_ctx
769732
finally:
770733
ctx.execution_context = None
771-
if self.module_reloader is not None and modules is not None:
772-
# Note timestamps for newly loaded modules
773-
new_modules = set(sys.modules) - modules
774-
self.module_reloader.check(
775-
modules={m: sys.modules[m] for m in new_modules},
776-
reload=False,
777-
)
778734

779735
def _register_cell(
780736
self,
@@ -796,12 +752,7 @@ def _register_cell(
796752
self.graph.cells[cell_id].set_stale(stale=True, broadcast=False)
797753
# leaky abstraction: the graph doesn't know about stale modules, so
798754
# we have to check for them here.
799-
module_reloader = self.module_reloader
800-
if (
801-
module_reloader is not None
802-
and module_reloader.cell_uses_stale_modules(cell)
803-
):
804-
self.graph.set_stale({cell.cell_id}, prune_imports=True)
755+
self.autoreload_manager.flag_if_imports_stale(cell)
805756
LOGGER.debug("registered cell %s", cell_id)
806757
LOGGER.debug("parents: %s", self.graph.parents[cell_id])
807758
LOGGER.debug("children: %s", self.graph.children[cell_id])
@@ -1805,9 +1756,6 @@ async def run_stale_cells(self) -> None:
18051756
)
18061757
)
18071758

1808-
if self.module_watcher is not None:
1809-
self.module_watcher.run_is_processed.set()
1810-
18111759
@kernel_tracer.start_as_current_span("set_cell_config")
18121760
async def set_cell_config(self, request: UpdateCellConfigCommand) -> None:
18131761
"""Update cell configs.

tests/_runtime/reload/test_module_watcher.py

Lines changed: 88 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -944,13 +944,14 @@ async def test_module_watcher_stop(
944944
await asyncio.sleep(INTERVAL)
945945

946946
# Stop the watcher
947-
assert k.module_watcher is not None
948-
assert not k.module_watcher.should_exit.is_set()
947+
watcher = k.autoreload_manager.watcher
948+
assert watcher is not None
949+
assert not watcher.should_exit.is_set()
949950

950-
k.module_watcher.stop()
951+
watcher.stop()
951952

952953
# should_exit should be set
953-
assert k.module_watcher.should_exit.is_set()
954+
assert watcher.should_exit.is_set()
954955

955956
async def test_module_watcher_processes_flag(
956957
self, execution_kernel: Kernel, exec_req: ExecReqProvider
@@ -965,9 +966,10 @@ async def test_module_watcher_processes_flag(
965966
# Give watcher time to start
966967
await asyncio.sleep(INTERVAL)
967968

968-
assert k.module_watcher is not None
969+
watcher = k.autoreload_manager.watcher
970+
assert watcher is not None
969971
# Initially should be set (no run in flight)
970-
assert k.module_watcher.run_is_processed.is_set()
972+
assert watcher.run_is_processed.is_set()
971973

972974

973975
class TestModuleWatcherEdgeCases:
@@ -1075,3 +1077,83 @@ async def test_module_watcher_cache_invalidation(
10751077

10761078
# The cache should handle the modified imports correctly
10771079
assert not k.graph.cells[er_1.cell_id].stale
1080+
1081+
1082+
class TestAutoreloadManagerLifecycle:
1083+
"""Tests for AutoreloadManager responding to runtime.auto_reload changes."""
1084+
1085+
async def test_mode_swap_replaces_watcher(
1086+
self, execution_kernel: Kernel, exec_req: ExecReqProvider
1087+
):
1088+
del exec_req
1089+
k = execution_kernel
1090+
config = copy.deepcopy(DEFAULT_CONFIG)
1091+
1092+
config["runtime"]["auto_reload"] = "lazy"
1093+
k.set_user_config(UpdateUserConfigCommand(config=config))
1094+
lazy_watcher = k.autoreload_manager.watcher
1095+
assert lazy_watcher is not None
1096+
assert lazy_watcher.mode == "lazy"
1097+
1098+
config["runtime"]["auto_reload"] = "autorun"
1099+
k.set_user_config(UpdateUserConfigCommand(config=config))
1100+
autorun_watcher = k.autoreload_manager.watcher
1101+
assert autorun_watcher is not None
1102+
assert autorun_watcher.mode == "autorun"
1103+
assert autorun_watcher is not lazy_watcher
1104+
assert lazy_watcher.should_exit.is_set()
1105+
1106+
async def test_reloader_persists_across_mode_swap(
1107+
self, execution_kernel: Kernel, exec_req: ExecReqProvider
1108+
):
1109+
"""Switching between lazy and autorun must not throw away the
1110+
reloader's mtime state — otherwise every swap would force a reload
1111+
of all already-tracked modules."""
1112+
del exec_req
1113+
k = execution_kernel
1114+
config = copy.deepcopy(DEFAULT_CONFIG)
1115+
1116+
config["runtime"]["auto_reload"] = "lazy"
1117+
k.set_user_config(UpdateUserConfigCommand(config=config))
1118+
reloader = k.autoreload_manager.reloader
1119+
assert reloader is not None
1120+
1121+
config["runtime"]["auto_reload"] = "autorun"
1122+
k.set_user_config(UpdateUserConfigCommand(config=config))
1123+
assert k.autoreload_manager.reloader is reloader
1124+
1125+
async def test_disable_clears_manager_state(
1126+
self, execution_kernel: Kernel, exec_req: ExecReqProvider
1127+
):
1128+
del exec_req
1129+
k = execution_kernel
1130+
config = copy.deepcopy(DEFAULT_CONFIG)
1131+
1132+
config["runtime"]["auto_reload"] = "lazy"
1133+
k.set_user_config(UpdateUserConfigCommand(config=config))
1134+
assert k.autoreload_manager.watcher is not None
1135+
assert k.autoreload_manager.reloader is not None
1136+
1137+
config["runtime"]["auto_reload"] = "off"
1138+
k.set_user_config(UpdateUserConfigCommand(config=config))
1139+
assert k.autoreload_manager.watcher is None
1140+
assert k.autoreload_manager.reloader is None
1141+
1142+
async def test_run_rearms_watcher_run_is_processed(
1143+
self, execution_kernel: Kernel, exec_req: ExecReqProvider
1144+
):
1145+
"""The on_finish hook fires for every kernel run, not just
1146+
`_run_stale_cells`. This is a deliberate semantic change from the
1147+
pre-extraction code, which only set the flag at the end of stale-cell
1148+
runs."""
1149+
k = execution_kernel
1150+
config = copy.deepcopy(DEFAULT_CONFIG)
1151+
config["runtime"]["auto_reload"] = "lazy"
1152+
k.set_user_config(UpdateUserConfigCommand(config=config))
1153+
1154+
watcher = k.autoreload_manager.watcher
1155+
assert watcher is not None
1156+
1157+
watcher.run_is_processed.clear()
1158+
await k.run([exec_req.get("x = 1")])
1159+
assert watcher.run_is_processed.is_set()

0 commit comments

Comments
 (0)