Skip to content

Commit 066d2c3

Browse files
authored
Merge pull request #25 from FluffyAIcode/AgentMemory/scheduler-orphan-session-fix-8e7f
Cancel scheduler session on HTTP disconnect / drain error
2 parents 9b69582 + c5acaa9 commit 066d2c3

3 files changed

Lines changed: 194 additions & 5 deletions

File tree

inference_engine/scheduler/scheduler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,9 +343,13 @@ def on_token(tok_id: int) -> bool:
343343
if session.state == SessionState.CANCELLED:
344344
return True
345345
session.output_token_ids.append(int(tok_id))
346-
asyncio.run_coroutine_threadsafe(
346+
enqueue = asyncio.run_coroutine_threadsafe(
347347
session.token_queue.put(int(tok_id)), loop
348348
)
349+
# Preserve token-before-sentinel ordering. The worker
350+
# runs in a thread, while the terminal sentinel is pushed
351+
# back on the event loop after generate() returns.
352+
enqueue.result()
349353
return False
350354

351355
result = await asyncio.to_thread(

inference_engine/server/app.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656

5757
from __future__ import annotations
5858

59+
import asyncio
5960
import time
6061
import uuid
6162
from contextlib import asynccontextmanager
@@ -370,12 +371,20 @@ async def chat_completions(req: ChatCompletionRequest, request: Request):
370371
media_type="text/event-stream",
371372
)
372373

373-
# Non-streaming: drain the session synchronously.
374-
output_token_ids: List[int] = []
375374
try:
376-
async for tok in scheduler.iter_tokens(session):
377-
output_token_ids.append(int(tok))
375+
output_token_ids = await _collect_non_streaming_tokens(
376+
scheduler=scheduler,
377+
session=session,
378+
request=request,
379+
)
380+
except asyncio.CancelledError:
381+
# Client timed out/disconnected while the JSON response was
382+
# draining. Without explicit cancellation the worker can keep
383+
# occupying the only slab, causing later queued requests to 429.
384+
await scheduler.cancel_session(session)
385+
raise
378386
except BaseException as exc:
387+
await scheduler.cancel_session(session)
379388
# Engine raised mid-generate; surface as 500.
380389
raise HTTPException(
381390
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -469,6 +478,32 @@ def _session_acceptance_rate(
469478
return None
470479

471480

481+
async def _collect_non_streaming_tokens(
482+
*,
483+
scheduler: Scheduler,
484+
session: Session,
485+
request: Request,
486+
disconnect_poll_interval_s: float = 0.05,
487+
) -> List[int]:
488+
"""Drain a non-streaming session while honoring client disconnects.
489+
490+
Streaming responses already poll ``request.is_disconnected()`` and
491+
cancel their scheduler session. JSON responses need the same cleanup:
492+
a timed-out client otherwise leaves the scheduler worker running until
493+
generation finishes, which can monopolize a single-slot server.
494+
"""
495+
output_token_ids: List[int] = []
496+
last_disconnect_check = time.monotonic()
497+
async for tok in scheduler.iter_tokens(session):
498+
output_token_ids.append(int(tok))
499+
now = time.monotonic()
500+
if (now - last_disconnect_check) >= disconnect_poll_interval_s:
501+
last_disconnect_check = now
502+
if await request.is_disconnected():
503+
await scheduler.cancel_session(session)
504+
return output_token_ids
505+
506+
472507
async def _stream_via_scheduler(
473508
*,
474509
scheduler: Scheduler,

tests/inference_engine/server/test_app_streaming.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import asyncio
1920
import json
2021
from typing import AsyncIterator, List
2122

@@ -276,6 +277,155 @@ async def test_stream_via_scheduler_finish_reason_stop_on_cancel(tokenizer):
276277
assert request.calls > 0
277278

278279

280+
async def test_collect_non_streaming_tokens_cancels_on_disconnect(tokenizer):
281+
"""The JSON response path must cancel its scheduler session when the
282+
client disconnects, otherwise one timed-out request can monopolize a
283+
single-slot server and turn later queued requests into 429s."""
284+
from tests.inference_engine.server.conftest import DeterministicEngine
285+
from inference_engine.scheduler.session import SessionState
286+
from inference_engine.server.app import _collect_non_streaming_tokens
287+
288+
ids = [tokenizer._intern(f"tok{i}") for i in range(20)]
289+
slow_engine = DeterministicEngine(
290+
fixed_tokens=ids, tokenizer=tokenizer,
291+
model_id_label="slow", per_token_delay_s=0.02,
292+
)
293+
scheduler = _build_scheduler_with_engine(slow_engine)
294+
session = await scheduler.submit(
295+
prompt_ids=[1], max_new_tokens=20, eos_token_ids=[0],
296+
)
297+
298+
request = _FakeRequest([True])
299+
tokens = await _collect_non_streaming_tokens(
300+
scheduler=scheduler,
301+
session=session,
302+
request=request,
303+
disconnect_poll_interval_s=0.0,
304+
)
305+
306+
assert tokens
307+
assert request.calls >= 1
308+
assert session.state is SessionState.CANCELLED
309+
assert scheduler.active_count == 0
310+
311+
312+
async def test_collect_non_streaming_tokens_propagates_cancel_and_cleans_up(
313+
tokenizer,
314+
):
315+
"""When the awaiting task is externally cancelled (e.g. uvicorn
316+
shutdown or a request transport timeout cancels the route handler
317+
task), the asyncio.CancelledError must propagate up so the route
318+
handler's ``except asyncio.CancelledError`` branch fires and
319+
cancels the scheduler session — otherwise the slab stays held
320+
forever and downstream requests 429.
321+
322+
This test exercises the cancellation propagation contract on the
323+
helper itself (the helper does NOT swallow CancelledError) plus
324+
verifies that calling ``cancel_session`` after a CancelledError
325+
correctly releases the slab — which is exactly what the route
326+
handler's CancelledError catch does at line 384 of app.py.
327+
"""
328+
from tests.inference_engine.server.conftest import DeterministicEngine
329+
from inference_engine.scheduler.session import SessionState
330+
from inference_engine.server.app import _collect_non_streaming_tokens
331+
332+
ids = [tokenizer._intern(f"tok{i}") for i in range(50)]
333+
slow_engine = DeterministicEngine(
334+
fixed_tokens=ids, tokenizer=tokenizer,
335+
model_id_label="slow", per_token_delay_s=0.05,
336+
)
337+
scheduler = _build_scheduler_with_engine(slow_engine)
338+
session = await scheduler.submit(
339+
prompt_ids=[1], max_new_tokens=50, eos_token_ids=[0],
340+
)
341+
342+
# never disconnects; we cancel via task.cancel() instead
343+
request = _FakeRequest([])
344+
345+
async def _drain():
346+
return await _collect_non_streaming_tokens(
347+
scheduler=scheduler,
348+
session=session,
349+
request=request,
350+
disconnect_poll_interval_s=10.0, # don't fire disconnect path
351+
)
352+
353+
task = asyncio.create_task(_drain())
354+
# Let the helper start awaiting the queue
355+
await asyncio.sleep(0.05)
356+
task.cancel()
357+
358+
with pytest.raises(asyncio.CancelledError):
359+
await task
360+
361+
# The route handler's `except asyncio.CancelledError` catch performs
362+
# exactly this sequence. We assert it works correctly: cancel_session
363+
# is idempotent for already-cancelled sessions, and the slab is
364+
# released afterwards.
365+
await scheduler.cancel_session(session)
366+
# Allow the worker to observe the CANCELLED state and tear down
367+
await asyncio.sleep(0.2)
368+
369+
assert session.state is SessionState.CANCELLED
370+
assert scheduler.active_count == 0
371+
372+
373+
async def test_route_handler_cancelled_error_branch_releases_slab(tokenizer):
374+
"""End-to-end via httpx.ASGITransport: cancel the in-flight POST
375+
request task and verify the slab is released so the next request
376+
is admitted (i.e. the route handler's CancelledError catch ran
377+
and released the slab via cancel_session).
378+
379+
This is the integration counterpart to
380+
``test_collect_non_streaming_tokens_propagates_cancel_and_cleans_up``
381+
— covers ``app.py`` line 384 (``await scheduler.cancel_session``
382+
inside the ``except asyncio.CancelledError`` branch) through the
383+
real route handler closure.
384+
"""
385+
from tests.inference_engine.server.conftest import DeterministicEngine
386+
from inference_engine.server.app import create_app
387+
from inference_engine.server.config import ServerConfig
388+
389+
ids = [tokenizer._intern(f"tok{i}") for i in range(50)]
390+
slow_engine = DeterministicEngine(
391+
fixed_tokens=ids, tokenizer=tokenizer,
392+
model_id_label="slow", per_token_delay_s=0.05,
393+
)
394+
app = create_app(slow_engine, ServerConfig(max_concurrent=1))
395+
async with AsyncClient(
396+
transport=ASGITransport(app=app), base_url="http://t", timeout=30.0,
397+
) as c:
398+
post_task = asyncio.create_task(c.post(
399+
"/v1/chat/completions", json={
400+
"model": "m",
401+
"messages": [{"role": "user", "content": "hi"}],
402+
"max_tokens": 50,
403+
},
404+
))
405+
# Let scheduler admit and start producing tokens
406+
await asyncio.sleep(0.1)
407+
post_task.cancel()
408+
with pytest.raises((asyncio.CancelledError, Exception)):
409+
await post_task
410+
# Allow worker to observe cancel and release slab
411+
await asyncio.sleep(0.3)
412+
# If the route handler's CancelledError branch ran, the slab
413+
# is back in the pool and a follow-up request gets admitted
414+
# (no 429). If it didn't run, the slab is still held.
415+
followup = await c.post(
416+
"/v1/chat/completions", json={
417+
"model": "m",
418+
"messages": [{"role": "user", "content": "hi2"}],
419+
"max_tokens": 5,
420+
},
421+
)
422+
assert followup.status_code == 200, (
423+
f"follow-up request got {followup.status_code}; "
424+
f"expected 200 because the cancelled request should have "
425+
f"released the slab via the CancelledError cleanup branch"
426+
)
427+
428+
279429
async def test_stream_via_scheduler_finish_reason_length_when_max_tokens(tokenizer):
280430
"""No disconnect, max_tokens cap → finish_reason='length'."""
281431
from tests.inference_engine.server.conftest import DeterministicEngine

0 commit comments

Comments
 (0)