Skip to content

Commit afee413

Browse files
[data][llm] Promote max_tasks_in_flight_per_actor to a first-class config field and adjust defaults (ray-project#63214)
Signed-off-by: Jeffrey Wang <jeffreywang@anyscale.com>
1 parent ac571c9 commit afee413

7 files changed

Lines changed: 173 additions & 96 deletions

File tree

python/ray/data/llm.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,16 @@ class vLLMEngineProcessorConfig(_vLLMEngineProcessorConfig):
125125
This is to overlap the batch processing to avoid the tail latency of
126126
each batch. The default value may not be optimal when the batch size
127127
or the batch processing latency is too small, but it should be good
128-
enough for batch size >= 64.
128+
enough for batch size >= 64. Sets the engine actor's Ray Core
129+
``max_concurrency``.
130+
max_tasks_in_flight_per_actor: Max tasks Ray Data submits concurrently to
131+
each engine actor. Passed through to ``ray.data.ActorPoolStrategy``.
132+
If unset, Ray Data uses
133+
``ray.data.DataContext.max_tasks_in_flight_per_actor`` if set globally.
134+
Otherwise, it defaults to ``2 * max_concurrent_batches``; the factor
135+
can be overridden via the
136+
``RAY_DATA_ACTOR_DEFAULT_MAX_TASKS_IN_FLIGHT_TO_MAX_CONCURRENCY_FACTOR``
137+
env var.
129138
should_continue_on_error: If True, continue processing when inference fails for a row
130139
instead of raising an exception. Failed rows will have a non-empty
131140
``__inference_error__`` column containing the error message, and other
@@ -233,7 +242,16 @@ class SGLangEngineProcessorConfig(_SGLangEngineProcessorConfig):
233242
This is to overlap the batch processing to avoid the tail latency of
234243
each batch. The default value may not be optimal when the batch size
235244
or the batch processing latency is too small, but it should be good
236-
enough for batch size >= 64.
245+
enough for batch size >= 64. Sets the engine actor's Ray Core
246+
``max_concurrency``.
247+
max_tasks_in_flight_per_actor: Max tasks Ray Data submits concurrently to
248+
each engine actor. Passed through to ``ray.data.ActorPoolStrategy``.
249+
If unset, Ray Data uses
250+
``ray.data.DataContext.max_tasks_in_flight_per_actor`` if set globally.
251+
Otherwise, it defaults to ``2 * max_concurrent_batches``; the factor
252+
can be overridden via the
253+
``RAY_DATA_ACTOR_DEFAULT_MAX_TASKS_IN_FLIGHT_TO_MAX_CONCURRENCY_FACTOR``
254+
env var.
237255
chat_template_stage: Chat templating stage config (bool | dict | ChatTemplateStageConfig).
238256
Defaults to True. Use nested config for per-stage control over batch_size,
239257
concurrency, runtime_env, num_cpus, and memory. Legacy ``apply_chat_template``

python/ray/llm/_internal/batch/processor/base.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,6 @@
1717
logger = logging.getLogger(__name__)
1818

1919

20-
# Higher values here are better for prefetching and locality. It's ok for this to be
21-
# fairly high since streaming backpressure prevents us from overloading actors.
22-
DEFAULT_MAX_TASKS_IN_FLIGHT = 16
23-
24-
2520
class ProcessorConfig(BaseModelExtended):
2621
"""The processor configuration."""
2722

@@ -55,9 +50,12 @@ class ProcessorConfig(BaseModelExtended):
5550

5651
experimental: Dict[str, Any] = Field(
5752
default_factory=dict,
58-
description="[Experimental] Experimental configurations."
53+
description="[Experimental] Experimental configurations. "
5954
"Supported keys:\n"
60-
"`max_tasks_in_flight_per_actor`: The maximum number of tasks in flight per actor. Default to 16.",
55+
"`max_tasks_in_flight_per_actor`: [DEPRECATED] Prefer the top-level "
56+
"`max_tasks_in_flight_per_actor` field on `OfflineProcessorConfig`. "
57+
"Setting it here is still respected (and overridden by the top-level "
58+
"field if both are set), but logs a deprecation warning.",
6159
)
6260

6361
@field_validator("concurrency")
@@ -156,7 +154,21 @@ class OfflineProcessorConfig(ProcessorConfig):
156154
"This is to overlap the batch processing to avoid the tail latency of "
157155
"each batch. The default value may not be optimal when the batch size "
158156
"or the batch processing latency is too small, but it should be good "
159-
"enough for batch size >= 32.",
157+
"enough for batch size >= 32. Sets the engine actor's Ray Core "
158+
"`max_concurrency`.",
159+
)
160+
max_tasks_in_flight_per_actor: Optional[int] = Field(
161+
default=None,
162+
description="Max tasks Ray Data submits concurrently to each engine "
163+
"actor. Passed through to `ray.data.ActorPoolStrategy`. If unset, Ray "
164+
"Data uses `ray.data.DataContext.max_tasks_in_flight_per_actor` if set "
165+
"globally. Otherwise, it defaults to `2 * max_concurrent_batches`; the "
166+
"factor can be overridden via the "
167+
"`RAY_DATA_ACTOR_DEFAULT_MAX_TASKS_IN_FLIGHT_TO_MAX_CONCURRENCY_FACTOR` "
168+
"env var. "
169+
"Setting this lower than `max_concurrent_batches` can underutilize the "
170+
"engine actor because Ray Data submits fewer tasks than the actor can "
171+
"process concurrently.",
160172
)
161173
should_continue_on_error: bool = Field(
162174
default=False,
@@ -260,6 +272,44 @@ def _coerce_legacy_to_stage_config(cls, values: Dict[str, Any]) -> Dict[str, Any
260272

261273
return values
262274

275+
@model_validator(mode="before")
276+
def _migrate_experimental_max_tasks_in_flight_per_actor(
277+
cls, values: Dict[str, Any]
278+
) -> Dict[str, Any]:
279+
"""Migrate deprecated `experimental[max_tasks_in_flight_per_actor]` to
280+
the top-level field; top-level wins if both are set."""
281+
experimental = values.get("experimental") or {}
282+
if "max_tasks_in_flight_per_actor" in experimental:
283+
logger.warning(
284+
"Setting `max_tasks_in_flight_per_actor` via `experimental` is "
285+
"deprecated; use the top-level `max_tasks_in_flight_per_actor` "
286+
"field on `OfflineProcessorConfig` instead. The value in "
287+
"`experimental` is still respected for now (and overridden by "
288+
"the top-level field if both are set), but will be removed in "
289+
"a future version."
290+
)
291+
if values.get("max_tasks_in_flight_per_actor") is None:
292+
values["max_tasks_in_flight_per_actor"] = experimental[
293+
"max_tasks_in_flight_per_actor"
294+
]
295+
return values
296+
297+
@model_validator(mode="after")
298+
def _warn_if_max_tasks_in_flight_underutilizes_actor(self):
299+
if (
300+
self.max_tasks_in_flight_per_actor is not None
301+
and self.max_tasks_in_flight_per_actor < self.max_concurrent_batches
302+
):
303+
logger.warning(
304+
"Setting `max_tasks_in_flight_per_actor` (%s) lower than "
305+
"`max_concurrent_batches` (%s) can underutilize each engine "
306+
"actor because Ray Data will submit fewer tasks than the actor "
307+
"can process concurrently.",
308+
self.max_tasks_in_flight_per_actor,
309+
self.max_concurrent_batches,
310+
)
311+
return self
312+
263313
@model_validator(mode="before")
264314
def _warn_prepare_image_stage_deprecation(
265315
cls, values: Dict[str, Any]

python/ray/llm/_internal/batch/processor/sglang_engine_proc.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
get_or_create_telemetry_agent,
1616
)
1717
from ray.llm._internal.batch.processor.base import (
18-
DEFAULT_MAX_TASKS_IN_FLIGHT,
1918
OfflineProcessorConfig,
2019
Processor,
2120
ProcessorBuilder,
@@ -184,9 +183,7 @@ def build_sglang_engine_processor(
184183
# saturate `max_concurrency`.
185184
compute=ray.data.ActorPoolStrategy(
186185
**config.get_concurrency(autoscaling_enabled=True),
187-
max_tasks_in_flight_per_actor=config.experimental.get(
188-
"max_tasks_in_flight_per_actor", DEFAULT_MAX_TASKS_IN_FLIGHT
189-
),
186+
max_tasks_in_flight_per_actor=config.max_tasks_in_flight_per_actor,
190187
),
191188
# The number of running batches "per actor" in Ray Core level.
192189
# This is used to make sure we overlap batches to avoid the tail

python/ray/llm/_internal/batch/processor/vllm_engine_proc.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
get_or_create_telemetry_agent,
1616
)
1717
from ray.llm._internal.batch.processor.base import (
18-
DEFAULT_MAX_TASKS_IN_FLIGHT,
1918
OfflineProcessorConfig,
2019
Processor,
2120
ProcessorBuilder,
@@ -284,9 +283,7 @@ def build_vllm_engine_processor(
284283
# saturate `max_concurrency`.
285284
compute=ray.data.ActorPoolStrategy(
286285
**config.get_concurrency(autoscaling_enabled=True),
287-
max_tasks_in_flight_per_actor=config.experimental.get(
288-
"max_tasks_in_flight_per_actor", DEFAULT_MAX_TASKS_IN_FLIGHT
289-
),
286+
max_tasks_in_flight_per_actor=config.max_tasks_in_flight_per_actor,
290287
),
291288
# The number of running batches "per actor" in Ray Core level.
292289
# This is used to make sure we overlap batches to avoid the tail

python/ray/llm/tests/batch/cpu/processor/test_processor_base.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import sys
22
from typing import Any, AsyncIterator, Dict, List, Type
3+
from unittest.mock import patch
34

45
import pydantic
56
import pytest
67

78
import ray
89
from ray.data.llm import build_processor
9-
from ray.llm._internal.batch.processor import vLLMEngineProcessorConfig
10+
from ray.llm._internal.batch.processor import (
11+
base as processor_base,
12+
vLLMEngineProcessorConfig,
13+
)
1014
from ray.llm._internal.batch.processor.base import (
1115
Processor,
1216
ProcessorBuilder,
@@ -386,6 +390,93 @@ def test_with_tuple_concurrency(self, pair, expected):
386390
assert conf.get_concurrency() == expected
387391

388392

393+
class TestOfflineProcessorConfig:
394+
@pytest.mark.parametrize(
395+
"kwargs, expected",
396+
[
397+
({"max_tasks_in_flight_per_actor": 10}, 10),
398+
({}, None),
399+
# Field stays None; the formula runs in Ray Data, not here.
400+
({"max_concurrent_batches": 4}, None),
401+
],
402+
)
403+
def test_max_tasks_in_flight_per_actor_passthrough(self, kwargs, expected):
404+
"""Field passes through to ActorPoolStrategy; None defers resolution."""
405+
config = vLLMEngineProcessorConfig(
406+
model_source="unsloth/Llama-3.2-1B-Instruct",
407+
**kwargs,
408+
)
409+
assert config.max_tasks_in_flight_per_actor == expected
410+
assert config.max_concurrent_batches == kwargs.get("max_concurrent_batches", 8)
411+
412+
def test_experimental_max_tasks_in_flight_per_actor_deprecated(self):
413+
"""Setting `experimental['max_tasks_in_flight_per_actor']` migrates to
414+
the top-level field with a deprecation log; the explicit top-level
415+
field overrides it but the warning still fires."""
416+
417+
def has_deprecation_log(warning_mock):
418+
return any(
419+
"max_tasks_in_flight_per_actor" in call.args[0]
420+
and "deprecated" in call.args[0]
421+
for call in warning_mock.call_args_list
422+
)
423+
424+
# Migration: experimental → top-level field.
425+
with patch.object(processor_base.logger, "warning") as warning_mock:
426+
cfg = vLLMEngineProcessorConfig(
427+
model_source="unsloth/Llama-3.2-1B-Instruct",
428+
experimental={"max_tasks_in_flight_per_actor": 10},
429+
)
430+
assert cfg.max_tasks_in_flight_per_actor == 10
431+
assert has_deprecation_log(warning_mock)
432+
433+
# Explicit top-level beats experimental, but warning still fires.
434+
with patch.object(processor_base.logger, "warning") as warning_mock:
435+
cfg = vLLMEngineProcessorConfig(
436+
model_source="unsloth/Llama-3.2-1B-Instruct",
437+
max_tasks_in_flight_per_actor=20,
438+
experimental={"max_tasks_in_flight_per_actor": 10},
439+
)
440+
assert cfg.max_tasks_in_flight_per_actor == 20
441+
assert has_deprecation_log(warning_mock)
442+
443+
def test_max_tasks_in_flight_under_max_concurrent_batches_warns(self):
444+
with patch.object(processor_base.logger, "warning") as warning_mock:
445+
cfg = vLLMEngineProcessorConfig(
446+
model_source="unsloth/Llama-3.2-1B-Instruct",
447+
max_tasks_in_flight_per_actor=1,
448+
max_concurrent_batches=8,
449+
)
450+
451+
assert cfg.max_tasks_in_flight_per_actor == 1
452+
assert cfg.max_concurrent_batches == 8
453+
warning_messages = [call.args[0] for call in warning_mock.call_args_list]
454+
assert any(
455+
"max_tasks_in_flight_per_actor" in message
456+
and "max_concurrent_batches" in message
457+
and "underutilize" in message
458+
for message in warning_messages
459+
)
460+
461+
@pytest.mark.parametrize(
462+
"kwargs",
463+
[
464+
{},
465+
{"max_tasks_in_flight_per_actor": 8, "max_concurrent_batches": 8},
466+
{"max_tasks_in_flight_per_actor": 16, "max_concurrent_batches": 8},
467+
],
468+
)
469+
def test_max_tasks_in_flight_does_not_warn_when_not_underutilized(self, kwargs):
470+
with patch.object(processor_base.logger, "warning") as warning_mock:
471+
vLLMEngineProcessorConfig(
472+
model_source="unsloth/Llama-3.2-1B-Instruct",
473+
**kwargs,
474+
)
475+
476+
warning_messages = [call.args[0] for call in warning_mock.call_args_list]
477+
assert not any("underutilize" in message for message in warning_messages)
478+
479+
389480
class TestMapKwargs:
390481
"""Tests for preprocess_map_kwargs and postprocess_map_kwargs."""
391482

python/ray/llm/tests/batch/gpu/processor/test_sglang_engine_proc.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
"""This test suite does not need sglang to be installed."""
22

33
import sys
4-
from unittest.mock import MagicMock, patch
4+
from unittest.mock import patch
55

66
import pytest
77

88
import ray
99
from ray.data.llm import SGLangEngineProcessorConfig
1010
from ray.llm._internal.batch.constants import SGLangTaskType
1111
from ray.llm._internal.batch.processor import ProcessorBuilder
12-
from ray.llm._internal.batch.processor.base import DEFAULT_MAX_TASKS_IN_FLIGHT
1312
from ray.llm._internal.batch.processor.sglang_engine_proc import (
1413
build_sglang_engine_processor,
1514
)
@@ -76,40 +75,6 @@ def test_sglang_engine_processor(gpu_type, model_llama_3_2_216M):
7675

7776

7877
class TestSGLangEngineProcessorConfig:
79-
@pytest.mark.parametrize(
80-
"experimental_config",
81-
[
82-
{"max_tasks_in_flight_per_actor": 10},
83-
{},
84-
],
85-
)
86-
def test_experimental_max_tasks_in_flight_per_actor_usage(
87-
self, experimental_config
88-
):
89-
"""Tests that max_tasks_in_flight_per_actor is set properly in the ActorPoolStrategy."""
90-
91-
with patch("ray.data.ActorPoolStrategy") as mock_actor_pool:
92-
mock_actor_pool.return_value = MagicMock()
93-
94-
config = SGLangEngineProcessorConfig(
95-
model_source="unsloth/Llama-3.2-1B-Instruct",
96-
experimental=experimental_config,
97-
)
98-
build_sglang_engine_processor(config)
99-
100-
mock_actor_pool.assert_called()
101-
call_kwargs = mock_actor_pool.call_args[1]
102-
if experimental_config:
103-
assert (
104-
call_kwargs["max_tasks_in_flight_per_actor"]
105-
== experimental_config["max_tasks_in_flight_per_actor"]
106-
)
107-
else:
108-
assert (
109-
call_kwargs["max_tasks_in_flight_per_actor"]
110-
== DEFAULT_MAX_TASKS_IN_FLIGHT
111-
)
112-
11378
def test_build_processor_autoconfig_failure_with_trust_remote_code(self):
11479
config = SGLangEngineProcessorConfig(
11580
model_source="nonexistent-org/nonexistent-model",

python/ray/llm/tests/batch/gpu/processor/test_vllm_engine_proc.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import sys
2-
from unittest.mock import MagicMock, patch
32

43
import pydantic
54
import pytest
@@ -692,46 +691,6 @@ def test_build_processor_autoconfig_failure(self):
692691
processor = build_processor(config)
693692
assert processor is not None
694693

695-
@pytest.mark.parametrize(
696-
"experimental_config",
697-
[
698-
{"max_tasks_in_flight_per_actor": 10},
699-
{},
700-
],
701-
)
702-
def test_experimental_max_tasks_in_flight_per_actor_usage(
703-
self, experimental_config
704-
):
705-
"""Tests that max_tasks_in_flight_per_actor is set properly in the ActorPoolStrategy."""
706-
707-
from ray.llm._internal.batch.processor.base import DEFAULT_MAX_TASKS_IN_FLIGHT
708-
from ray.llm._internal.batch.processor.vllm_engine_proc import (
709-
build_vllm_engine_processor,
710-
vLLMEngineProcessorConfig,
711-
)
712-
713-
with patch("ray.data.ActorPoolStrategy") as mock_actor_pool:
714-
mock_actor_pool.return_value = MagicMock()
715-
716-
config = vLLMEngineProcessorConfig(
717-
model_source="unsloth/Llama-3.2-1B-Instruct",
718-
experimental=experimental_config,
719-
)
720-
build_vllm_engine_processor(config)
721-
722-
mock_actor_pool.assert_called()
723-
call_kwargs = mock_actor_pool.call_args[1]
724-
if experimental_config:
725-
assert (
726-
call_kwargs["max_tasks_in_flight_per_actor"]
727-
== experimental_config["max_tasks_in_flight_per_actor"]
728-
)
729-
else:
730-
assert (
731-
call_kwargs["max_tasks_in_flight_per_actor"]
732-
== DEFAULT_MAX_TASKS_IN_FLIGHT
733-
)
734-
735694

736695
if __name__ == "__main__":
737696
sys.exit(pytest.main(["-v", __file__]))

0 commit comments

Comments
 (0)