|
16 | 16 |
|
17 | 17 | from __future__ import annotations |
18 | 18 |
|
| 19 | +import asyncio |
19 | 20 | import json |
20 | 21 | from typing import AsyncIterator, List |
21 | 22 |
|
@@ -276,6 +277,155 @@ async def test_stream_via_scheduler_finish_reason_stop_on_cancel(tokenizer): |
276 | 277 | assert request.calls > 0 |
277 | 278 |
|
278 | 279 |
|
| 280 | +async def test_collect_non_streaming_tokens_cancels_on_disconnect(tokenizer): |
| 281 | + """The JSON response path must cancel its scheduler session when the |
| 282 | + client disconnects, otherwise one timed-out request can monopolize a |
| 283 | + single-slot server and turn later queued requests into 429s.""" |
| 284 | + from tests.inference_engine.server.conftest import DeterministicEngine |
| 285 | + from inference_engine.scheduler.session import SessionState |
| 286 | + from inference_engine.server.app import _collect_non_streaming_tokens |
| 287 | + |
| 288 | + ids = [tokenizer._intern(f"tok{i}") for i in range(20)] |
| 289 | + slow_engine = DeterministicEngine( |
| 290 | + fixed_tokens=ids, tokenizer=tokenizer, |
| 291 | + model_id_label="slow", per_token_delay_s=0.02, |
| 292 | + ) |
| 293 | + scheduler = _build_scheduler_with_engine(slow_engine) |
| 294 | + session = await scheduler.submit( |
| 295 | + prompt_ids=[1], max_new_tokens=20, eos_token_ids=[0], |
| 296 | + ) |
| 297 | + |
| 298 | + request = _FakeRequest([True]) |
| 299 | + tokens = await _collect_non_streaming_tokens( |
| 300 | + scheduler=scheduler, |
| 301 | + session=session, |
| 302 | + request=request, |
| 303 | + disconnect_poll_interval_s=0.0, |
| 304 | + ) |
| 305 | + |
| 306 | + assert tokens |
| 307 | + assert request.calls >= 1 |
| 308 | + assert session.state is SessionState.CANCELLED |
| 309 | + assert scheduler.active_count == 0 |
| 310 | + |
| 311 | + |
| 312 | +async def test_collect_non_streaming_tokens_propagates_cancel_and_cleans_up( |
| 313 | + tokenizer, |
| 314 | +): |
| 315 | + """When the awaiting task is externally cancelled (e.g. uvicorn |
| 316 | + shutdown or a request transport timeout cancels the route handler |
| 317 | + task), the asyncio.CancelledError must propagate up so the route |
| 318 | + handler's ``except asyncio.CancelledError`` branch fires and |
| 319 | + cancels the scheduler session — otherwise the slab stays held |
| 320 | + forever and downstream requests 429. |
| 321 | +
|
| 322 | + This test exercises the cancellation propagation contract on the |
| 323 | + helper itself (the helper does NOT swallow CancelledError) plus |
| 324 | + verifies that calling ``cancel_session`` after a CancelledError |
| 325 | + correctly releases the slab — which is exactly what the route |
| 326 | + handler's CancelledError catch does at line 384 of app.py. |
| 327 | + """ |
| 328 | + from tests.inference_engine.server.conftest import DeterministicEngine |
| 329 | + from inference_engine.scheduler.session import SessionState |
| 330 | + from inference_engine.server.app import _collect_non_streaming_tokens |
| 331 | + |
| 332 | + ids = [tokenizer._intern(f"tok{i}") for i in range(50)] |
| 333 | + slow_engine = DeterministicEngine( |
| 334 | + fixed_tokens=ids, tokenizer=tokenizer, |
| 335 | + model_id_label="slow", per_token_delay_s=0.05, |
| 336 | + ) |
| 337 | + scheduler = _build_scheduler_with_engine(slow_engine) |
| 338 | + session = await scheduler.submit( |
| 339 | + prompt_ids=[1], max_new_tokens=50, eos_token_ids=[0], |
| 340 | + ) |
| 341 | + |
| 342 | + # never disconnects; we cancel via task.cancel() instead |
| 343 | + request = _FakeRequest([]) |
| 344 | + |
| 345 | + async def _drain(): |
| 346 | + return await _collect_non_streaming_tokens( |
| 347 | + scheduler=scheduler, |
| 348 | + session=session, |
| 349 | + request=request, |
| 350 | + disconnect_poll_interval_s=10.0, # don't fire disconnect path |
| 351 | + ) |
| 352 | + |
| 353 | + task = asyncio.create_task(_drain()) |
| 354 | + # Let the helper start awaiting the queue |
| 355 | + await asyncio.sleep(0.05) |
| 356 | + task.cancel() |
| 357 | + |
| 358 | + with pytest.raises(asyncio.CancelledError): |
| 359 | + await task |
| 360 | + |
| 361 | + # The route handler's `except asyncio.CancelledError` catch performs |
| 362 | + # exactly this sequence. We assert it works correctly: cancel_session |
| 363 | + # is idempotent for already-cancelled sessions, and the slab is |
| 364 | + # released afterwards. |
| 365 | + await scheduler.cancel_session(session) |
| 366 | + # Allow the worker to observe the CANCELLED state and tear down |
| 367 | + await asyncio.sleep(0.2) |
| 368 | + |
| 369 | + assert session.state is SessionState.CANCELLED |
| 370 | + assert scheduler.active_count == 0 |
| 371 | + |
| 372 | + |
| 373 | +async def test_route_handler_cancelled_error_branch_releases_slab(tokenizer): |
| 374 | + """End-to-end via httpx.ASGITransport: cancel the in-flight POST |
| 375 | + request task and verify the slab is released so the next request |
| 376 | + is admitted (i.e. the route handler's CancelledError catch ran |
| 377 | + and released the slab via cancel_session). |
| 378 | +
|
| 379 | + This is the integration counterpart to |
| 380 | + ``test_collect_non_streaming_tokens_propagates_cancel_and_cleans_up`` |
| 381 | + — covers ``app.py`` line 384 (``await scheduler.cancel_session`` |
| 382 | + inside the ``except asyncio.CancelledError`` branch) through the |
| 383 | + real route handler closure. |
| 384 | + """ |
| 385 | + from tests.inference_engine.server.conftest import DeterministicEngine |
| 386 | + from inference_engine.server.app import create_app |
| 387 | + from inference_engine.server.config import ServerConfig |
| 388 | + |
| 389 | + ids = [tokenizer._intern(f"tok{i}") for i in range(50)] |
| 390 | + slow_engine = DeterministicEngine( |
| 391 | + fixed_tokens=ids, tokenizer=tokenizer, |
| 392 | + model_id_label="slow", per_token_delay_s=0.05, |
| 393 | + ) |
| 394 | + app = create_app(slow_engine, ServerConfig(max_concurrent=1)) |
| 395 | + async with AsyncClient( |
| 396 | + transport=ASGITransport(app=app), base_url="http://t", timeout=30.0, |
| 397 | + ) as c: |
| 398 | + post_task = asyncio.create_task(c.post( |
| 399 | + "/v1/chat/completions", json={ |
| 400 | + "model": "m", |
| 401 | + "messages": [{"role": "user", "content": "hi"}], |
| 402 | + "max_tokens": 50, |
| 403 | + }, |
| 404 | + )) |
| 405 | + # Let scheduler admit and start producing tokens |
| 406 | + await asyncio.sleep(0.1) |
| 407 | + post_task.cancel() |
| 408 | + with pytest.raises((asyncio.CancelledError, Exception)): |
| 409 | + await post_task |
| 410 | + # Allow worker to observe cancel and release slab |
| 411 | + await asyncio.sleep(0.3) |
| 412 | + # If the route handler's CancelledError branch ran, the slab |
| 413 | + # is back in the pool and a follow-up request gets admitted |
| 414 | + # (no 429). If it didn't run, the slab is still held. |
| 415 | + followup = await c.post( |
| 416 | + "/v1/chat/completions", json={ |
| 417 | + "model": "m", |
| 418 | + "messages": [{"role": "user", "content": "hi2"}], |
| 419 | + "max_tokens": 5, |
| 420 | + }, |
| 421 | + ) |
| 422 | + assert followup.status_code == 200, ( |
| 423 | + f"follow-up request got {followup.status_code}; " |
| 424 | + f"expected 200 because the cancelled request should have " |
| 425 | + f"released the slab via the CancelledError cleanup branch" |
| 426 | + ) |
| 427 | + |
| 428 | + |
279 | 429 | async def test_stream_via_scheduler_finish_reason_length_when_max_tokens(tokenizer): |
280 | 430 | """No disconnect, max_tokens cap → finish_reason='length'.""" |
281 | 431 | from tests.inference_engine.server.conftest import DeterministicEngine |
|
0 commit comments