|
1 | 1 | import sys |
2 | 2 | from typing import Any, AsyncIterator, Dict, List, Type |
| 3 | +from unittest.mock import patch |
3 | 4 |
|
4 | 5 | import pydantic |
5 | 6 | import pytest |
6 | 7 |
|
7 | 8 | import ray |
8 | 9 | from ray.data.llm import build_processor |
9 | | -from ray.llm._internal.batch.processor import vLLMEngineProcessorConfig |
| 10 | +from ray.llm._internal.batch.processor import ( |
| 11 | + base as processor_base, |
| 12 | + vLLMEngineProcessorConfig, |
| 13 | +) |
10 | 14 | from ray.llm._internal.batch.processor.base import ( |
11 | 15 | Processor, |
12 | 16 | ProcessorBuilder, |
@@ -386,6 +390,93 @@ def test_with_tuple_concurrency(self, pair, expected): |
386 | 390 | assert conf.get_concurrency() == expected |
387 | 391 |
|
388 | 392 |
|
| 393 | +class TestOfflineProcessorConfig: |
| 394 | + @pytest.mark.parametrize( |
| 395 | + "kwargs, expected", |
| 396 | + [ |
| 397 | + ({"max_tasks_in_flight_per_actor": 10}, 10), |
| 398 | + ({}, None), |
| 399 | + # Field stays None; the formula runs in Ray Data, not here. |
| 400 | + ({"max_concurrent_batches": 4}, None), |
| 401 | + ], |
| 402 | + ) |
| 403 | + def test_max_tasks_in_flight_per_actor_passthrough(self, kwargs, expected): |
| 404 | + """Field passes through to ActorPoolStrategy; None defers resolution.""" |
| 405 | + config = vLLMEngineProcessorConfig( |
| 406 | + model_source="unsloth/Llama-3.2-1B-Instruct", |
| 407 | + **kwargs, |
| 408 | + ) |
| 409 | + assert config.max_tasks_in_flight_per_actor == expected |
| 410 | + assert config.max_concurrent_batches == kwargs.get("max_concurrent_batches", 8) |
| 411 | + |
| 412 | + def test_experimental_max_tasks_in_flight_per_actor_deprecated(self): |
| 413 | + """Setting `experimental['max_tasks_in_flight_per_actor']` migrates to |
| 414 | + the top-level field with a deprecation log; the explicit top-level |
| 415 | + field overrides it but the warning still fires.""" |
| 416 | + |
| 417 | + def has_deprecation_log(warning_mock): |
| 418 | + return any( |
| 419 | + "max_tasks_in_flight_per_actor" in call.args[0] |
| 420 | + and "deprecated" in call.args[0] |
| 421 | + for call in warning_mock.call_args_list |
| 422 | + ) |
| 423 | + |
| 424 | + # Migration: experimental → top-level field. |
| 425 | + with patch.object(processor_base.logger, "warning") as warning_mock: |
| 426 | + cfg = vLLMEngineProcessorConfig( |
| 427 | + model_source="unsloth/Llama-3.2-1B-Instruct", |
| 428 | + experimental={"max_tasks_in_flight_per_actor": 10}, |
| 429 | + ) |
| 430 | + assert cfg.max_tasks_in_flight_per_actor == 10 |
| 431 | + assert has_deprecation_log(warning_mock) |
| 432 | + |
| 433 | + # Explicit top-level beats experimental, but warning still fires. |
| 434 | + with patch.object(processor_base.logger, "warning") as warning_mock: |
| 435 | + cfg = vLLMEngineProcessorConfig( |
| 436 | + model_source="unsloth/Llama-3.2-1B-Instruct", |
| 437 | + max_tasks_in_flight_per_actor=20, |
| 438 | + experimental={"max_tasks_in_flight_per_actor": 10}, |
| 439 | + ) |
| 440 | + assert cfg.max_tasks_in_flight_per_actor == 20 |
| 441 | + assert has_deprecation_log(warning_mock) |
| 442 | + |
| 443 | + def test_max_tasks_in_flight_under_max_concurrent_batches_warns(self): |
| 444 | + with patch.object(processor_base.logger, "warning") as warning_mock: |
| 445 | + cfg = vLLMEngineProcessorConfig( |
| 446 | + model_source="unsloth/Llama-3.2-1B-Instruct", |
| 447 | + max_tasks_in_flight_per_actor=1, |
| 448 | + max_concurrent_batches=8, |
| 449 | + ) |
| 450 | + |
| 451 | + assert cfg.max_tasks_in_flight_per_actor == 1 |
| 452 | + assert cfg.max_concurrent_batches == 8 |
| 453 | + warning_messages = [call.args[0] for call in warning_mock.call_args_list] |
| 454 | + assert any( |
| 455 | + "max_tasks_in_flight_per_actor" in message |
| 456 | + and "max_concurrent_batches" in message |
| 457 | + and "underutilize" in message |
| 458 | + for message in warning_messages |
| 459 | + ) |
| 460 | + |
| 461 | + @pytest.mark.parametrize( |
| 462 | + "kwargs", |
| 463 | + [ |
| 464 | + {}, |
| 465 | + {"max_tasks_in_flight_per_actor": 8, "max_concurrent_batches": 8}, |
| 466 | + {"max_tasks_in_flight_per_actor": 16, "max_concurrent_batches": 8}, |
| 467 | + ], |
| 468 | + ) |
| 469 | + def test_max_tasks_in_flight_does_not_warn_when_not_underutilized(self, kwargs): |
| 470 | + with patch.object(processor_base.logger, "warning") as warning_mock: |
| 471 | + vLLMEngineProcessorConfig( |
| 472 | + model_source="unsloth/Llama-3.2-1B-Instruct", |
| 473 | + **kwargs, |
| 474 | + ) |
| 475 | + |
| 476 | + warning_messages = [call.args[0] for call in warning_mock.call_args_list] |
| 477 | + assert not any("underutilize" in message for message in warning_messages) |
| 478 | + |
| 479 | + |
389 | 480 | class TestMapKwargs: |
390 | 481 | """Tests for preprocess_map_kwargs and postprocess_map_kwargs.""" |
391 | 482 |
|
|
0 commit comments