Skip to content

[Data][LLM] Fix max_pending_requests default to track vLLM's GPU-dependent max_num_seqs#62918

Merged
kouroshHakha merged 2 commits into
ray-project:masterfrom
Aydin-ab:fix-vllm-max-pending-requests-fallback
May 14, 2026
Merged

[Data][LLM] Fix max_pending_requests default to track vLLM's GPU-dependent max_num_seqs#62918
kouroshHakha merged 2 commits into
ray-project:masterfrom
Aydin-ab:fix-vllm-max-pending-requests-fallback

Conversation

@Aydin-ab
Copy link
Copy Markdown
Contributor

Summary

vLLMEngineStageUDF computed its default max_pending_requests from engine_kwargs.get("max_num_seqs", 128) * pp_size * 1.1. The hardcoded 128 fallback is stale — vLLM's actual default for max_num_seqs is GPU-dependent via AsyncEngineArgs.get_batch_defaults (vllm/engine/arg_utils.py):

GPU vLLM default Ray assumed
A10G / other <70 GiB 256 128
A100 80 GiB (explicitly excluded from large-GPU path) 256 128
H100 / MI300x (≥70 GiB, non-A100) 1024 128
CPU 256 × world_size 128

When users don't set max_num_seqs explicitly (the common case), the semaphore silently caps inflight requests far below vLLM's real capacity — e.g. ~141 vs 1024 on H100, ~14% utilization.

Fix

Move the default resolution into vLLMEngineWrapper, which already calls AsyncEngineArgs.create_engine_config() and has access to the resolved scheduler_config.max_num_seqs and parallel_config.pipeline_parallel_size. The UDF passes max_pending_requests=None through and reads the resolved value back.

Semantics:

  • max_pending_requests=None (default): auto-resolve from vLLM's resolved engine config
  • positive int: explicit limit (unchanged)
  • non-positive (e.g. -1): disable semaphore (unchanged)

This aligns with ProcessorConfig.max_pending_requests's stated intent: "If not specified, will use the default value from the backend engine."

Test plan

  • Existing test_vllm_engine_udf_basic updated to reflect that the UDF now reads max_pending_requests from the (mocked) wrapper and passes None through when the caller didn't supply a value.
  • Other tests that pass explicit max_pending_requests=10 are unaffected (positive int path unchanged).
  • test_vllm_wrapper_semaphore exercises max_pending_requests=2 (positive int) — unaffected.
  • Local pytest run not possible without a ray C-extension build; CI will exercise the GPU tests.

🤖 Generated with Claude Code

…ndent max_num_seqs

`vLLMEngineStageUDF` computed the default `max_pending_requests` as
`ceil(1.1 * engine_kwargs.get("max_num_seqs", 128) * pp_size)`. The
hardcoded `128` fallback does not match vLLM's actual default, which is
GPU-dependent via `AsyncEngineArgs.get_batch_defaults`:

- A10G (<70 GiB) / A100: 256
- H100 / MI300x (>=70 GiB, non-A100): 1024
- CPU: 256 * world_size

When users don't set `max_num_seqs` explicitly (the common case), the
semaphore silently caps inflight requests far below vLLM's real capacity
(e.g. ~141 vs 1024 on H100, ~14% utilization).

Move the default resolution into `vLLMEngineWrapper`, which already calls
`AsyncEngineArgs.create_engine_config()` and has access to the resolved
`scheduler_config.max_num_seqs` and `parallel_config.pipeline_parallel_size`.
The UDF passes `max_pending_requests=None` through as-is and reads the
resolved value back from the wrapper.

Behavior:
- `max_pending_requests=None` (default): auto-resolve from vLLM config
- positive int: explicit limit (unchanged)
- non-positive (e.g. -1): disable semaphore (unchanged)

This aligns with the `ProcessorConfig.max_pending_requests` field's
stated intent: "If not specified, will use the default value from the
backend engine."

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Aydin Abiar <aydin@anyscale.com>
@Aydin-ab Aydin-ab requested a review from a team as a code owner April 24, 2026 20:09
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the calculation of max_pending_requests within the vLLMEngineStage. The logic for resolving the default value has been moved from the UDF to the vLLMEngineWrapper, where it now dynamically calculates the limit based on vLLM's resolved engine configuration (specifically max_num_seqs and pipeline_parallel_size). This change ensures that the request concurrency limit correctly tracks GPU-dependent capacities rather than relying on hardcoded defaults. I have no feedback to provide as there were no review comments.

@ray-gardener ray-gardener Bot added the data Ray Data-related issues label Apr 25, 2026
@Aydin-ab Aydin-ab added the go add ONLY when ready to merge, run all tests label Apr 28, 2026
Copy link
Copy Markdown
Contributor

@jeffreywang-anyscale jeffreywang-anyscale left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚢

Signed-off-by: Aydin Abiar <aydin@anyscale.com>
@Aydin-ab Aydin-ab requested a review from kouroshHakha May 14, 2026 18:59
@kouroshHakha kouroshHakha merged commit 8f85308 into ray-project:master May 14, 2026
6 checks passed
@Aydin-ab Aydin-ab deleted the fix-vllm-max-pending-requests-fallback branch May 14, 2026 19:32
TruongQuangPhat pushed a commit to cyhapun/ray-fix-issue that referenced this pull request May 27, 2026
…ndent max_num_seqs (ray-project#62918)

Signed-off-by: Aydin Abiar <aydin@anyscale.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: phattruong <23120318@student.hcmus.edu.vn>
alexandrplashchinsky pushed a commit to alexandrplashchinsky/ray-alex that referenced this pull request May 29, 2026
…ndent max_num_seqs (ray-project#62918)

Signed-off-by: Aydin Abiar <aydin@anyscale.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Alexandr Plashchinsky <alexandr.plashchinsky@alexandrplashchinsky-H765G66H9V.local>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

data Ray Data-related issues go add ONLY when ready to merge, run all tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants