diff --git a/marimo/_messaging/streams.py b/marimo/_messaging/streams.py index 4af3700e4e7..dac6b69f751 100644 --- a/marimo/_messaging/streams.py +++ b/marimo/_messaging/streams.py @@ -68,7 +68,7 @@ def output_max_bytes() -> int: try: return get_context().marimo_config["runtime"]["output_max_bytes"] except ContextNotInitializedError: - return 5_000_000 + return 5_000_000 # 5MB def std_stream_max_bytes() -> int: @@ -77,7 +77,7 @@ def std_stream_max_bytes() -> int: try: return get_context().marimo_config["runtime"]["std_stream_max_bytes"] except ContextNotInitializedError: - return 1_000_000 + return 1_000_000 # 1MB class PipeProtocol(Protocol): @@ -85,7 +85,7 @@ def send(self, obj: KernelMessage) -> None: pass -class QueuePipe: +class QueuePipe(PipeProtocol): def __init__(self, queue: QueueType[KernelMessage]): self._queue = queue diff --git a/marimo/_pyodide/pyodide_session.py b/marimo/_pyodide/pyodide_session.py index 4fb4d776fa4..4c39f169086 100644 --- a/marimo/_pyodide/pyodide_session.py +++ b/marimo/_pyodide/pyodide_session.py @@ -30,18 +30,9 @@ BatchableCommand, CodeCompletionCommand, CommandMessage, - ModelCommand, - UpdateUIElementCommand, UpdateUserConfigCommand, ) -from marimo._runtime.context.kernel_context import initialize_kernel_context -from marimo._runtime.input_override import input_override from marimo._runtime.marimo_pdb import MarimoPdb -from marimo._runtime.runner.hooks import Priority, create_default_hooks -from marimo._runtime.runtime import Kernel -from marimo._runtime.utils.set_ui_element_request_manager import ( - SetUIElementRequestManager, -) from marimo._server.export.exporter import Exporter from marimo._server.files.os_file_system import OSFileSystem from marimo._server.models.export import ExportAsHTMLRequest @@ -446,94 +437,50 @@ def _launch_pyodide_kernel( user_config: MarimoConfig, ) -> RestartableTask: from marimo._output.formatters.formatters import register_formatters + from marimo._runtime.kernel_lifecycle import ( + asyncio_queue_reader, + create_kernel, + listen_messages, + teardown_kernel, + ) register_formatters() - - LOGGER.debug("Launching kernel") + LOGGER.debug("Launching pyodide kernel") # Patches for pyodide compatibility patches.patch_pyodide_networking() - # Some libraries mess with Python's default recursion limit, which becomes # a problem when running with Pyodide. patches.patch_recursion_limit(limit=1000) is_edit_mode = session_mode == SessionMode.EDIT - # Create communication channels stream = PyodideStream(on_message, input_queue) stdout = PyodideStdout(stream) stderr = PyodideStderr(stream) stdin = PyodideStdin(stream) if is_edit_mode else None debugger = MarimoPdb(stdout=stdout, stdin=stdin) if is_edit_mode else None - def _enqueue_control_request(req: CommandMessage) -> None: - control_queue.put_nowait(req) - if isinstance(req, (UpdateUIElementCommand, ModelCommand)): - set_ui_element_queue.put_nowait(req) - - # Create hooks with mode-specific configuration - from marimo._runtime.runner.hooks_post_execution import ( - attempt_pytest, - broadcast_storage_backends, - render_toplevel_defs, - ) - - hooks = create_default_hooks() - if is_edit_mode and user_config["runtime"].get("reactive_tests", False): - hooks.add_post_execution(attempt_pytest, Priority.LATE) - if is_edit_mode: - hooks.add_post_execution(render_toplevel_defs, Priority.LATE) - hooks.add_post_execution(broadcast_storage_backends, Priority.LATE) - - kernel = Kernel( - cell_configs=configs, - app_metadata=app_metadata, + kernel, ctx = create_kernel( stream=stream, stdout=stdout, stderr=stderr, stdin=stdin, - module=patches.patch_main_module( - file=app_metadata.filename, - input_override=input_override, - print_override=None, - doc=app_metadata.docstring, - ), - enqueue_control_request=_enqueue_control_request, - debugger_override=debugger, + debugger=debugger, + configs=configs, + app_metadata=app_metadata, user_config=user_config, - hooks=hooks, - ) - ctx = initialize_kernel_context( - kernel=kernel, - stream=stream, - stdout=stdout, - stderr=stderr, + is_edit_mode=is_edit_mode, + control_queue=control_queue, + set_ui_element_queue=set_ui_element_queue, virtual_file_storage=None, mode=session_mode, + print_override_fn=None, ) if is_edit_mode: signal.signal(signal.SIGINT, handlers.construct_interrupt_handler(ctx)) - ui_element_request_mgr = SetUIElementRequestManager(set_ui_element_queue) - - async def listen_messages() -> None: - while True: - request: CommandMessage | None = await control_queue.get() - LOGGER.debug("received request %s", request) - if isinstance( - request, - (commands.UpdateUIElementCommand, commands.ModelCommand), - ): - merged = ui_element_request_mgr.process_request(request) - for r in merged: - await kernel.handle_message(r) - continue - - if request is not None: - await kernel.handle_message(request) - async def listen_completion() -> None: while True: request = await completion_queue.get() @@ -547,6 +494,17 @@ async def listen_completion() -> None: kernel.code_completion(request, docstrings_limit=5) async def listen() -> None: - await asyncio.gather(listen_messages(), listen_completion()) + try: + await asyncio.gather( + listen_messages( + kernel, + control_queue, + set_ui_element_queue, + asyncio_queue_reader, + ), + listen_completion(), + ) + finally: + teardown_kernel(kernel, ctx) return RestartableTask(listen) diff --git a/marimo/_runtime/kernel_lifecycle.py b/marimo/_runtime/kernel_lifecycle.py new file mode 100644 index 00000000000..dfd47258715 --- /dev/null +++ b/marimo/_runtime/kernel_lifecycle.py @@ -0,0 +1,206 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Kernel-startup primitives that don't depend on the hosting environment. + +Environment-specific concerns (stream construction, signal handlers, +subprocess bootstrap, the outer task driver) stay at the call site. +""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Any, TypeVar + +from marimo import _loggers +from marimo._runtime import patches +from marimo._runtime.commands import ( + ModelCommand, + StopKernelCommand, + UpdateUIElementCommand, +) +from marimo._runtime.context.kernel_context import ( + KernelRuntimeContext, + initialize_kernel_context, +) +from marimo._runtime.context.types import teardown_context +from marimo._runtime.input_override import input_override +from marimo._runtime.runner.hooks import ( + NotebookCellHooks, + Priority, + create_default_hooks, +) +from marimo._runtime.utils.set_ui_element_request_manager import ( + SetUIElementRequestManager, +) + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from marimo._ast.cell import CellConfig + from marimo._config.config import MarimoConfig + from marimo._messaging.types import Stderr, Stdin, Stdout, Stream + from marimo._runtime import marimo_pdb + from marimo._runtime.commands import ( + AppMetadata, + BatchableCommand, + CommandMessage, + ) + from marimo._runtime.runtime import Kernel + from marimo._runtime.virtual_file import VirtualFileStorageType + from marimo._session.model import SessionMode + from marimo._session.queue import QueueType + from marimo._types.ids import CellId_t + + ControlQueue = QueueType[CommandMessage] | asyncio.Queue[CommandMessage] + UIElementQueue = ( + QueueType[BatchableCommand] | asyncio.Queue[BatchableCommand] + ) + +LOGGER = _loggers.marimo_logger() + +# Lets each caller pin listen_messages and its reader to the same queue type +# (threading vs asyncio). +_Q = TypeVar("_Q") + + +def _build_hooks( + is_edit_mode: bool, user_config: MarimoConfig +) -> NotebookCellHooks: + from marimo._runtime.runner.hooks_post_execution import ( + attempt_pytest, + broadcast_storage_backends, + render_toplevel_defs, + ) + + hooks = create_default_hooks() + if is_edit_mode and user_config["runtime"].get("reactive_tests", False): + hooks.add_post_execution(attempt_pytest, Priority.LATE) + if is_edit_mode: + hooks.add_post_execution(render_toplevel_defs, Priority.LATE) + hooks.add_post_execution(broadcast_storage_backends, Priority.LATE) + return hooks + + +def create_kernel( + *, + stream: Stream, + stdout: Stdout | None, + stderr: Stderr | None, + stdin: Stdin | None, + debugger: marimo_pdb.MarimoPdb | None, + configs: dict[CellId_t, CellConfig], + app_metadata: AppMetadata, + user_config: MarimoConfig, + is_edit_mode: bool, + control_queue: ControlQueue, + set_ui_element_queue: UIElementQueue, + virtual_file_storage: VirtualFileStorageType | None, + mode: SessionMode, + print_override_fn: Callable[[Any], None] | None, +) -> tuple[Kernel, KernelRuntimeContext]: + # Run mode forces autorun and disables the module autoreloader. + if not is_edit_mode: + user_config = user_config.copy() + user_config["runtime"]["on_cell_change"] = "autorun" + user_config["runtime"]["auto_reload"] = "off" + + def _enqueue_control_request(req: CommandMessage) -> None: + control_queue.put_nowait(req) + if isinstance(req, (UpdateUIElementCommand, ModelCommand)): + set_ui_element_queue.put_nowait(req) + + # Deferred to break the runtime.py <-> kernel_lifecycle.py import cycle. + from marimo._runtime.runtime import Kernel + + kernel = Kernel( + cell_configs=configs, + app_metadata=app_metadata, + stream=stream, + stdout=stdout, + stderr=stderr, + stdin=stdin, + module=patches.patch_main_module( + file=app_metadata.filename, + input_override=input_override, + print_override=print_override_fn, + doc=app_metadata.docstring, + ), + debugger_override=debugger, + user_config=user_config, + enqueue_control_request=_enqueue_control_request, + hooks=_build_hooks(is_edit_mode, user_config), + ) + ctx = initialize_kernel_context( + kernel=kernel, + stream=stream, + stdout=stdout, + stderr=stderr, + virtual_file_storage=virtual_file_storage, + mode=mode, + ) + return kernel, ctx + + +async def threaded_queue_reader( + queue: QueueType[CommandMessage], +) -> CommandMessage | None: + # Offload the blocking get() so background asyncio tasks aren't starved. + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, queue.get) + + +async def asyncio_queue_reader( + queue: asyncio.Queue[CommandMessage], +) -> CommandMessage | None: + return await queue.get() + + +async def listen_messages( + kernel: Kernel, + control_queue: _Q, + set_ui_element_queue: UIElementQueue, + get_request: Callable[[_Q], Awaitable[CommandMessage | None]], +) -> None: + """Run the kernel's control loop until `StopKernelCommand` is received. + + `get_request` adapts the queue-read mechanism so this loop can drive + either a threading/multiprocessing queue or an `asyncio.Queue`. + """ + ui_request_mgr = SetUIElementRequestManager(set_ui_element_queue) + + while True: + try: + request = await get_request(control_queue) + except Exception as e: + # triggered on Windows when quit with Ctrl+C + LOGGER.debug("kernel queue.get() failed %s", e) + return + + if request is None: + continue + LOGGER.debug("Received control request: %s", type(request).__name__) + if isinstance(request, StopKernelCommand): + return + + merged: list[CommandMessage] + if isinstance(request, (UpdateUIElementCommand, ModelCommand)): + merged = list(ui_request_mgr.process_request(request)) + else: + merged = [request] + + for r in merged: + try: + await kernel.handle_message(r) + except Exception: + LOGGER.exception( + "Failed to handle control request: %s", + type(r).__name__, + ) + + +def teardown_kernel(kernel: Kernel, ctx: KernelRuntimeContext) -> None: + # Defensively shut down registries in case a leak prevents context + # destruction from cleaning them up. + ctx.virtual_file_registry.shutdown() + ctx.app_kernel_runner_registry.shutdown() + teardown_context() + kernel.teardown() diff --git a/marimo/_runtime/runtime.py b/marimo/_runtime/runtime.py index 163c3e47192..d7dcb9d873e 100644 --- a/marimo/_runtime/runtime.py +++ b/marimo/_runtime/runtime.py @@ -161,12 +161,10 @@ ) from marimo._runtime.context.kernel_context import ( KernelRuntimeContext, - initialize_kernel_context, ) -from marimo._runtime.context.types import teardown_context from marimo._runtime.context.utils import get_mode from marimo._runtime.control_flow import MarimoInterrupt -from marimo._runtime.input_override import getpass_override, input_override +from marimo._runtime.input_override import getpass_override from marimo._runtime.packages.import_error_extractors import ( extract_missing_module_from_cause_chain, try_extract_packages_from_import_error_message, @@ -192,13 +190,9 @@ from marimo._runtime.runner.hooks import ( NotebookCellHooks, Priority, - create_default_hooks, ) from marimo._runtime.scratch import SCRATCH_CELL_ID from marimo._runtime.state import State -from marimo._runtime.utils.set_ui_element_request_manager import ( - SetUIElementRequestManager, -) from marimo._runtime.virtual_file.virtual_file import VirtualFile from marimo._runtime.win32_interrupt_handler import Win32InterruptHandler from marimo._secrets.load_dotenv import ( @@ -3540,68 +3534,87 @@ async def handle(self, request: CommandMessage) -> None: raise ValueError(f"Unknown request {request}") -def launch_kernel( - control_queue: QueueType[CommandMessage], - set_ui_element_queue: QueueType[BatchableCommand], - completion_queue: QueueType[CodeCompletionCommand], - input_queue: QueueType[str], - stream_queue: QueueType[KernelMessage] | None, - socket_addr: tuple[str, int] | None, - is_edit_mode: bool, - configs: dict[CellId_t, CellConfig], - app_metadata: AppMetadata, - user_config: MarimoConfig, - virtual_file_storage: VirtualFileStorageType | None, - redirect_console_to_browser: bool, - interrupt_queue: QueueType[bool] | None = None, - profile_path: str | None = None, - log_level: int | None = None, - is_ipc: bool = False, - parent_pid: int | None = None, -) -> None: +@dataclasses.dataclass +class _KernelStreams: + stream: ThreadSafeStream + stdout: ThreadSafeStdout | None + stderr: ThreadSafeStderr | None + stdin: ThreadSafeStdin | None + debugger: marimo_pdb.MarimoPdb | None + pipe: TypedConnection[KernelMessage] | None + + def close(self, use_fd_redirect: bool) -> None: + if not use_fd_redirect: + from marimo._messaging.thread_local_streams import ( + clear_thread_local_streams, + ) + + clear_thread_local_streams() + + if isinstance(self.pipe, connection.Connection): + self.pipe.close() + + +def _bootstrap_subprocess( + parent_pid: int | None, + log_level: int | None, + is_subprocess: bool, +) -> Callable[[], asyncio.AbstractEventLoop] | None: + # Returns a loop factory only on Windows 3.14+; elsewhere either mutates + # the loop policy or does nothing. if log_level is not None: _loggers.set_level(log_level) - LOGGER.debug("Launching kernel") - is_subprocess = is_edit_mode or is_ipc - loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None - if is_subprocess: - restore_signals() - - # Become the leader of a new session/process group before connecting - # back to the parent, to avoid race conditions with the parent - # process (which assumes its child is in another process group). - if sys.platform != "win32": - os.setsid() - start_parent_poller(parent_pid) - - # The runtime process inherits the server's loop policy. On Windows, we - # restore the event loop policy to the default ProactorEventLoop, so - # user code can use asyncio.create_subprocess_exec and other APIs that - # the SelectorEventLoop does not implement. - if sys.platform == "win32": - if sys.version_info >= (3, 14): - # Event loop policies are deprecated in Python 3.14 - loop_factory = asyncio.ProactorEventLoop - else: - asyncio.set_event_loop_policy( - asyncio.WindowsProactorEventLoopPolicy() - ) + if not is_subprocess: + return None + + restore_signals() + + # Become the leader of a new session/process group before connecting + # back to the parent, to avoid race conditions with the parent + # process (which assumes its child is in another process group). + if sys.platform != "win32": + os.setsid() + start_parent_poller(parent_pid) + + # The runtime process inherits the server's loop policy. On Windows, we + # restore the event loop policy to the default ProactorEventLoop, so + # user code can use asyncio.create_subprocess_exec and other APIs that + # the SelectorEventLoop does not implement. + if sys.platform == "win32": + if sys.version_info >= (3, 14): + # Event loop policies are deprecated in Python 3.14 + return asyncio.ProactorEventLoop + asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) + return None - profiler = None - if profile_path is not None: - import cProfile - profiler = cProfile.Profile() - profiler.enable() +@contextlib.contextmanager +def _maybe_profile(profile_path: str | None) -> Iterator[None]: + if profile_path is None: + yield + return - should_redirect_stdio = is_edit_mode or redirect_console_to_browser - # Only use os.dup2-based fd redirection in process-based modes - # (edit mode / IPC). Thread-based run mode uses the lighter-weight - # thread-local proxy instead to avoid process-global fd mutations. - use_fd_redirect = is_subprocess + import cProfile - # Create communication channels + profiler = cProfile.Profile() + profiler.enable() + try: + yield + finally: + profiler.disable() + profiler.dump_stats(profile_path) + + +def _create_streams( + socket_addr: tuple[str, int] | None, + stream_queue: QueueType[KernelMessage] | None, + input_queue: QueueType[str], + is_edit_mode: bool, + should_redirect_stdio: bool, + use_fd_redirect: bool, +) -> _KernelStreams | None: + # Returns None when the socket fails to connect; callers should bail out. pipe: TypedConnection[KernelMessage] | None = None if socket_addr is not None: n_tries = 0 @@ -3629,7 +3642,7 @@ def launch_kernel( n_tries, exc_info=last_error, ) - return + return None stream = ThreadSafeStream( pipe=pipe, @@ -3665,160 +3678,134 @@ def launch_kernel( if is_edit_mode and not bool(os.getenv("DEBUGPY_RUNNING")) else None ) - - # In run mode, the kernel should always be in autorun, and the module - # autoreloader is disabled - if not is_edit_mode: - user_config = user_config.copy() - user_config["runtime"]["on_cell_change"] = "autorun" - user_config["runtime"]["auto_reload"] = "off" - - def _enqueue_control_request(req: CommandMessage) -> None: - control_queue.put_nowait(req) - if isinstance(req, (UpdateUIElementCommand, ModelCommand)): - set_ui_element_queue.put_nowait(req) - - # Create hooks with mode-specific configuration - from marimo._runtime.runner.hooks_post_execution import ( - attempt_pytest, - broadcast_storage_backends, - render_toplevel_defs, - ) - - hooks = create_default_hooks() - if is_edit_mode and user_config["runtime"].get("reactive_tests", False): - hooks.add_post_execution(attempt_pytest, Priority.LATE) - if is_edit_mode: - hooks.add_post_execution(render_toplevel_defs, Priority.LATE) - hooks.add_post_execution(broadcast_storage_backends, Priority.LATE) - - kernel = Kernel( - cell_configs=configs, - app_metadata=app_metadata, + return _KernelStreams( stream=stream, stdout=stdout, stderr=stderr, stdin=stdin, - module=patches.patch_main_module( - file=app_metadata.filename, - input_override=input_override, - print_override=print_override, - doc=app_metadata.docstring, - ), - debugger_override=debugger, - user_config=user_config, - enqueue_control_request=_enqueue_control_request, - hooks=hooks, - ) - ctx = initialize_kernel_context( - kernel=kernel, - stream=stream, - stdout=stdout, - stderr=stderr, - virtual_file_storage=virtual_file_storage, - mode=SessionMode.EDIT if is_edit_mode else SessionMode.RUN, + debugger=debugger, + pipe=pipe, ) - if is_edit_mode: - # completions only provided in edit mode - kernel.start_completion_worker(completion_queue) - - if is_subprocess: - # Subprocess kernels (EDIT and IPC_RUN) can receive signals and need - # their own formatter registration since they don't share state with - # the host process. - # - # Each subprocess kernel needs to install the formatter import hooks - from marimo._output.formatters.formatters import register_formatters - - register_formatters(theme=user_config["display"]["theme"]) - signal.signal(signal.SIGINT, handlers.construct_interrupt_handler(ctx)) - - if sys.platform == "win32": - if interrupt_queue is not None: - Win32InterruptHandler(interrupt_queue).start() - # windows doesn't handle SIGTERM - signal.signal( - signal.SIGBREAK, handlers.construct_sigterm_handler(kernel) - ) - else: - signal.signal( - signal.SIGTERM, handlers.construct_sigterm_handler(kernel) - ) +def _install_subprocess_handlers( + kernel: Kernel, + ctx: KernelRuntimeContext, + user_config: MarimoConfig, + interrupt_queue: QueueType[bool] | None, +) -> None: + # Subprocess kernels don't share state with the host, so they need + # their own formatter import hooks and signal handlers. + from marimo._output.formatters.formatters import register_formatters - ui_element_request_mgr = SetUIElementRequestManager(set_ui_element_queue) + register_formatters(theme=user_config["display"]["theme"]) - async def control_loop(kernel: Kernel) -> None: - loop = asyncio.get_running_loop() + signal.signal(signal.SIGINT, handlers.construct_interrupt_handler(ctx)) - while True: - try: - # Offload the blocking queue.get() to a thread so the event - # loop stays free to service background asyncio tasks (e.g. - # user-created tasks via create_task / ensure_future). - request: CommandMessage | None = await loop.run_in_executor( - None, - control_queue.get, - ) - except Exception as e: - # triggered on Windows when quit with Ctrl+C - LOGGER.debug("kernel queue.get() failed %s", e) - break - LOGGER.debug( - "Received control request: %s", type(request).__name__ - ) - if isinstance(request, StopKernelCommand): - break - elif isinstance(request, (UpdateUIElementCommand, ModelCommand)): - # Drain the shared queue and merge pending requests: - # - UI element updates: last-write-wins per element ID - # - Model commands: last-write-wins per model ID - merged = ui_element_request_mgr.process_request(request) - for r in merged: - await kernel.handle_message(r) - continue + if sys.platform == "win32": + if interrupt_queue is not None: + Win32InterruptHandler(interrupt_queue).start() + # windows doesn't handle SIGTERM + signal.signal( + signal.SIGBREAK, handlers.construct_sigterm_handler(kernel) + ) + else: + signal.signal( + signal.SIGTERM, handlers.construct_sigterm_handler(kernel) + ) - if request is not None: - try: - await kernel.handle_message(request) - except Exception: - LOGGER.exception( - "Failed to handle control request: %s", - type(request).__name__, - ) - # The control loop is asynchronous so that (a) user code can use - # top-level await, and (b) background asyncio tasks created by user code - # (via create_task / ensure_future) are not starved by a blocking - # queue.get(). The queue read is offloaded to a thread via - # run_in_executor; avoid adding further async primitives elsewhere in the - # runtime unless there is a very good reason. - if loop_factory is not None: - asyncio.run(control_loop(kernel), loop_factory=loop_factory) - else: - asyncio.run(control_loop(kernel)) +def launch_kernel( + control_queue: QueueType[CommandMessage], + set_ui_element_queue: QueueType[BatchableCommand], + completion_queue: QueueType[CodeCompletionCommand], + input_queue: QueueType[str], + stream_queue: QueueType[KernelMessage] | None, + socket_addr: tuple[str, int] | None, + is_edit_mode: bool, + configs: dict[CellId_t, CellConfig], + app_metadata: AppMetadata, + user_config: MarimoConfig, + virtual_file_storage: VirtualFileStorageType | None, + redirect_console_to_browser: bool, + interrupt_queue: QueueType[bool] | None = None, + profile_path: str | None = None, + log_level: int | None = None, + is_ipc: bool = False, + parent_pid: int | None = None, +) -> None: + from marimo._runtime.kernel_lifecycle import ( + create_kernel, + listen_messages, + teardown_kernel, + threaded_queue_reader, + ) - if not use_fd_redirect: - from marimo._messaging.thread_local_streams import ( - clear_thread_local_streams, + LOGGER.debug("Launching kernel") + is_subprocess = is_edit_mode or is_ipc + loop_factory = _bootstrap_subprocess(parent_pid, log_level, is_subprocess) + + with _maybe_profile(profile_path): + should_redirect_stdio = is_edit_mode or redirect_console_to_browser + # Only use os.dup2-based fd redirection in process-based modes + # (edit mode / IPC). Thread-based run mode uses the lighter-weight + # thread-local proxy instead to avoid process-global fd mutations. + use_fd_redirect = is_subprocess + streams = _create_streams( + socket_addr, + stream_queue, + input_queue, + is_edit_mode, + should_redirect_stdio, + use_fd_redirect, ) + if streams is None: + return - clear_thread_local_streams() + kernel, ctx = create_kernel( + stream=streams.stream, + stdout=streams.stdout, + stderr=streams.stderr, + stdin=streams.stdin, + debugger=streams.debugger, + configs=configs, + app_metadata=app_metadata, + user_config=user_config, + is_edit_mode=is_edit_mode, + control_queue=control_queue, + set_ui_element_queue=set_ui_element_queue, + virtual_file_storage=virtual_file_storage, + mode=SessionMode.EDIT if is_edit_mode else SessionMode.RUN, + print_override_fn=print_override, + ) - if profiler is not None and profile_path is not None: - profiler.disable() - profiler.dump_stats(profile_path) + if is_edit_mode: + # completions only provided in edit mode + kernel.start_completion_worker(completion_queue) + + if is_subprocess: + # Read theme from kernel.user_config — create_kernel may have + # mutated it for run mode (autorun + auto_reload off). + _install_subprocess_handlers( + kernel, ctx, kernel.user_config, interrupt_queue + ) + + # The control loop is asynchronous so that (a) user code can use + # top-level await, and (b) background asyncio tasks created by user + # code (via create_task / ensure_future) are not starved by a + # blocking queue.get(). The queue read is offloaded to a thread via + # run_in_executor; avoid adding further async primitives elsewhere + # in the runtime unless there is a very good reason. + coro = listen_messages( + kernel, + control_queue, + set_ui_element_queue, + threaded_queue_reader, + ) + if loop_factory is not None: + asyncio.run(coro, loop_factory=loop_factory) + else: + asyncio.run(coro) - # Defensively clear context data structures, in case a leak prevents - # the context from being destroyed. - # - # TODO(akshayka): define ownership semantics for contexts, so the - # context knows how to shut itself down. The virtual file registry - # is shared between the main thread and mo.Thread's right now ... - get_context().virtual_file_registry.shutdown() - get_context().app_kernel_runner_registry.shutdown() - teardown_context() - kernel.teardown() - if isinstance(pipe, connection.Connection): - pipe.close() + teardown_kernel(kernel, ctx) + streams.close(use_fd_redirect) diff --git a/tests/_pyodide/test_pyodide_session.py b/tests/_pyodide/test_pyodide_session.py index 07ab96fb9a3..7f5ec0197db 100644 --- a/tests/_pyodide/test_pyodide_session.py +++ b/tests/_pyodide/test_pyodide_session.py @@ -6,7 +6,7 @@ import unittest.mock from textwrap import dedent from typing import TYPE_CHECKING, Any -from unittest.mock import MagicMock, Mock +from unittest.mock import MagicMock, Mock, patch import msgspec import pytest @@ -19,6 +19,7 @@ AsyncQueueManager, PyodideBridge, PyodideSession, + _launch_pyodide_kernel, parse_command, ) from marimo._runtime.commands import ( @@ -172,6 +173,68 @@ async def test_pyodide_session_start( pass +async def test_pyodide_kernel_teardown_runs_on_stop( + pyodide_app_file: Path, +) -> None: + """Stopping the kernel task must trigger teardown_kernel via the listen() + finally block (previously absent for the pyodide path).""" + fake_kernel = MagicMock() + fake_ctx = MagicMock() + teardown_calls: list[tuple[Any, Any]] = [] + + async def block_until_cancelled(*_args: Any, **_kwargs: Any) -> None: + await asyncio.Event().wait() + + with ( + patch( + "marimo._runtime.kernel_lifecycle.create_kernel", + return_value=(fake_kernel, fake_ctx), + ), + patch( + "marimo._runtime.kernel_lifecycle.listen_messages", + side_effect=block_until_cancelled, + ), + patch( + "marimo._runtime.kernel_lifecycle.teardown_kernel", + side_effect=lambda k, c: teardown_calls.append((k, c)), + ), + patch("marimo._pyodide.pyodide_session.signal"), + patch("marimo._output.formatters.formatters.register_formatters"), + patch( + "marimo._pyodide.pyodide_session.patches.patch_pyodide_networking" + ), + patch("marimo._pyodide.pyodide_session.patches.patch_recursion_limit"), + ): + kernel_task = _launch_pyodide_kernel( + control_queue=asyncio.Queue(), + set_ui_element_queue=asyncio.Queue(), + completion_queue=asyncio.Queue(), + input_queue=asyncio.Queue(), + on_message=lambda _msg: None, + session_mode=SessionMode.EDIT, + configs={}, + app_metadata=AppMetadata( + query_params={}, + cli_args={}, + app_config=_AppConfig(), + filename=str(pyodide_app_file), + ), + user_config=DEFAULT_CONFIG, + ) + start_task = asyncio.create_task(kernel_task.start()) + # Yield enough times for: outer task → RestartableTask.start → inner + # task creation → listen() → asyncio.gather → child tasks suspended. + for _ in range(5): + await asyncio.sleep(0) + kernel_task.stop() + try: + await start_task + except asyncio.CancelledError: + pass + + assert teardown_calls == [(fake_kernel, fake_ctx)] + + async def test_pyodide_session_put_control_request( pyodide_session: PyodideSession, ) -> None: diff --git a/tests/_runtime/test_kernel_lifecycle.py b/tests/_runtime/test_kernel_lifecycle.py new file mode 100644 index 00000000000..fa00dfa174d --- /dev/null +++ b/tests/_runtime/test_kernel_lifecycle.py @@ -0,0 +1,153 @@ +# Copyright 2026 Marimo. All rights reserved. +from __future__ import annotations + +import asyncio +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from marimo._runtime.commands import ( + ExecuteCellsCommand, + StopKernelCommand, + UpdateUIElementCommand, +) +from marimo._runtime.kernel_lifecycle import ( + asyncio_queue_reader, + listen_messages, +) +from marimo._types.ids import CellId_t, UIElementId + + +@pytest.fixture +def kernel() -> Any: + k = MagicMock() + k.handle_message = AsyncMock() + return k + + +@pytest.fixture +def control() -> asyncio.Queue[Any]: + return asyncio.Queue() + + +@pytest.fixture +def ui() -> asyncio.Queue[Any]: + return asyncio.Queue() + + +def _execute(cell_id: str = "c1") -> ExecuteCellsCommand: + return ExecuteCellsCommand(cell_ids=[CellId_t(cell_id)], codes=["x = 1"]) + + +def _ui_update( + elem_id: str = "u1", value: Any = None +) -> UpdateUIElementCommand: + return UpdateUIElementCommand( + object_ids=[UIElementId(elem_id)], values=[value] + ) + + +async def test_listen_messages_exits_on_stop_command( + kernel: Any, + control: asyncio.Queue[Any], + ui: asyncio.Queue[Any], +) -> None: + cmd = _execute() + control.put_nowait(cmd) + control.put_nowait(StopKernelCommand()) + # A request enqueued *after* StopKernel must not be dispatched. + control.put_nowait(_execute("after-stop")) + + await listen_messages(kernel, control, ui, asyncio_queue_reader) + + assert kernel.handle_message.await_count == 1 + assert kernel.handle_message.await_args.args == (cmd,) + + +async def test_listen_messages_skips_none_requests( + kernel: Any, + control: asyncio.Queue[Any], + ui: asyncio.Queue[Any], +) -> None: + control.put_nowait(None) + cmd = _execute() + control.put_nowait(cmd) + control.put_nowait(StopKernelCommand()) + + await listen_messages(kernel, control, ui, asyncio_queue_reader) + + assert kernel.handle_message.await_count == 1 + assert kernel.handle_message.await_args.args == (cmd,) + + +async def test_listen_messages_swallows_handle_message_exception_non_ui( + kernel: Any, + control: asyncio.Queue[Any], + ui: asyncio.Queue[Any], +) -> None: + kernel.handle_message.side_effect = [RuntimeError("boom"), None] + control.put_nowait(_execute("first")) + control.put_nowait(_execute("second")) + control.put_nowait(StopKernelCommand()) + + await listen_messages(kernel, control, ui, asyncio_queue_reader) + + # Second dispatch proves the loop survived the first raise. + assert kernel.handle_message.await_count == 2 + + +async def test_listen_messages_swallows_handle_message_exception_ui_branch( + kernel: Any, + control: asyncio.Queue[Any], + ui: asyncio.Queue[Any], +) -> None: + """Regression test: the UI-merge branch used to let `handle_message` + exceptions propagate while the non-UI branch caught them.""" + kernel.handle_message.side_effect = [RuntimeError("boom"), None] + ui_cmd = _ui_update() + control.put_nowait(ui_cmd) + ui.put_nowait(ui_cmd) + control.put_nowait(_execute("after-ui")) + control.put_nowait(StopKernelCommand()) + + await listen_messages(kernel, control, ui, asyncio_queue_reader) + + assert kernel.handle_message.await_count == 2 + + +async def test_listen_messages_exits_when_reader_raises( + kernel: Any, + control: asyncio.Queue[Any], + ui: asyncio.Queue[Any], +) -> None: + async def failing_reader(_queue: Any) -> Any: + raise OSError("queue closed") + + await listen_messages(kernel, control, ui, failing_reader) + + kernel.handle_message.assert_not_called() + + +async def test_listen_messages_merges_ui_updates( + kernel: Any, + control: asyncio.Queue[Any], + ui: asyncio.Queue[Any], +) -> None: + """Contiguous UI updates against the same element collapse to one dispatch.""" + first = _ui_update("u", 1) + second = _ui_update("u", 2) + # Enqueue both on control + ui (matches what _enqueue_control_request does). + control.put_nowait(first) + ui.put_nowait(first) + control.put_nowait(second) + ui.put_nowait(second) + control.put_nowait(StopKernelCommand()) + + await listen_messages(kernel, control, ui, asyncio_queue_reader) + + # Both UI updates merge into a single batched dispatch. + assert kernel.handle_message.await_count == 1 + dispatched = kernel.handle_message.await_args.args[0] + assert isinstance(dispatched, UpdateUIElementCommand) + assert dispatched.values == [2] diff --git a/tests/_runtime/test_runtime.py b/tests/_runtime/test_runtime.py index 2e419a5db91..6dd9e752303 100644 --- a/tests/_runtime/test_runtime.py +++ b/tests/_runtime/test_runtime.py @@ -4207,7 +4207,7 @@ class TestLaunchKernelEventLoop: "marimo._runtime.runtime.ThreadSafeStdin", "marimo._runtime.runtime.marimo_pdb.MarimoPdb", "marimo._runtime.runtime.Kernel", - "marimo._runtime.runtime.initialize_kernel_context", + "marimo._runtime.kernel_lifecycle.initialize_kernel_context", "marimo._runtime.runtime.patches.patch_main_module", "marimo._output.formatters.formatters.register_formatters", ]