Skip to content

Commit 8f0ebe0

Browse files
authored
Refactor session extensions to use EventAwareExtension base class (#8678)
1 parent 609a3f8 commit 8f0ebe0

11 files changed

Lines changed: 331 additions & 204 deletions

File tree

marimo/_internal/session/extensions.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,16 @@
1010
ReplayExtension,
1111
SessionViewExtension,
1212
)
13-
from marimo._session.extensions.types import SessionExtension
13+
from marimo._session.extensions.types import (
14+
EventAwareExtension,
15+
ExtensionRegistry,
16+
SessionExtension,
17+
)
1418

1519
__all__ = [
1620
"CachingExtension",
21+
"EventAwareExtension",
22+
"ExtensionRegistry",
1723
"HeartbeatExtension",
1824
"LoggingExtension",
1925
"NotificationListenerExtension",

marimo/_server/scratchpad.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from marimo._messaging.notification import CellNotification
1919
from marimo._messaging.serde import deserialize_kernel_message
2020
from marimo._runtime.scratch import SCRATCH_CELL_ID
21-
from marimo._session.events import SessionEventBus, SessionEventListener
21+
from marimo._session.extensions.types import EventAwareExtension
2222

2323
if TYPE_CHECKING:
2424
from collections.abc import AsyncGenerator
@@ -73,7 +73,7 @@ def _format_sse(event: str, data: Any) -> str:
7373
# -- Listeners ----------------------------------------------------------------
7474

7575

76-
class ScratchCellListener(SessionEventListener):
76+
class ScratchCellListener(EventAwareExtension):
7777
"""Listens for scratch cell notifications via an asyncio.Queue.
7878
7979
Supports both SSE streaming (via ``stream()``) and simple blocking
@@ -82,19 +82,10 @@ class ScratchCellListener(SessionEventListener):
8282
"""
8383

8484
def __init__(self) -> None:
85+
super().__init__()
8586
self._queue: asyncio.Queue[CellNotification | None] = asyncio.Queue()
8687
self.timed_out = False
8788

88-
def on_attach(self, session: Session, event_bus: SessionEventBus) -> None:
89-
del session
90-
self._event_bus = event_bus
91-
event_bus.subscribe(self)
92-
93-
def on_detach(self) -> None:
94-
if hasattr(self, "_event_bus"):
95-
self._event_bus.unsubscribe(self)
96-
del self._event_bus
97-
9889
def on_notification_sent(
9990
self, session: Session, notification: KernelMessage
10091
) -> None:

marimo/_server/session/listeners.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,12 @@
88
from __future__ import annotations
99

1010
from marimo._server.recents import RecentFilesManager
11-
from marimo._session.events import (
12-
SessionEventBus,
13-
SessionEventListener,
14-
)
15-
from marimo._session.extensions.types import SessionExtension
11+
from marimo._session.extensions.types import EventAwareExtension
1612
from marimo._session.session import Session
1713
from marimo._types.ids import SessionId
1814

1915

20-
class RecentsTrackerListener(SessionExtension, SessionEventListener):
16+
class RecentsTrackerListener(EventAwareExtension):
2117
"""Event listener that tracks recently accessed files."""
2218

2319
def __init__(self, recents_manager: RecentFilesManager) -> None:
@@ -26,20 +22,8 @@ def __init__(self, recents_manager: RecentFilesManager) -> None:
2622
Args:
2723
recents_manager: Manager for recent files
2824
"""
25+
super().__init__()
2926
self._recents = recents_manager
30-
self._event_bus: SessionEventBus | None = None
31-
32-
def on_attach(self, session: Session, event_bus: SessionEventBus) -> None:
33-
"""Attach the recents tracker listener to a session."""
34-
del session
35-
self._event_bus = event_bus
36-
self._event_bus.subscribe(self)
37-
38-
def on_detach(self) -> None:
39-
"""Detach the recents tracker listener from a session."""
40-
if self._event_bus:
41-
self._event_bus.unsubscribe(self)
42-
self._event_bus = None
4327

4428
async def on_session_created(self, session: Session) -> None:
4529
"""Update recent files when a session is created."""

marimo/_session/events.py

Lines changed: 62 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from marimo._types.ids import ConsumerId, SessionId
1414

1515
if TYPE_CHECKING:
16+
from collections.abc import Awaitable, Callable
17+
1618
from marimo._runtime import commands
1719
from marimo._session.session import Session
1820

@@ -96,76 +98,84 @@ def unsubscribe(self, listener: SessionEventListener) -> None:
9698
if listener in self._listeners:
9799
self._listeners.remove(listener)
98100

99-
async def emit_session_created(self, session: Session) -> None:
100-
"""Emit a session created event."""
101-
for listener in self._listeners:
101+
def _emit(
102+
self,
103+
event_name: str,
104+
call: Callable[[SessionEventListener], None],
105+
) -> None:
106+
"""Dispatch a synchronous event to all listeners."""
107+
for listener in list(self._listeners):
102108
try:
103-
await listener.on_session_created(session)
109+
call(listener)
104110
except Exception as e:
105111
LOGGER.error(
106-
"Error handling session created event for listener %s: %s",
112+
"Error handling %s for listener %s: %s",
113+
event_name,
107114
listener,
108115
e,
109116
)
110-
continue
111117

112-
async def emit_session_closed(self, session: Session) -> None:
113-
"""Emit a session closed event."""
114-
for listener in self._listeners:
118+
async def _emit_async(
119+
self,
120+
event_name: str,
121+
call: Callable[[SessionEventListener], Awaitable[None]],
122+
) -> None:
123+
"""Dispatch an async event to all listeners."""
124+
for listener in list(self._listeners):
115125
try:
116-
await listener.on_session_closed(session)
126+
await call(listener)
117127
except Exception as e:
118128
LOGGER.error(
119-
"Error handling session closed event for listener %s: %s",
129+
"Error handling %s for listener %s: %s",
130+
event_name,
120131
listener,
121132
e,
122133
)
123-
continue
134+
135+
async def emit_session_created(self, session: Session) -> None:
136+
"""Emit a session created event."""
137+
await self._emit_async(
138+
"session_created",
139+
lambda listener: listener.on_session_created(session),
140+
)
141+
142+
async def emit_session_closed(self, session: Session) -> None:
143+
"""Emit a session closed event."""
144+
await self._emit_async(
145+
"session_closed",
146+
lambda listener: listener.on_session_closed(session),
147+
)
124148

125149
async def emit_session_resumed(
126150
self, session: Session, old_id: SessionId
127151
) -> None:
128152
"""Emit a session resumed event."""
129-
for listener in self._listeners:
130-
try:
131-
await listener.on_session_resumed(session, old_id)
132-
except Exception as e:
133-
LOGGER.error(
134-
"Error handling session resumed event for listener %s: %s",
135-
listener,
136-
e,
137-
)
138-
continue
153+
await self._emit_async(
154+
"session_resumed",
155+
lambda listener: listener.on_session_resumed(session, old_id),
156+
)
139157

140158
async def emit_session_notebook_renamed(
141159
self, session: Session, old_path: str | None
142160
) -> None:
143161
"""Emit a session renamed event."""
144-
for listener in self._listeners:
145-
try:
146-
await listener.on_session_notebook_renamed(session, old_path)
147-
except Exception as e:
148-
LOGGER.error(
149-
"Error handling session notebook renamed event for listener %s: %s",
150-
listener,
151-
e,
152-
)
153-
continue
162+
await self._emit_async(
163+
"session_notebook_renamed",
164+
lambda listener: listener.on_session_notebook_renamed(
165+
session, old_path
166+
),
167+
)
154168

155169
def emit_notification_sent(
156170
self, session: Session, notification: KernelMessage
157171
) -> None:
158172
"""Emit a notification sent event."""
159-
for listener in self._listeners:
160-
try:
161-
listener.on_notification_sent(session, notification)
162-
except Exception as e:
163-
LOGGER.error(
164-
"Error handling notification sent event for listener %s: %s",
165-
listener,
166-
e,
167-
)
168-
continue
173+
self._emit(
174+
"notification_sent",
175+
lambda listener: listener.on_notification_sent(
176+
session, notification
177+
),
178+
)
169179

170180
def emit_received_command(
171181
self,
@@ -174,28 +184,16 @@ def emit_received_command(
174184
from_consumer_id: Optional[ConsumerId],
175185
) -> None:
176186
"""Emit a received command event."""
177-
for listener in self._listeners:
178-
try:
179-
listener.on_received_command(
180-
session, request, from_consumer_id
181-
)
182-
except Exception as e:
183-
LOGGER.error(
184-
"Error handling received command event for listener %s: %s",
185-
listener,
186-
e,
187-
)
188-
continue
187+
self._emit(
188+
"received_command",
189+
lambda listener: listener.on_received_command(
190+
session, request, from_consumer_id
191+
),
192+
)
189193

190194
def emit_received_stdin(self, session: Session, stdin: str) -> None:
191195
"""Emit a received stdin event."""
192-
for listener in self._listeners:
193-
try:
194-
listener.on_received_stdin(session, stdin)
195-
except Exception as e:
196-
LOGGER.error(
197-
"Error handling received stdin event for listener %s: %s",
198-
listener,
199-
e,
200-
)
201-
continue
196+
self._emit(
197+
"received_stdin",
198+
lambda listener: listener.on_received_stdin(session, stdin),
199+
)

0 commit comments

Comments
 (0)