Skip to content

Commit 5d1923c

Browse files
authored
refactor: DRY up code between wasm and native kernel (#9591)
1 parent 1b022c5 commit 5d1923c

5 files changed

Lines changed: 133 additions & 75 deletions

File tree

marimo/_pyodide/pyodide_session.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,19 @@ def __init__(
116116
user_config: MarimoConfig,
117117
) -> None:
118118
"""Initialize kernel and client connection to it."""
119+
from marimo._runtime.kernel_lifecycle import make_control_enqueuer
120+
119121
self.app_manager = app
120122
self.mode = mode
121123
self.app_metadata = app_metadata
122124
self._queue_manager = AsyncQueueManager()
123125
self.session_consumer = on_write
124126
self.session_view = SessionView()
125127
self._initial_user_config = user_config
128+
self._enqueue_control_request = make_control_enqueuer(
129+
self._queue_manager.control_queue,
130+
self._queue_manager.set_ui_element_queue,
131+
)
126132

127133
self.consumers: list[Callable[[KernelMessage], None]] = [
128134
lambda msg: self.session_consumer(msg),
@@ -148,12 +154,7 @@ async def start(self) -> None:
148154
await self.kernel_task.start()
149155

150156
def put_control_request(self, request: commands.CommandMessage) -> None:
151-
self._queue_manager.control_queue.put_nowait(request)
152-
if isinstance(
153-
request,
154-
(commands.UpdateUIElementCommand, commands.ModelCommand),
155-
):
156-
self._queue_manager.set_ui_element_queue.put_nowait(request)
157+
self._enqueue_control_request(request)
157158

158159
def put_completion_request(
159160
self, request: commands.CodeCompletionCommand
@@ -442,6 +443,7 @@ def _launch_pyodide_kernel(
442443
KernelArgs,
443444
asyncio_queue_reader,
444445
create_kernel,
446+
drain_stale,
445447
listen_messages,
446448
teardown_kernel,
447449
)
@@ -485,14 +487,9 @@ def _launch_pyodide_kernel(
485487

486488
async def listen_completion() -> None:
487489
while True:
488-
request = await completion_queue.get()
489-
while not completion_queue.empty():
490-
# discard stale requests to avoid choking the runtime
491-
request = await completion_queue.get()
492-
LOGGER.debug("received completion request %s", request)
493-
# 5 is arbitrary, but is a good limit:
494-
# too high will cause long load times
495-
# too low can be not as useful
490+
request = drain_stale(
491+
completion_queue, latest=await completion_queue.get()
492+
)
496493
kernel.code_completion(request, docstrings_limit=5)
497494

498495
async def listen() -> None:

marimo/_runtime/complete.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from marimo._output.md import _md
2424
from marimo._runtime import dataflow
2525
from marimo._runtime.commands import CodeCompletionCommand
26-
from marimo._session.queue import QueueType
2726
from marimo._types.ids import RequestId
2827
from marimo._utils.docs import MarimoConverter
2928
from marimo._utils.format_signature import format_signature
@@ -360,17 +359,6 @@ def _write_no_completions(stream: Stream, completion_id: RequestId) -> None:
360359
_write_completion_result(stream, completion_id, 0, [])
361360

362361

363-
def _drain_queue(
364-
completion_queue: QueueType[CodeCompletionCommand],
365-
) -> CodeCompletionCommand:
366-
"""Drain the queue of completion requests, returning the most recent one"""
367-
368-
request = completion_queue.get()
369-
while not completion_queue.empty():
370-
request = completion_queue.get()
371-
return request
372-
373-
374362
def _get_completions_with_script(
375363
codes: list[str], document: str
376364
) -> tuple[jedi.Script, list[jedi.api.classes.Completion]]:
@@ -739,32 +727,3 @@ def complete(
739727
pass
740728
else:
741729
LOGGER.debug("Completion worker released globals lock.")
742-
743-
744-
def completion_worker(
745-
completion_queue: QueueType[CodeCompletionCommand],
746-
graph: dataflow.DirectedGraph,
747-
glbls: dict[str, Any],
748-
glbls_lock: threading.RLock,
749-
stream: Stream,
750-
) -> None:
751-
"""Code completion worker.
752-
753-
754-
Args:
755-
completion_queue: queue from which requests are pulled.
756-
graph: dataflow graph backing the marimo program
757-
glbls: dictionary of global variables in interpreter memory
758-
glbls_lock: lock protecting globals
759-
stream: stream used to communicate completion results
760-
"""
761-
762-
while True:
763-
request = _drain_queue(completion_queue)
764-
complete(
765-
request=request,
766-
graph=graph,
767-
glbls=glbls,
768-
glbls_lock=glbls_lock,
769-
stream=stream,
770-
)

marimo/_runtime/kernel_lifecycle.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import asyncio
1111
import contextlib
1212
import dataclasses
13+
import queue as _queue
1314
from typing import TYPE_CHECKING, Any, TypeVar
1415

1516
from marimo import _loggers
@@ -83,6 +84,20 @@ def is_edit_mode(self) -> bool:
8384
# Lets each caller pin listen_messages and its reader to the same queue type
8485
# (threading vs asyncio).
8586
_Q = TypeVar("_Q")
87+
_T = TypeVar("_T")
88+
89+
90+
def drain_stale(queue: Any, *, latest: _T) -> _T:
91+
"""Discard stale items queued behind ``latest`` and return the newest.
92+
93+
Drains via ``get_nowait()`` until exhausted; ``empty()`` is intentionally
94+
avoided because ``multiprocessing.Queue.empty()`` can lie.
95+
"""
96+
while True:
97+
try:
98+
latest = queue.get_nowait()
99+
except (asyncio.QueueEmpty, _queue.Empty):
100+
return latest
86101

87102

88103
def _build_hooks(
@@ -103,6 +118,21 @@ def _build_hooks(
103118
return hooks
104119

105120

121+
def make_control_enqueuer(
122+
control_queue: ControlQueue,
123+
set_ui_element_queue: UIElementQueue,
124+
) -> Callable[[CommandMessage], None]:
125+
"""Build a callable that routes control requests, mirroring UI-element
126+
commands onto the batching queue."""
127+
128+
def enqueue(req: CommandMessage) -> None:
129+
control_queue.put_nowait(req)
130+
if isinstance(req, (UpdateUIElementCommand, ModelCommand)):
131+
set_ui_element_queue.put_nowait(req)
132+
133+
return enqueue
134+
135+
106136
def create_kernel(
107137
args: KernelArgs,
108138
) -> tuple[Kernel, KernelRuntimeContext]:
@@ -113,11 +143,6 @@ def create_kernel(
113143
user_config["runtime"]["on_cell_change"] = "autorun"
114144
user_config["runtime"]["auto_reload"] = "off"
115145

116-
def _enqueue_control_request(req: CommandMessage) -> None:
117-
args.control_queue.put_nowait(req)
118-
if isinstance(req, (UpdateUIElementCommand, ModelCommand)):
119-
args.set_ui_element_queue.put_nowait(req)
120-
121146
# Deferred to break the runtime.py <-> kernel_lifecycle.py import cycle.
122147
from marimo._runtime.runtime import Kernel
123148

@@ -133,7 +158,10 @@ def _enqueue_control_request(req: CommandMessage) -> None:
133158
),
134159
debugger_override=args.debugger,
135160
user_config=user_config,
136-
enqueue_control_request=_enqueue_control_request,
161+
enqueue_control_request=make_control_enqueuer(
162+
args.control_queue,
163+
args.set_ui_element_queue,
164+
),
137165
hooks=_build_hooks(args.is_edit_mode, user_config),
138166
)
139167
ctx = initialize_kernel_context(

marimo/_runtime/runtime.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -702,19 +702,16 @@ def start_completion_worker(
702702
self, completion_queue: QueueType[CodeCompletionCommand]
703703
) -> None:
704704
"""Must be called after context is initialized"""
705-
from marimo._runtime.complete import completion_worker
705+
from marimo._runtime.kernel_lifecycle import drain_stale
706706

707-
threading.Thread(
708-
target=completion_worker,
709-
args=(
710-
completion_queue,
711-
self.graph,
712-
self.globals,
713-
self._globals_lock,
714-
get_context().stream,
715-
),
716-
daemon=True,
717-
).start()
707+
def _worker() -> None:
708+
while True:
709+
request = drain_stale(
710+
completion_queue, latest=completion_queue.get()
711+
)
712+
self.code_completion(request, docstrings_limit=80)
713+
714+
threading.Thread(target=_worker, daemon=True).start()
718715
self._completion_worker_started = True
719716

720717
@kernel_tracer.start_as_current_span("code_completion")
@@ -728,7 +725,7 @@ def code_completion(
728725
self.graph,
729726
self.globals,
730727
self._globals_lock,
731-
get_context().stream,
728+
self.stream,
732729
docstrings_limit,
733730
)
734731

tests/_runtime/test_kernel_lifecycle.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,25 @@
22
from __future__ import annotations
33

44
import asyncio
5+
import queue as _queue
56
from typing import Any
67
from unittest.mock import AsyncMock, MagicMock
78

89
import pytest
910

1011
from marimo._runtime.commands import (
1112
ExecuteCellsCommand,
13+
ModelCommand,
1214
StopKernelCommand,
1315
UpdateUIElementCommand,
1416
)
1517
from marimo._runtime.kernel_lifecycle import (
1618
asyncio_queue_reader,
19+
drain_stale,
1720
listen_messages,
21+
make_control_enqueuer,
1822
)
19-
from marimo._types.ids import CellId_t, UIElementId
23+
from marimo._types.ids import CellId_t, UIElementId, WidgetModelId
2024

2125

2226
@pytest.fixture
@@ -151,3 +155,76 @@ async def test_listen_messages_merges_ui_updates(
151155
dispatched = kernel.handle_message.await_args.args[0]
152156
assert isinstance(dispatched, UpdateUIElementCommand)
153157
assert dispatched.values == [2]
158+
159+
160+
@pytest.mark.parametrize(
161+
"queue_factory",
162+
[asyncio.Queue, _queue.Queue],
163+
ids=["asyncio", "threading"],
164+
)
165+
def test_drain_stale_returns_latest_when_queue_empty(
166+
queue_factory: Any,
167+
) -> None:
168+
q = queue_factory()
169+
latest = _execute("only")
170+
assert drain_stale(q, latest=latest) is latest
171+
172+
173+
@pytest.mark.parametrize(
174+
"queue_factory",
175+
[asyncio.Queue, _queue.Queue],
176+
ids=["asyncio", "threading"],
177+
)
178+
def test_drain_stale_returns_newest_pending(queue_factory: Any) -> None:
179+
q = queue_factory()
180+
initial = _execute("initial")
181+
newer = _execute("newer")
182+
newest = _execute("newest")
183+
q.put_nowait(newer)
184+
q.put_nowait(newest)
185+
186+
assert drain_stale(q, latest=initial) is newest
187+
# Drained: nothing else remains.
188+
assert q.empty()
189+
190+
191+
def test_make_control_enqueuer_routes_plain_command_to_control_only() -> None:
192+
control: asyncio.Queue[Any] = asyncio.Queue()
193+
ui: asyncio.Queue[Any] = asyncio.Queue()
194+
enqueue = make_control_enqueuer(control, ui)
195+
196+
cmd = _execute()
197+
enqueue(cmd)
198+
199+
assert control.get_nowait() is cmd
200+
assert ui.empty()
201+
202+
203+
def test_make_control_enqueuer_mirrors_ui_element_command() -> None:
204+
control: asyncio.Queue[Any] = asyncio.Queue()
205+
ui: asyncio.Queue[Any] = asyncio.Queue()
206+
enqueue = make_control_enqueuer(control, ui)
207+
208+
cmd = _ui_update("u", 1)
209+
enqueue(cmd)
210+
211+
assert control.get_nowait() is cmd
212+
assert ui.get_nowait() is cmd
213+
214+
215+
def test_make_control_enqueuer_mirrors_model_command() -> None:
216+
from marimo._runtime.commands import ModelUpdateMessage
217+
218+
control: asyncio.Queue[Any] = asyncio.Queue()
219+
ui: asyncio.Queue[Any] = asyncio.Queue()
220+
enqueue = make_control_enqueuer(control, ui)
221+
222+
cmd = ModelCommand(
223+
model_id=WidgetModelId("m1"),
224+
message=ModelUpdateMessage(state={"x": 1}, buffer_paths=[]),
225+
buffers=[],
226+
)
227+
enqueue(cmd)
228+
229+
assert control.get_nowait() is cmd
230+
assert ui.get_nowait() is cmd

0 commit comments

Comments
 (0)