44import uuid
55import asyncio
66
7+ import httpx
8+
79from asyncio import Queue
810from typing import (
911 Dict ,
2628 OutputType ,
2729 UnexpectedEndOfExecution ,
2830)
31+ from consts import JUPYTER_BASE_URL
2932from errors import ExecutionError
3033from envs import get_envs
3134
3235logger = logging .getLogger (__name__ )
3336
3437MAX_RECONNECT_RETRIES = 3
3538PING_TIMEOUT = 30
39+ KEEPALIVE_INTERVAL = 5 # seconds between keepalive pings during streaming
3640
3741
3842class Execution :
@@ -97,6 +101,22 @@ async def connect(self):
97101 name = "receive_message" ,
98102 )
99103
104+ async def interrupt (self ):
105+ """Interrupt the current kernel execution via the Jupyter REST API."""
106+ try :
107+ async with httpx .AsyncClient () as client :
108+ response = await client .post (
109+ f"{ JUPYTER_BASE_URL } /api/kernels/{ self .context_id } /interrupt"
110+ )
111+ if response .is_success :
112+ logger .info (f"Kernel { self .context_id } interrupted successfully" )
113+ else :
114+ logger .error (
115+ f"Failed to interrupt kernel { self .context_id } : { response .status_code } "
116+ )
117+ except Exception as e :
118+ logger .error (f"Error interrupting kernel { self .context_id } : { e } " )
119+
100120 def _get_execute_request (
101121 self , msg_id : str , code : Union [str , StrictStr ], background : bool
102122 ) -> str :
@@ -239,7 +259,18 @@ async def _wait_for_result(self, message_id: str):
239259 queue = self ._executions [message_id ].queue
240260
241261 while True :
242- output = await queue .get ()
262+ try :
263+ output = await asyncio .wait_for (
264+ queue .get (), timeout = KEEPALIVE_INTERVAL
265+ )
266+ except asyncio .TimeoutError :
267+ # Yield a keepalive so Starlette writes to the socket.
268+ # If the client has disconnected, the write fails and
269+ # uvicorn delivers http.disconnect, which cancels this
270+ # generator via CancelledError.
271+ yield {"type" : "keepalive" }
272+ continue
273+
243274 if output .type == OutputType .END_OF_EXECUTION :
244275 break
245276
@@ -294,11 +325,6 @@ async def execute(
294325 if self ._ws is None :
295326 raise Exception ("WebSocket not connected" )
296327
297- # Lock only the setup + send phase, not the streaming phase.
298- # Results are read from a per-execution queue (keyed by message_id)
299- # so streaming doesn't need serialization. Releasing before streaming
300- # prevents client disconnect from holding the lock until the kernel
301- # finishes execution (see #213).
302328 async with self ._lock :
303329 # Wait for any pending cleanup task to complete
304330 if self ._cleanup_task and not self ._cleanup_task .done ():
@@ -367,22 +393,35 @@ async def execute(
367393 )
368394 await execution .queue .put (UnexpectedEndOfExecution ())
369395
370- # Schedule env var cleanup inside the lock so the next execution
371- # can wait for it. The task sends reset code to the kernel, which
372- # queues it after the current execution's code.
373- if env_vars :
396+ # Stream the results.
397+ # If the client disconnects (Starlette cancels the task), we
398+ # interrupt the kernel so the next execution isn't blocked (#213).
399+ client_disconnected = False
400+ try :
401+ async for item in self ._wait_for_result (message_id ):
402+ yield item
403+ except (asyncio .CancelledError , GeneratorExit ):
404+ client_disconnected = True
405+ logger .warning (
406+ f"Client disconnected during execution ({ message_id } ), interrupting kernel"
407+ )
408+ # Shield the interrupt from the ongoing cancellation so
409+ # the HTTP request to the kernel actually completes.
410+ try :
411+ await asyncio .shield (self .interrupt ())
412+ except asyncio .CancelledError :
413+ pass
414+ raise
415+ finally :
416+ if message_id in self ._executions :
417+ del self ._executions [message_id ]
418+
419+ # Clean up env vars in a separate request after the main code has run
420+ if env_vars and not client_disconnected :
374421 self ._cleanup_task = asyncio .create_task (
375422 self ._cleanup_env_vars (env_vars )
376423 )
377424
378- # Stream the results without holding the lock
379- try :
380- async for item in self ._wait_for_result (message_id ):
381- yield item
382- finally :
383- if message_id in self ._executions :
384- del self ._executions [message_id ]
385-
386425 async def _receive_message (self ):
387426 if not self ._ws :
388427 logger .error ("No WebSocket connection" )
@@ -394,7 +433,8 @@ async def _receive_message(self):
394433 except Exception as e :
395434 logger .error (f"WebSocket received error while receiving messages: { str (e )} " )
396435 finally :
397- # To prevent infinite hang, cancel all ongoing executions as results may be lost during reconnect.
436+ # To prevent infinite hang, we need to cancel all ongoing execution as we could lost results during the reconnect
437+ # Thanks to the locking, there can be either no ongoing execution or just one.
398438 for key , execution in self ._executions .items ():
399439 await execution .queue .put (
400440 Error (
0 commit comments