From f98d5aff84edbef9f4fbc9d3e7a651b387e820f0 Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Fri, 8 May 2026 02:42:54 +0000 Subject: [PATCH 01/26] Change default bundles constructed for TPU in LLM to per-host and fix tests Signed-off-by: Ryan O'Leary --- .../serve/core/configs/accelerators.py | 71 +++++++-- .../serve/engines/vllm/vllm_models.py | 14 ++ .../tests/serve/cpu/configs/test_models.py | 46 ++++++ .../deployments/llm/test_llm_engine_tpu.py | 139 ++++++++++++------ 4 files changed, 216 insertions(+), 54 deletions(-) diff --git a/python/ray/llm/_internal/serve/core/configs/accelerators.py b/python/ray/llm/_internal/serve/core/configs/accelerators.py index 3ec11d809dd4..35a72db7e82a 100644 --- a/python/ray/llm/_internal/serve/core/configs/accelerators.py +++ b/python/ray/llm/_internal/serve/core/configs/accelerators.py @@ -6,9 +6,14 @@ from typing_extensions import Annotated import ray.util.accelerators.accelerators as accelerators +from ray._private.accelerators.tpu import get_chips_per_host from ray.llm._internal.serve.observability.logging import get_logger from ray.util.placement_group import PlacementGroup, placement_group -from ray.util.tpu import get_tpu_version_from_type, slice_placement_group +from ray.util.tpu import ( + get_num_chips_from_topology, + get_tpu_version_from_type, + slice_placement_group, +) logger = get_logger(__name__) @@ -27,6 +32,21 @@ def format_ray_accelerator_resource(accelerator_type_str: str) -> str: return f"accelerator_type:{accelerator_type_str}" +def get_inferred_tensor_parallel_size(topology: Optional[str]) -> Optional[int]: + """Infers the tensor parallel size from the TPU topology.""" + if not topology: + return None + + try: + return get_num_chips_from_topology(topology) + except ValueError as e: + logger.warning( + f"Failed to infer tensor_parallel_size from topology '{topology}': {e}. " + "Defaulting to None." + ) + return None + + def infer_hardware_kind_from_bundles( placement_group_config: Optional[Dict[str, Any]] ) -> Optional[str]: @@ -180,10 +200,35 @@ def __init__(self, config: TPUConfig): def default_bundles( self, *, num_devices: int, accelerator_type_str: Optional[str] = None ): - bundle = {"TPU": 1} - if accelerator_type_str: - bundle[format_ray_accelerator_resource(accelerator_type_str)] = 0.001 - return [bundle.copy() for _ in range(num_devices)] + if not self._config.topology: + # Fallback to per-chip bundles if no topology is specified + bundle = {"TPU": 1} + if accelerator_type_str: + bundle[format_ray_accelerator_resource(accelerator_type_str)] = 0.001 + return [bundle.copy() for _ in range(num_devices)] + + # Topology is specified, compute per-host bundles + if not accelerator_type_str: + raise ValueError( + "`accelerator_type` must be specified when `topology` is present " + "in order to compute TPU resource requirements." + ) + version = get_tpu_version_from_type(accelerator_type_str) + chips_per_host = get_chips_per_host(self._config.topology, version) + + if num_devices > chips_per_host and num_devices % chips_per_host != 0: + raise ValueError( + f"num_devices ({num_devices}) must be a multiple of " + f"chips_per_host ({chips_per_host}) for TPU topologies." + ) + + num_hosts = max(1, num_devices // chips_per_host) + + tpu_resources = min(num_devices, chips_per_host) + bundle = {"TPU": tpu_resources} + bundle[format_ray_accelerator_resource(accelerator_type_str)] = 0.001 + + return [bundle.copy() for _ in range(num_hosts)] def create_placement_group( self, @@ -254,11 +299,15 @@ def requires_remote_initialization(self) -> bool: return True def get_remote_options(self, accelerator_type_str: str = None): - # TPUs use custom resource strings rather than a native kwarg - options: Dict[str, Any] = {"resources": {"TPU": 0.001}} - + # The PlacementGroupSchedulingStrategy natively handles routing the task to + # the correct hardware. We omit TPU resource requests to avoid consuming + # chips that the model engine workers must use. + options: Dict[str, Any] = {"resources": {}} if accelerator_type_str: - options["accelerator_type"] = accelerator_type_str + # Pin the task to the TPU accelerator to avoid scheduling on a CPU bundle. + options["label_selector"] = { + "ray.io/accelerator-type": accelerator_type_str + } return options def shutdown(self): @@ -270,7 +319,3 @@ def shutdown(self): logger.warning(f"Failed to shut down TPU slice PG: {e}") finally: self._slice_pg_wrapper = None - - def __del__(self): - """Ensure placement groups are cleaned up when this backend is garbage collected.""" - self.shutdown() diff --git a/python/ray/llm/_internal/serve/engines/vllm/vllm_models.py b/python/ray/llm/_internal/serve/engines/vllm/vllm_models.py index cfe6432ffdf4..95b06434d743 100644 --- a/python/ray/llm/_internal/serve/engines/vllm/vllm_models.py +++ b/python/ray/llm/_internal/serve/engines/vllm/vllm_models.py @@ -23,6 +23,7 @@ TPUAccelerator, TPUConfig, format_ray_accelerator_resource, + get_inferred_tensor_parallel_size, ) from ray.llm._internal.serve.core.configs.llm_config import ( AcceleratorType, @@ -193,6 +194,19 @@ def from_llm_config(cls, llm_config: LLMConfig) -> "VLLMEngineConfig": mirror_config = llm_config.model_loading_config.model_source all_engine_kwargs = llm_config.engine_kwargs.copy() + + # If tensor_parallel_size is not specified, try to infer it from topology + if "tensor_parallel_size" not in all_engine_kwargs: + if isinstance(llm_config.accelerator_config, TPUConfig): + total_chips = get_inferred_tensor_parallel_size( + llm_config.accelerator_config.topology + ) + if total_chips is not None: + all_engine_kwargs["tensor_parallel_size"] = total_chips + logger.info( + f"Inferred tensor_parallel_size={total_chips} from TPUConfig." + ) + engine_kwargs = {} frontend_kwargs = {} diff --git a/python/ray/llm/tests/serve/cpu/configs/test_models.py b/python/ray/llm/tests/serve/cpu/configs/test_models.py index ff35dca1fe9e..9ab80c947421 100644 --- a/python/ray/llm/tests/serve/cpu/configs/test_models.py +++ b/python/ray/llm/tests/serve/cpu/configs/test_models.py @@ -417,6 +417,52 @@ def test_requires_deferred_placement_group(self): tpu_accel_with_topo = TPUAccelerator(TPUConfig(kind="tpu", topology="4x4")) assert tpu_accel_with_topo.requires_deferred_placement_group is True + @pytest.mark.parametrize( + "topology,num_devices,accelerator_type_str,expected_bundles_count,expected_chips_per_host", + [ + ("1x1", 1, "TPU-V6E", 1, 1), + ("1x1", 1, "TPU-V7X", 1, 1), + ("4x4", 16, "TPU-V6E", 4, 4), + ("2x2x2", 8, "TPU-V5P", 2, 4), + ("2x2", 4, "TPU-V5LITEPOD", 1, 4), + ("2x2x1", 4, "TPU-V4", 1, 4), + ("2x4", 8, "TPU-V6E", 1, 8), + ], + ) + def test_default_bundles_topology( + self, + topology, + num_devices, + accelerator_type_str, + expected_bundles_count, + expected_chips_per_host, + ): + """Test that different topologies return correct per-host bundles.""" + tpu_accel = TPUAccelerator(TPUConfig(kind="tpu", topology=topology)) + bundles = tpu_accel.default_bundles( + num_devices=num_devices, accelerator_type_str=accelerator_type_str + ) + + assert len(bundles) == expected_bundles_count + for bundle in bundles: + assert bundle["TPU"] == expected_chips_per_host + assert f"accelerator_type:{accelerator_type_str}" in bundle + + def test_default_bundles_topology_missing_accelerator_type_raises(self): + """Test that ValueError is raised when topology is present but accelerator type is missing.""" + tpu_accel = TPUAccelerator(TPUConfig(kind="tpu", topology="4x4")) + with pytest.raises( + ValueError, + match="`accelerator_type` must be specified when `topology` is present", + ): + tpu_accel.default_bundles(num_devices=16, accelerator_type_str=None) + + def test_default_bundles_topology_non_multiple_num_devices_raises(self): + """Test that ValueError is raised when num_devices is not a multiple of chips_per_host.""" + tpu_accel = TPUAccelerator(TPUConfig(kind="tpu", topology="4x4")) + with pytest.raises(ValueError, match="must be a multiple of chips_per_host"): + tpu_accel.default_bundles(num_devices=6, accelerator_type_str="TPU-V6E") + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_engine_tpu.py b/python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_engine_tpu.py index 10fe2bfdf14d..ec10690de68b 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_engine_tpu.py +++ b/python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_engine_tpu.py @@ -26,29 +26,33 @@ def test_tpu_slice_placement_group_creation_default_resources(ray_tpu_cluster): llm_config = LLMConfig( model_loading_config=ModelLoadingConfig(model_id="test-tpu-model"), accelerator_type="TPU-V6E", - accelerator_config={"kind": "tpu", "topology": "4x4"}, + accelerator_config=TPUConfig(kind="tpu", topology="4x4"), ) engine_config = llm_config.get_engine_config() - pg = engine_config.get_or_create_pg() - assert isinstance(pg, PlacementGroup) + pg = None + try: + pg = engine_config.get_or_create_pg() - pg_table = placement_group_table(pg) - assert pg_table["strategy"] == "PACK" + assert isinstance(pg, PlacementGroup) - # 4x4 v6e = 16 chips. We default to 1 TPU chip per bundle. - assert len(pg_table["bundles"]) == 16 - for bundle in pg_table["bundles"].values(): - assert "TPU" in bundle - assert bundle["TPU"] == 1 + pg_table = placement_group_table(pg) + assert pg_table["strategy"] == "PACK" - # Let the backend tear down its own resources if it has any - engine_config.accelerator.shutdown() - try: - ray.util.remove_placement_group(pg) - except Exception: - pass # Already cleaned up by the wrapper + # 4x4 v6e = 16 chips. We default to 4 TPU chips per bundle (per-host). + assert len(pg_table["bundles"]) == 4 + for bundle in pg_table["bundles"].values(): + assert "TPU" in bundle + assert bundle["TPU"] == 4.0 + finally: + # Let the backend tear down its own resources if it has any + engine_config.accelerator.shutdown() + if pg is not None: + try: + ray.util.remove_placement_group(pg) + except Exception: + pass def test_tpu_slice_placement_group_creation_host_resources(ray_tpu_cluster): @@ -59,32 +63,36 @@ def test_tpu_slice_placement_group_creation_host_resources(ray_tpu_cluster): llm_config = LLMConfig( model_loading_config=ModelLoadingConfig(model_id="test-tpu-model"), accelerator_type="TPU-V6E", - accelerator_config={"kind": "tpu", "topology": "4x4"}, + accelerator_config=TPUConfig(kind="tpu", topology="4x4"), placement_group_config={ "strategy": "STRICT_SPREAD", - "bundles": [{"TPU": 4}], + "bundles": [{"TPU": 4}] * 4, }, ) engine_config = llm_config.get_engine_config() - pg = engine_config.get_or_create_pg() - - assert isinstance(pg, PlacementGroup) - pg_table = placement_group_table(pg) - assert pg_table["strategy"] == "STRICT_SPREAD" - # We should provision 4 host-level bundles instead of the default 16 chip-level bundles. - assert len(pg_table["bundles"]) == 4 - for bundle in pg_table["bundles"].values(): - assert "TPU" in bundle - assert bundle["TPU"] == 4 - - # Let the backend tear down its own resources if it has any - engine_config.accelerator.shutdown() + pg = None try: - ray.util.remove_placement_group(pg) - except Exception: - pass # Already cleaned up by the wrapper + pg = engine_config.get_or_create_pg() + + assert isinstance(pg, PlacementGroup) + + pg_table = placement_group_table(pg) + assert pg_table["strategy"] == "STRICT_SPREAD" + # We should provision 4 host-level bundles instead of the default 16 chip-level bundles. + assert len(pg_table["bundles"]) == 4 + for bundle in pg_table["bundles"].values(): + assert "TPU" in bundle + assert bundle["TPU"] == 4 + finally: + # Let the backend tear down its own resources if it has any + engine_config.accelerator.shutdown() + if pg is not None: + try: + ray.util.remove_placement_group(pg) + except Exception: + pass def test_single_tpu_fallback(ray_tpu_cluster): @@ -221,15 +229,17 @@ def test_tpu_slice_placement_group_creation_cpu_driver_homogeneous_tpu_bundles_p pass -def test_tpu_serve_deployment_default_chip_level_bundles(ray_tpu_cluster): +def test_tpu_serve_deployment_default_host_level_bundles(ray_tpu_cluster): """ Verifies that a Serve deployment created for a multi-host TPU slice defaults - to chip-level bundles when no placement_group_config is specified. + to host-level bundles when no placement_group_config is specified. """ + from ray.llm._internal.serve.core.configs.accelerators import TPUConfig + llm_config = LLMConfig( model_loading_config=ModelLoadingConfig(model_id="test-tpu-model"), accelerator_type="TPU-V6E", - accelerator_config={"kind": "tpu", "topology": "4x4"}, + accelerator_config=TPUConfig(kind="tpu", topology="4x4"), ) app = serve.deployment(LLMServer).bind(llm_config, engine_cls=PGCreationMockEngine) @@ -256,10 +266,10 @@ def test_tpu_serve_deployment_default_chip_level_bundles(ray_tpu_cluster): worker_pg = [pg for pg in active_pgs if pg not in head_pgs][0] assert worker_pg["strategy"] == "PACK" - # 4x4 topology = 16 chips. Default is 16 bundles of 1 TPU. - assert len(worker_pg["bundles"]) == 16 + # 4x4 topology = 16 chips. Default is 4 bundles of 4 TPUs (per-host). + assert len(worker_pg["bundles"]) == 4 for bundle in worker_pg["bundles"].values(): - assert bundle.get("TPU", 0) == 1 + assert bundle.get("TPU", 0) == 4.0 serve.shutdown() @@ -272,7 +282,7 @@ def test_tpu_serve_deployment_explicit_host_level_bundles(ray_tpu_cluster): llm_config = LLMConfig( model_loading_config=ModelLoadingConfig(model_id="test-tpu-model"), accelerator_type="TPU-V6E", - accelerator_config={"kind": "tpu", "topology": "4x4"}, + accelerator_config=TPUConfig(kind="tpu", topology="4x4"), placement_group_config={"bundle_per_worker": {"TPU": 4}}, ) @@ -308,5 +318,52 @@ def test_tpu_serve_deployment_explicit_host_level_bundles(ray_tpu_cluster): serve.shutdown() +def test_tpu_serve_deployment_explicit_per_chip_bundles(ray_tpu_cluster): + """ + Verifies that a user can explicitly request chip-level bundles (1 TPU per bundle) + for a full multi-host TPU slice via placement_group_config. + """ + from ray.llm._internal.serve.core.configs.accelerators import TPUConfig + + llm_config = LLMConfig( + model_loading_config=ModelLoadingConfig(model_id="test-tpu-model"), + accelerator_type="TPU-V6E", + accelerator_config=TPUConfig(kind="tpu", topology="4x4"), + placement_group_config={"bundle_per_worker": {"TPU": 1}}, + engine_kwargs={"tensor_parallel_size": 16}, + ) + + app = serve.deployment(LLMServer).bind(llm_config, engine_cls=PGCreationMockEngine) + serve.run(app) + + pg_table = ray.util.placement_group_table() + active_pgs = list( + {k: v for k, v in pg_table.items() if v["state"] == "CREATED"}.values() + ) + + assert ( + len(active_pgs) == 2 + ), "Expected 2 PGs - one for TPU Head, one for worker bundles" + + tpu_head_resource = "TPU-v6e-16-head" + head_pgs = [ + pg + for pg in active_pgs + if len(pg["bundles"]) == 1 + and tpu_head_resource in list(pg["bundles"].values())[0] + ] + assert len(head_pgs) == 1 + + worker_pg = [pg for pg in active_pgs if pg not in head_pgs][0] + + assert worker_pg["strategy"] == "PACK" + # 4x4 topology = 16 chips. Explicitly requested 16 bundles of 1 TPU. + assert len(worker_pg["bundles"]) == 16 + for bundle in worker_pg["bundles"].values(): + assert bundle.get("TPU", 0) == 1.0 + + serve.shutdown() + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) From 6a1511d507ed206cf6ea0e8a7caae3ff3ddb9b8b Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Wed, 6 May 2026 20:20:37 +0000 Subject: [PATCH 02/26] Improve lifecycle handling of SlicePlacementGroup and support explicit `bundle_label_selector` Signed-off-by: Ryan O'Leary --- python/ray/tests/test_tpu.py | 109 +++++++++++++++++++++++++++++++++++ python/ray/util/tpu.py | 97 +++++++++++++++++++++++++------ 2 files changed, 188 insertions(+), 18 deletions(-) diff --git a/python/ray/tests/test_tpu.py b/python/ray/tests/test_tpu.py index 3400cbbb6255..2a7e9b7bddeb 100644 --- a/python/ray/tests/test_tpu.py +++ b/python/ray/tests/test_tpu.py @@ -839,5 +839,114 @@ def test_slice_placement_group_chips_per_vm_override(ray_v6e_tpu_cluster): assert override_pg.bundle_resources["TPU"] == 4 +def test_user_bundle_label_selector_merged(ray_tpu_cluster): + """Verifies that user-passed bundle_label_selector is merged with dynamic TPU labels.""" + user_selectors = [{"env": "prod"}, {"env": "test"}] + + # 2x2x2 v4 = 2 hosts = 2 bundles + slice_pg = SlicePlacementGroup( + topology="2x2x2", accelerator_version="v4", bundle_label_selector=user_selectors + ) + + assert len(slice_pg._bundle_label_selector) == 2 + + # Verify slice 0 + assert slice_pg._bundle_label_selector[0]["env"] == "prod" + assert ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY in slice_pg._bundle_label_selector[0] + + # Verify slice 1 + assert slice_pg._bundle_label_selector[1]["env"] == "test" + assert ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY in slice_pg._bundle_label_selector[1] + + +def test_user_bundle_label_selector_collision_dynamic_wins(ray_v6e_tpu_cluster): + """Verifies that dynamic TPU labels take precedence on collision.""" + user_selectors = [{ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY: "user-requested-slice"}] + + # v6e-8 is single host (1 bundle) + slice_pg = SlicePlacementGroup( + topology="2x4", accelerator_version="v6e", bundle_label_selector=user_selectors + ) + + assert len(slice_pg._bundle_label_selector) == 1 + # The dynamic value should win (it generates test-v6e-slice-N) + actual_val = slice_pg._bundle_label_selector[0][ + ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY + ] + assert actual_val != "user-requested-slice" + assert "test-v6e-slice-" in actual_val + + +def test_user_bundle_label_selector_length_mismatch_raises(): + """Verifies that providing wrong length of selector list raises ValueError.""" + user_selectors = [{"env": "prod"}] # Only 1 provided but 2x2x2 v4 has 2 hosts + + with pytest.raises(ValueError, match="bundle_label_selector length"): + SlicePlacementGroup( + topology="2x2x2", + accelerator_version="v4", + bundle_label_selector=user_selectors, + ) + + +def test_release_head_pgs_idempotent(ray_tpu_cluster): + """Verifies that release_head_pgs() is idempotent.""" + slice_pg = SlicePlacementGroup(topology="2x2x2", accelerator_version="v4") + + assert len(slice_pg.head_placement_groups) == 1 + + slice_pg.release_head_pgs() + assert len(slice_pg.head_placement_groups) == 0 + + # Call again, should not raise + slice_pg.release_head_pgs() + assert len(slice_pg.head_placement_groups) == 0 + + +def test_shutdown_idempotent(ray_tpu_cluster): + """Verifies that shutdown() is idempotent.""" + slice_pg = SlicePlacementGroup(topology="2x2x2", accelerator_version="v4") + + slice_pg.shutdown() + assert slice_pg.placement_group is None + assert len(slice_pg.head_placement_groups) == 0 + + # Call again, should not raise + slice_pg.shutdown() + + +def test_shutdown_safe_after_construction_failure(): + """Verifies that shutdown() is safe to call on a partially-constructed instance.""" + with patch( + "ray.util.tpu.SlicePlacementGroup._reserve_slice", + side_effect=RuntimeError("Test failure"), + ): + with pytest.raises(RuntimeError, match="Test failure"): + SlicePlacementGroup(topology="2x2x2", accelerator_version="v4") + + # If the above didn't crash or leak resources, we are good. + # We can also manually construct a partial instance and call shutdown. + partial_pg = SlicePlacementGroup.__new__(SlicePlacementGroup) + partial_pg._head_pgs = [] + partial_pg._placement_group = None + + # Should not raise even though it's missing attributes + partial_pg.shutdown() + + +def test_release_head_pgs_after_ready_then_shutdown(ray_tpu_cluster): + """Validates Slice PG lifecycle: wait until ready, release head PGs, then shutdown.""" + slice_pg = SlicePlacementGroup(topology="2x2x2", accelerator_version="v4") + + # Wait for ready + ray.get(slice_pg.placement_group.ready()) + + slice_pg.release_head_pgs() + assert len(slice_pg.head_placement_groups) == 0 + + slice_pg.shutdown() + assert slice_pg.placement_group is None + + if __name__ == "__main__": sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/util/tpu.py b/python/ray/util/tpu.py index c311771b091f..615d4461e0f6 100644 --- a/python/ray/util/tpu.py +++ b/python/ray/util/tpu.py @@ -160,8 +160,10 @@ def get_tpu_worker_resources( """ accelerator_version = get_tpu_version_from_type(accelerator_type) - resolved_chips_per_vm = chips_per_vm or get_chips_per_host( - topology, accelerator_version + resolved_chips_per_vm = ( + chips_per_vm + if chips_per_vm is not None + else get_chips_per_host(topology, accelerator_version) ) total_chips_per_slice = get_num_chips_from_topology(topology) @@ -447,6 +449,8 @@ class SlicePlacementGroup: TPU head placement group to become ready. Defaults to ``DEFAULT_TPU_HEAD_RESERVATION_TIMEOUT_S``. Pass ``None`` to wait indefinitely. + bundle_label_selector: Optional list of label selectors to apply per bundle. These label + selectors are applied in addition to dynamic TPU slice name labels, which take precedence. Examples: @@ -490,7 +494,13 @@ def __init__( head_reservation_timeout_s: Optional[float] = ( DEFAULT_TPU_HEAD_RESERVATION_TIMEOUT_S ), + bundle_label_selector: Optional[List[Dict[str, str]]] = None, ): + self._head_pgs: List[PlacementGroup] = [] + self._bundle_label_selector: List[Dict[str, str]] = [] + self._placement_group: Optional[PlacementGroup] = None + self._user_bundle_label_selector = bundle_label_selector or [] + self._topology = topology.strip().lower() self._accelerator_version = get_tpu_version_from_type( accelerator_version.strip() @@ -508,8 +518,10 @@ def __init__( chips_per_vm=chips_per_vm, ) - self._chips_per_host = chips_per_vm or get_chips_per_host( - self._topology, self._accelerator_version + self._chips_per_host = ( + chips_per_vm + if chips_per_vm is not None + else get_chips_per_host(self._topology, self._accelerator_version) ) # Within Ray, a "host" corresponds to a user-visible compute VM. @@ -518,10 +530,7 @@ def __init__( hosts_per_slice = max(1, total_chips // self._chips_per_host) self._num_hosts = hosts_per_slice * self._num_slices - self._head_pgs: List[PlacementGroup] = [] - self._bundle_label_selector: List[Dict[str, str]] = [] self._validate_tpu_config() - self._placement_group = None # Reserve a TPU slice of the provided accelerator version and topology. self._placement_group = self._reserve_slice( @@ -549,6 +558,15 @@ def _reserve_slice( lifetime: Optional[str] = None, ) -> PlacementGroup: """Performs the two-step scheduling to reserve a TPU slice.""" + if ( + self._user_bundle_label_selector + and len(self._user_bundle_label_selector) != self._num_bundles + ): + raise ValueError( + f"bundle_label_selector length ({len(self._user_bundle_label_selector)}) must " + f"match the number of bundles ({self._num_bundles})." + ) + self._bundle_label_selector = [] bundles = [] bundles_per_slice = self._num_bundles // self._num_slices @@ -557,7 +575,7 @@ def _reserve_slice( accelerator_type = "TPU-" + self.accelerator_version.upper() try: - for _ in range(self.num_slices): + for slice_idx in range(self.num_slices): reservation = reserve_tpu_slice( self._topology, accelerator_type, @@ -575,10 +593,20 @@ def _reserve_slice( slice_name, head_pg = reservation self._head_pgs.append(head_pg) - # Reserving a slice is done through constructing num_hosts bundles, each with a label selector for - # the unique name of an available TPU slice. - selector = {ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY: slice_name} - self._bundle_label_selector.extend([selector] * bundles_per_slice) + dynamic_labels = {ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY: slice_name} + + for bundle_idx in range(bundles_per_slice): + global_bundle_idx = slice_idx * bundles_per_slice + bundle_idx + + user_labels = ( + self._user_bundle_label_selector[global_bundle_idx] + if global_bundle_idx < len(self._user_bundle_label_selector) + else {} + ) + # Dynamic TPU slice labels take precedence; user labels fill in the rest. + merged_labels = {**user_labels, **dynamic_labels} + self._bundle_label_selector.append(merged_labels) + bundles += [ self._bundle_resources.copy() for _ in range(bundles_per_slice) ] @@ -647,14 +675,47 @@ def bundle_resources(self) -> Dict[str, float]: """The resources that are assigned to each bundle.""" return self._bundle_resources + @DeveloperAPI(stability="alpha") + def release_head_pgs(self) -> None: + """Remove all internal head placement groups. + + The head PGs exist only to atomically claim a TPU slice's label during + the race window between slice selection and worker-PG construction. + Once the worker PG's bundles are scheduled, the worker PG holds the TPU + resources on every host in the slice and the head PGs are redundant. + + Callers should invoke this idempotent call after `self.placement_group.ready()` + resolves successfully. + """ + head_pgs = getattr(self, "_head_pgs", []) + self._head_pgs = [] + for head_pg in head_pgs: + try: + remove_placement_group(head_pg) + except Exception: + logger.exception( + "Failed to remove TPU head placement group %s; the " + "slice reservation marker may leak until the creator " + "process exits.", + getattr(head_pg, "id", head_pg), + ) + def shutdown(self): - """Removes the worker placement group and all internal head PGs.""" - if self._placement_group: - remove_placement_group(self._placement_group) + """Remove the worker placement group and all internal head PGs. + + Idempotent. Safe to call on a partially-constructed instance. + """ + worker_pg = getattr(self, "_placement_group", None) + if worker_pg is not None: self._placement_group = None - for head_pg in self._head_pgs: - remove_placement_group(head_pg) - self._head_pgs = [] + try: + remove_placement_group(worker_pg) + except Exception: + logger.exception( + "Failed to remove TPU worker placement group %s.", + getattr(worker_pg, "id", worker_pg), + ) + self.release_head_pgs() @PublicAPI(stability="alpha") From 5ec15c00b319eb6ed67d760b7bbe281129db36b5 Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Fri, 8 May 2026 02:41:17 +0000 Subject: [PATCH 03/26] Add AcceleratorConfig to Serve and fix gang scheduling Signed-off-by: Ryan O'Leary --- .../serve/core/configs/accelerators.py | 71 ++------- .../serve/engines/vllm/vllm_models.py | 14 -- .../tests/serve/cpu/configs/test_models.py | 46 ------ .../deployments/llm/test_llm_engine_tpu.py | 139 ++++++------------ python/ray/serve/_private/common.py | 12 +- python/ray/serve/_private/config.py | 12 ++ python/ray/serve/_private/default_impl.py | 96 +++++++++++- .../serve/_private/deployment_scheduler.py | 53 ++++++- python/ray/serve/_private/deployment_state.py | 102 +++++++++++-- python/ray/serve/api.py | 26 ++++ python/ray/serve/config.py | 51 ++++++- python/ray/serve/tests/BUILD.bazel | 1 + .../serve/tests/test_accelerator_config.py | 107 ++++++++++++++ python/ray/serve/tests/unit/BUILD.bazel | 1 + .../tests/unit/test_accelerator_config.py | 119 +++++++++++++++ .../serve/tests/unit/test_deployment_state.py | 68 +++++++++ src/ray/protobuf/serve.proto | 3 + 17 files changed, 676 insertions(+), 245 deletions(-) create mode 100644 python/ray/serve/tests/test_accelerator_config.py create mode 100644 python/ray/serve/tests/unit/test_accelerator_config.py diff --git a/python/ray/llm/_internal/serve/core/configs/accelerators.py b/python/ray/llm/_internal/serve/core/configs/accelerators.py index 35a72db7e82a..3ec11d809dd4 100644 --- a/python/ray/llm/_internal/serve/core/configs/accelerators.py +++ b/python/ray/llm/_internal/serve/core/configs/accelerators.py @@ -6,14 +6,9 @@ from typing_extensions import Annotated import ray.util.accelerators.accelerators as accelerators -from ray._private.accelerators.tpu import get_chips_per_host from ray.llm._internal.serve.observability.logging import get_logger from ray.util.placement_group import PlacementGroup, placement_group -from ray.util.tpu import ( - get_num_chips_from_topology, - get_tpu_version_from_type, - slice_placement_group, -) +from ray.util.tpu import get_tpu_version_from_type, slice_placement_group logger = get_logger(__name__) @@ -32,21 +27,6 @@ def format_ray_accelerator_resource(accelerator_type_str: str) -> str: return f"accelerator_type:{accelerator_type_str}" -def get_inferred_tensor_parallel_size(topology: Optional[str]) -> Optional[int]: - """Infers the tensor parallel size from the TPU topology.""" - if not topology: - return None - - try: - return get_num_chips_from_topology(topology) - except ValueError as e: - logger.warning( - f"Failed to infer tensor_parallel_size from topology '{topology}': {e}. " - "Defaulting to None." - ) - return None - - def infer_hardware_kind_from_bundles( placement_group_config: Optional[Dict[str, Any]] ) -> Optional[str]: @@ -200,35 +180,10 @@ def __init__(self, config: TPUConfig): def default_bundles( self, *, num_devices: int, accelerator_type_str: Optional[str] = None ): - if not self._config.topology: - # Fallback to per-chip bundles if no topology is specified - bundle = {"TPU": 1} - if accelerator_type_str: - bundle[format_ray_accelerator_resource(accelerator_type_str)] = 0.001 - return [bundle.copy() for _ in range(num_devices)] - - # Topology is specified, compute per-host bundles - if not accelerator_type_str: - raise ValueError( - "`accelerator_type` must be specified when `topology` is present " - "in order to compute TPU resource requirements." - ) - version = get_tpu_version_from_type(accelerator_type_str) - chips_per_host = get_chips_per_host(self._config.topology, version) - - if num_devices > chips_per_host and num_devices % chips_per_host != 0: - raise ValueError( - f"num_devices ({num_devices}) must be a multiple of " - f"chips_per_host ({chips_per_host}) for TPU topologies." - ) - - num_hosts = max(1, num_devices // chips_per_host) - - tpu_resources = min(num_devices, chips_per_host) - bundle = {"TPU": tpu_resources} - bundle[format_ray_accelerator_resource(accelerator_type_str)] = 0.001 - - return [bundle.copy() for _ in range(num_hosts)] + bundle = {"TPU": 1} + if accelerator_type_str: + bundle[format_ray_accelerator_resource(accelerator_type_str)] = 0.001 + return [bundle.copy() for _ in range(num_devices)] def create_placement_group( self, @@ -299,15 +254,11 @@ def requires_remote_initialization(self) -> bool: return True def get_remote_options(self, accelerator_type_str: str = None): - # The PlacementGroupSchedulingStrategy natively handles routing the task to - # the correct hardware. We omit TPU resource requests to avoid consuming - # chips that the model engine workers must use. - options: Dict[str, Any] = {"resources": {}} + # TPUs use custom resource strings rather than a native kwarg + options: Dict[str, Any] = {"resources": {"TPU": 0.001}} + if accelerator_type_str: - # Pin the task to the TPU accelerator to avoid scheduling on a CPU bundle. - options["label_selector"] = { - "ray.io/accelerator-type": accelerator_type_str - } + options["accelerator_type"] = accelerator_type_str return options def shutdown(self): @@ -319,3 +270,7 @@ def shutdown(self): logger.warning(f"Failed to shut down TPU slice PG: {e}") finally: self._slice_pg_wrapper = None + + def __del__(self): + """Ensure placement groups are cleaned up when this backend is garbage collected.""" + self.shutdown() diff --git a/python/ray/llm/_internal/serve/engines/vllm/vllm_models.py b/python/ray/llm/_internal/serve/engines/vllm/vllm_models.py index 95b06434d743..cfe6432ffdf4 100644 --- a/python/ray/llm/_internal/serve/engines/vllm/vllm_models.py +++ b/python/ray/llm/_internal/serve/engines/vllm/vllm_models.py @@ -23,7 +23,6 @@ TPUAccelerator, TPUConfig, format_ray_accelerator_resource, - get_inferred_tensor_parallel_size, ) from ray.llm._internal.serve.core.configs.llm_config import ( AcceleratorType, @@ -194,19 +193,6 @@ def from_llm_config(cls, llm_config: LLMConfig) -> "VLLMEngineConfig": mirror_config = llm_config.model_loading_config.model_source all_engine_kwargs = llm_config.engine_kwargs.copy() - - # If tensor_parallel_size is not specified, try to infer it from topology - if "tensor_parallel_size" not in all_engine_kwargs: - if isinstance(llm_config.accelerator_config, TPUConfig): - total_chips = get_inferred_tensor_parallel_size( - llm_config.accelerator_config.topology - ) - if total_chips is not None: - all_engine_kwargs["tensor_parallel_size"] = total_chips - logger.info( - f"Inferred tensor_parallel_size={total_chips} from TPUConfig." - ) - engine_kwargs = {} frontend_kwargs = {} diff --git a/python/ray/llm/tests/serve/cpu/configs/test_models.py b/python/ray/llm/tests/serve/cpu/configs/test_models.py index 9ab80c947421..ff35dca1fe9e 100644 --- a/python/ray/llm/tests/serve/cpu/configs/test_models.py +++ b/python/ray/llm/tests/serve/cpu/configs/test_models.py @@ -417,52 +417,6 @@ def test_requires_deferred_placement_group(self): tpu_accel_with_topo = TPUAccelerator(TPUConfig(kind="tpu", topology="4x4")) assert tpu_accel_with_topo.requires_deferred_placement_group is True - @pytest.mark.parametrize( - "topology,num_devices,accelerator_type_str,expected_bundles_count,expected_chips_per_host", - [ - ("1x1", 1, "TPU-V6E", 1, 1), - ("1x1", 1, "TPU-V7X", 1, 1), - ("4x4", 16, "TPU-V6E", 4, 4), - ("2x2x2", 8, "TPU-V5P", 2, 4), - ("2x2", 4, "TPU-V5LITEPOD", 1, 4), - ("2x2x1", 4, "TPU-V4", 1, 4), - ("2x4", 8, "TPU-V6E", 1, 8), - ], - ) - def test_default_bundles_topology( - self, - topology, - num_devices, - accelerator_type_str, - expected_bundles_count, - expected_chips_per_host, - ): - """Test that different topologies return correct per-host bundles.""" - tpu_accel = TPUAccelerator(TPUConfig(kind="tpu", topology=topology)) - bundles = tpu_accel.default_bundles( - num_devices=num_devices, accelerator_type_str=accelerator_type_str - ) - - assert len(bundles) == expected_bundles_count - for bundle in bundles: - assert bundle["TPU"] == expected_chips_per_host - assert f"accelerator_type:{accelerator_type_str}" in bundle - - def test_default_bundles_topology_missing_accelerator_type_raises(self): - """Test that ValueError is raised when topology is present but accelerator type is missing.""" - tpu_accel = TPUAccelerator(TPUConfig(kind="tpu", topology="4x4")) - with pytest.raises( - ValueError, - match="`accelerator_type` must be specified when `topology` is present", - ): - tpu_accel.default_bundles(num_devices=16, accelerator_type_str=None) - - def test_default_bundles_topology_non_multiple_num_devices_raises(self): - """Test that ValueError is raised when num_devices is not a multiple of chips_per_host.""" - tpu_accel = TPUAccelerator(TPUConfig(kind="tpu", topology="4x4")) - with pytest.raises(ValueError, match="must be a multiple of chips_per_host"): - tpu_accel.default_bundles(num_devices=6, accelerator_type_str="TPU-V6E") - if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_engine_tpu.py b/python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_engine_tpu.py index ec10690de68b..10fe2bfdf14d 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_engine_tpu.py +++ b/python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_engine_tpu.py @@ -26,33 +26,29 @@ def test_tpu_slice_placement_group_creation_default_resources(ray_tpu_cluster): llm_config = LLMConfig( model_loading_config=ModelLoadingConfig(model_id="test-tpu-model"), accelerator_type="TPU-V6E", - accelerator_config=TPUConfig(kind="tpu", topology="4x4"), + accelerator_config={"kind": "tpu", "topology": "4x4"}, ) engine_config = llm_config.get_engine_config() + pg = engine_config.get_or_create_pg() - pg = None - try: - pg = engine_config.get_or_create_pg() + assert isinstance(pg, PlacementGroup) - assert isinstance(pg, PlacementGroup) + pg_table = placement_group_table(pg) + assert pg_table["strategy"] == "PACK" - pg_table = placement_group_table(pg) - assert pg_table["strategy"] == "PACK" + # 4x4 v6e = 16 chips. We default to 1 TPU chip per bundle. + assert len(pg_table["bundles"]) == 16 + for bundle in pg_table["bundles"].values(): + assert "TPU" in bundle + assert bundle["TPU"] == 1 - # 4x4 v6e = 16 chips. We default to 4 TPU chips per bundle (per-host). - assert len(pg_table["bundles"]) == 4 - for bundle in pg_table["bundles"].values(): - assert "TPU" in bundle - assert bundle["TPU"] == 4.0 - finally: - # Let the backend tear down its own resources if it has any - engine_config.accelerator.shutdown() - if pg is not None: - try: - ray.util.remove_placement_group(pg) - except Exception: - pass + # Let the backend tear down its own resources if it has any + engine_config.accelerator.shutdown() + try: + ray.util.remove_placement_group(pg) + except Exception: + pass # Already cleaned up by the wrapper def test_tpu_slice_placement_group_creation_host_resources(ray_tpu_cluster): @@ -63,36 +59,32 @@ def test_tpu_slice_placement_group_creation_host_resources(ray_tpu_cluster): llm_config = LLMConfig( model_loading_config=ModelLoadingConfig(model_id="test-tpu-model"), accelerator_type="TPU-V6E", - accelerator_config=TPUConfig(kind="tpu", topology="4x4"), + accelerator_config={"kind": "tpu", "topology": "4x4"}, placement_group_config={ "strategy": "STRICT_SPREAD", - "bundles": [{"TPU": 4}] * 4, + "bundles": [{"TPU": 4}], }, ) engine_config = llm_config.get_engine_config() + pg = engine_config.get_or_create_pg() + + assert isinstance(pg, PlacementGroup) - pg = None + pg_table = placement_group_table(pg) + assert pg_table["strategy"] == "STRICT_SPREAD" + # We should provision 4 host-level bundles instead of the default 16 chip-level bundles. + assert len(pg_table["bundles"]) == 4 + for bundle in pg_table["bundles"].values(): + assert "TPU" in bundle + assert bundle["TPU"] == 4 + + # Let the backend tear down its own resources if it has any + engine_config.accelerator.shutdown() try: - pg = engine_config.get_or_create_pg() - - assert isinstance(pg, PlacementGroup) - - pg_table = placement_group_table(pg) - assert pg_table["strategy"] == "STRICT_SPREAD" - # We should provision 4 host-level bundles instead of the default 16 chip-level bundles. - assert len(pg_table["bundles"]) == 4 - for bundle in pg_table["bundles"].values(): - assert "TPU" in bundle - assert bundle["TPU"] == 4 - finally: - # Let the backend tear down its own resources if it has any - engine_config.accelerator.shutdown() - if pg is not None: - try: - ray.util.remove_placement_group(pg) - except Exception: - pass + ray.util.remove_placement_group(pg) + except Exception: + pass # Already cleaned up by the wrapper def test_single_tpu_fallback(ray_tpu_cluster): @@ -229,17 +221,15 @@ def test_tpu_slice_placement_group_creation_cpu_driver_homogeneous_tpu_bundles_p pass -def test_tpu_serve_deployment_default_host_level_bundles(ray_tpu_cluster): +def test_tpu_serve_deployment_default_chip_level_bundles(ray_tpu_cluster): """ Verifies that a Serve deployment created for a multi-host TPU slice defaults - to host-level bundles when no placement_group_config is specified. + to chip-level bundles when no placement_group_config is specified. """ - from ray.llm._internal.serve.core.configs.accelerators import TPUConfig - llm_config = LLMConfig( model_loading_config=ModelLoadingConfig(model_id="test-tpu-model"), accelerator_type="TPU-V6E", - accelerator_config=TPUConfig(kind="tpu", topology="4x4"), + accelerator_config={"kind": "tpu", "topology": "4x4"}, ) app = serve.deployment(LLMServer).bind(llm_config, engine_cls=PGCreationMockEngine) @@ -266,10 +256,10 @@ def test_tpu_serve_deployment_default_host_level_bundles(ray_tpu_cluster): worker_pg = [pg for pg in active_pgs if pg not in head_pgs][0] assert worker_pg["strategy"] == "PACK" - # 4x4 topology = 16 chips. Default is 4 bundles of 4 TPUs (per-host). - assert len(worker_pg["bundles"]) == 4 + # 4x4 topology = 16 chips. Default is 16 bundles of 1 TPU. + assert len(worker_pg["bundles"]) == 16 for bundle in worker_pg["bundles"].values(): - assert bundle.get("TPU", 0) == 4.0 + assert bundle.get("TPU", 0) == 1 serve.shutdown() @@ -282,7 +272,7 @@ def test_tpu_serve_deployment_explicit_host_level_bundles(ray_tpu_cluster): llm_config = LLMConfig( model_loading_config=ModelLoadingConfig(model_id="test-tpu-model"), accelerator_type="TPU-V6E", - accelerator_config=TPUConfig(kind="tpu", topology="4x4"), + accelerator_config={"kind": "tpu", "topology": "4x4"}, placement_group_config={"bundle_per_worker": {"TPU": 4}}, ) @@ -318,52 +308,5 @@ def test_tpu_serve_deployment_explicit_host_level_bundles(ray_tpu_cluster): serve.shutdown() -def test_tpu_serve_deployment_explicit_per_chip_bundles(ray_tpu_cluster): - """ - Verifies that a user can explicitly request chip-level bundles (1 TPU per bundle) - for a full multi-host TPU slice via placement_group_config. - """ - from ray.llm._internal.serve.core.configs.accelerators import TPUConfig - - llm_config = LLMConfig( - model_loading_config=ModelLoadingConfig(model_id="test-tpu-model"), - accelerator_type="TPU-V6E", - accelerator_config=TPUConfig(kind="tpu", topology="4x4"), - placement_group_config={"bundle_per_worker": {"TPU": 1}}, - engine_kwargs={"tensor_parallel_size": 16}, - ) - - app = serve.deployment(LLMServer).bind(llm_config, engine_cls=PGCreationMockEngine) - serve.run(app) - - pg_table = ray.util.placement_group_table() - active_pgs = list( - {k: v for k, v in pg_table.items() if v["state"] == "CREATED"}.values() - ) - - assert ( - len(active_pgs) == 2 - ), "Expected 2 PGs - one for TPU Head, one for worker bundles" - - tpu_head_resource = "TPU-v6e-16-head" - head_pgs = [ - pg - for pg in active_pgs - if len(pg["bundles"]) == 1 - and tpu_head_resource in list(pg["bundles"].values())[0] - ] - assert len(head_pgs) == 1 - - worker_pg = [pg for pg in active_pgs if pg not in head_pgs][0] - - assert worker_pg["strategy"] == "PACK" - # 4x4 topology = 16 chips. Explicitly requested 16 bundles of 1 TPU. - assert len(worker_pg["bundles"]) == 16 - for bundle in worker_pg["bundles"].values(): - assert bundle.get("TPU", 0) == 1.0 - - serve.shutdown() - - if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/serve/_private/common.py b/python/ray/serve/_private/common.py index c095a3ee44e5..324c0d8bb7dc 100644 --- a/python/ray/serve/_private/common.py +++ b/python/ray/serve/_private/common.py @@ -1,7 +1,7 @@ import json from dataclasses import asdict, dataclass, field from enum import Enum -from typing import Any, Awaitable, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union from starlette.types import Scope @@ -18,6 +18,10 @@ from ray.util.annotations import PublicAPI from ray.util.placement_group import PlacementGroup +if TYPE_CHECKING: + from ray.serve._private.default_impl import _ReplicaPlacementGroup + from ray.serve.config import AcceleratorConfig + REPLICA_ID_FULL_ID_STR_PREFIX = "SERVE_REPLICA::" GANG_PG_NAME_PREFIX = "SERVE_GANG::" @@ -897,6 +901,7 @@ class CreatePlacementGroupRequest: runtime_env: Optional[str] = None bundle_label_selector: Optional[List[Dict[str, str]]] = None fallback_strategy: Optional[List[Dict[str, Any]]] = None + accelerator_config: Optional["AcceleratorConfig"] = None @dataclass @@ -924,6 +929,9 @@ class GangPlacementGroupRequest: replica_pg_fallback_strategy: Optional[List[Dict[str, Any]]] = None """Fallback strategy for per-replica placement group bundles.""" + accelerator_config: Optional["AcceleratorConfig"] = None + """Optional accelerator configuration for TPU/GPU provisioning.""" + @dataclass class GangReservationResult: @@ -932,7 +940,7 @@ class GangReservationResult: success: bool """True when all gang PGs were created successfully.""" error_message: Optional[str] = None - gang_pgs: Optional[List[PlacementGroup]] = None + gang_pgs: Optional[List[Union[PlacementGroup, "_ReplicaPlacementGroup"]]] = None gang_ids: Optional[List[str]] = None gang_pg_names: Optional[List[str]] = None diff --git a/python/ray/serve/_private/config.py b/python/ray/serve/_private/config.py index f34839511689..461e0fe8c0e1 100644 --- a/python/ray/serve/_private/config.py +++ b/python/ray/serve/_private/config.py @@ -32,6 +32,7 @@ ) from ray.serve._private.utils import DEFAULT, DeploymentOptionUpdateType from ray.serve.config import ( + AcceleratorConfig, AggregationFunction, AutoscalingConfig, DeploymentActorConfig, @@ -191,6 +192,10 @@ class DeploymentConfig(BaseModel): update_type=DeploymentOptionUpdateType.NeedsActorReconfigure, ) + accelerator_config: Optional[AcceleratorConfig] = Field( + default=None, update_type=DeploymentOptionUpdateType.HeavyWeight + ) + # This flag is used to let replica know they are deployed from # a different language. is_cross_language: bool = False @@ -323,6 +328,8 @@ def needs_pickle(self): def to_proto(self): data = self.model_dump() + if data.get("accelerator_config") is not None: + data["accelerator_config"] = cloudpickle.dumps(self.accelerator_config) if data.get("user_config") is not None: if self.needs_pickle(): data["user_config"] = cloudpickle.dumps(data["user_config"]) @@ -430,6 +437,11 @@ def from_proto(cls, proto: DeploymentConfigProto): data["is_cross_language"] if "is_cross_language" in data else False ) needs_pickle = _needs_pickle(deployment_language, is_cross_language) + if "accelerator_config" in data: + if data["accelerator_config"] != b"": + data["accelerator_config"] = cloudpickle.loads(proto.accelerator_config) + else: + data["accelerator_config"] = None if "user_config" in data: if data["user_config"] != b"": if needs_pickle: diff --git a/python/ray/serve/_private/default_impl.py b/python/ray/serve/_private/default_impl.py index a68efdffd038..c84b8da9bed0 100644 --- a/python/ray/serve/_private/default_impl.py +++ b/python/ray/serve/_private/default_impl.py @@ -1,5 +1,7 @@ import asyncio -from typing import Callable, Optional, Tuple +import logging +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple, Union import ray from ray._common.constants import HEAD_NODE_RESOURCE_NAME @@ -42,7 +44,11 @@ inside_ray_client_context, resolve_deployment_response, ) -from ray.util.placement_group import PlacementGroup +from ray.serve.config import TPUAcceleratorConfig +from ray.util.placement_group import PlacementGroup, remove_placement_group +from ray.util.tpu import SlicePlacementGroup, slice_placement_group + +logger = logging.getLogger(__name__) # NOTE: Please read carefully before changing! # @@ -51,11 +57,93 @@ # API modified w/o substantial enough justification +@dataclass +class _ReplicaPlacementGroup: + """Internal Serve handle for a replica's placement group(s). + + Wraps the worker PG and any accelerator-specific cleanup hooks so the + controller doesn't need to know whether the underlying request was a + plain CPU/GPU PG or a TPU slice reservation. + """ + + placement_group: PlacementGroup + _slice_pg: Optional[SlicePlacementGroup] = None + + def release_reservation_holders(self) -> None: + """Call after ``placement_group.ready()`` resolves successfully. + + Releases any internal reservation-holder PGs (e.g. TPU head PGs) + that were only needed to claim resources during scheduling. No-op + for non-accelerator deployments. + """ + if self._slice_pg is not None: + self._slice_pg.release_head_pgs() + + def shutdown(self) -> None: + """Tear down the replica's PG(s). Idempotent.""" + if self._slice_pg is not None: + self._slice_pg.shutdown() + self._slice_pg = None + self.placement_group = None + elif self.placement_group is not None: + + try: + remove_placement_group(self.placement_group) + except Exception: + logger.exception("Failed to remove placement group.") + finally: + self.placement_group = None + + +def _create_replica_placement_group( + request: CreatePlacementGroupRequest, +) -> _ReplicaPlacementGroup: + """Internal entry point that supports accelerator-specific dispatch.""" + accelerator_config = request.accelerator_config + + if isinstance(accelerator_config, TPUAcceleratorConfig): + slice_pg = _default_create_tpu_placement_group( + tpu_config=accelerator_config, + strategy=request.strategy, + name=request.name, + lifetime="detached", + bundle_label_selector=request.bundle_label_selector, + ) + return _ReplicaPlacementGroup( + placement_group=slice_pg.placement_group, + _slice_pg=slice_pg, + ) + + pg = _default_create_placement_group(request) + return _ReplicaPlacementGroup(placement_group=pg) + + +def _default_create_tpu_placement_group( + tpu_config: TPUAcceleratorConfig, + strategy: str, + name: str, + lifetime: Optional[str], + bundle_label_selector: Optional[List[Dict[str, str]]] = None, +) -> SlicePlacementGroup: + return slice_placement_group( + topology=tpu_config.topology, + accelerator_version=tpu_config.accelerator_version, + num_slices=tpu_config.num_slices, + chips_per_vm=tpu_config.chips_per_vm, + strategy=strategy, + name=name, + lifetime=lifetime, + bundle_label_selector=bundle_label_selector, + ) + + def create_cluster_node_info_cache(gcs_client: GcsClient) -> ClusterNodeInfoCache: return DefaultClusterNodeInfoCache(gcs_client) -CreatePlacementGroupFn = Callable[[CreatePlacementGroupRequest], PlacementGroup] +CreatePlacementGroupFn = Callable[ + [CreatePlacementGroupRequest], Union[PlacementGroup, _ReplicaPlacementGroup] +] def _default_create_placement_group( @@ -81,7 +169,7 @@ def create_deployment_scheduler( cluster_node_info_cache, head_node_id, create_placement_group_fn=create_placement_group_fn_override - or _default_create_placement_group, + or _create_replica_placement_group, ) diff --git a/python/ray/serve/_private/deployment_scheduler.py b/python/ray/serve/_private/deployment_scheduler.py index 8cd17ef5d175..2fdf3fa7190b 100644 --- a/python/ray/serve/_private/deployment_scheduler.py +++ b/python/ray/serve/_private/deployment_scheduler.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from enum import Enum from functools import total_ordering -from typing import Any, Callable, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union import ray from ray._raylet import node_labels_match_selector @@ -27,6 +27,7 @@ RAY_SERVE_USE_PACK_SCHEDULING_STRATEGY, SERVE_LOGGER_NAME, ) +from ray.serve.config import AcceleratorConfig from ray.util.placement_group import PlacementGroup from ray.util.scheduling_strategies import ( LabelMatchExpressionsT, @@ -35,6 +36,9 @@ PlacementGroupSchedulingStrategy, ) +if TYPE_CHECKING: + from ray.serve._private.default_impl import _ReplicaPlacementGroup + logger = logging.getLogger(SERVE_LOGGER_NAME) @@ -198,6 +202,7 @@ class ReplicaSchedulingRequest: placement_group_strategy: Optional[str] = None placement_group_bundle_label_selector: Optional[List[Dict[str, str]]] = None placement_group_fallback_strategy: Optional[List[Dict[str, Any]]] = None + accelerator_config: Optional[AcceleratorConfig] = None max_replicas_per_node: Optional[int] = None # Gang scheduling fields -- if set, replica should be scheduled on # the reserved gang placement group at the specified bundle index. @@ -636,12 +641,21 @@ def _schedule_replica( replica_id = scheduling_request.replica_id deployment_id = replica_id.deployment_id placement_group = None + slice_pg = None scheduling_strategy = default_scheduling_strategy if scheduling_request.gang_placement_group is not None: # Gang scheduling -- use the reserved gang placement group - placement_group = scheduling_request.gang_placement_group + pg_wrapper = scheduling_request.gang_placement_group + placement_group = ( + pg_wrapper.placement_group + if hasattr(pg_wrapper, "placement_group") + else pg_wrapper + ) + # Preserve the wrapper for cleanup of head PGs + slice_pg = pg_wrapper if hasattr(pg_wrapper, "placement_group") else None + scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=placement_group, placement_group_bundle_index=scheduling_request.gang_pg_index, @@ -651,21 +665,32 @@ def _schedule_replica( target_labels = None target_node_id = None elif scheduling_request.placement_group_bundles is not None: + slice_pg = None placement_group_strategy = ( scheduling_request.placement_group_strategy if scheduling_request.placement_group_strategy else "PACK" ) try: - pg = self._create_placement_group_fn( + pg_result = self._create_placement_group_fn( CreatePlacementGroupRequest( bundles=scheduling_request.placement_group_bundles, strategy=placement_group_strategy, target_node_id=target_node_id, name=scheduling_request.actor_options["name"], bundle_label_selector=scheduling_request.placement_group_bundle_label_selector, - ) + accelerator_config=scheduling_request.accelerator_config, + ), ) + + from ray.serve._private.default_impl import _ReplicaPlacementGroup + + if isinstance(pg_result, _ReplicaPlacementGroup): + pg = pg_result.placement_group + slice_pg = pg_result + else: + pg = pg_result + slice_pg = None except Exception: # We add a defensive exception here, so the controller can # make progress even if the placement group isn't created. @@ -720,6 +745,17 @@ def _schedule_replica( scheduling_request.status = ( ReplicaSchedulingRequestStatus.ACTOR_CREATION_FAILED ) + + # Only clean up single-replica PGs. Gang PGs are managed elsewhere. + if scheduling_request.gang_placement_group is None: + if slice_pg is not None: + slice_pg.shutdown() + elif ( + placement_group is not None + and scheduling_request.placement_group_bundles is not None + ): + ray.util.remove_placement_group(placement_group) + return False del self._pending_replicas[deployment_id][replica_id] @@ -731,7 +767,11 @@ def _schedule_replica( placement_group = scheduling_strategy.placement_group scheduling_request.status = ReplicaSchedulingRequestStatus.SUCCEEDED - scheduling_request.on_scheduled(actor_handle, placement_group=placement_group) + scheduling_request.on_scheduled( + actor_handle, + placement_group=placement_group, + placement_group_manager=slice_pg, + ) return True @abstractmethod @@ -816,7 +856,7 @@ def _prepare_gangs_for_deployment( # Flatten per-replica bundles to form a placement group to atomically reserve resources # required for each gang - gang_pgs: List[PlacementGroup] = [] + gang_pgs: List[Union[PlacementGroup, "_ReplicaPlacementGroup"]] = [] gang_ids: List[str] = [] gang_pg_names: List[str] = [] for gang_index in range(num_gangs): @@ -867,6 +907,7 @@ def _prepare_gangs_for_deployment( name=pg_name, bundle_label_selector=label_selector, fallback_strategy=fallback_strategy, + accelerator_config=request.accelerator_config, ) ) gang_pgs.append(pg) diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index 6de2a62e5ed4..8a55016adebd 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -10,12 +10,15 @@ from copy import copy from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union import ray from ray import ObjectRef, cloudpickle from ray._common import ray_constants from ray.actor import ActorHandle + +if TYPE_CHECKING: + from ray.serve._private.default_impl import _ReplicaPlacementGroup from ray.exceptions import ( RayActorError, RayError, @@ -724,7 +727,7 @@ def __init__( # we trigger `initialize_and_get_metadata`. self._was_initialized_obj_ref: Optional[ObjectRef] = None # Set to True when `check_ready()` determines the actor cannot be - # recovered (e.g., the previous controller crashed before the actor + # recovered (e.g. the previous controller crashed before the actor # finished its initial setup). The reconciler treats this case as a # silent drop / replace rather than a deploy failure, since the # underlying cause is a controller-side crash, not user code. @@ -779,6 +782,7 @@ def __init__( self._last_record_routing_stats_time: float = 0.0 self._has_user_routing_stats_method: bool = False self._ingress: bool = False + self._replica_pg = None # Outbound deployments polling state self._outbound_deployments: Optional[List[DeploymentID]] = None @@ -1013,7 +1017,9 @@ def start( self, deployment_info: DeploymentInfo, assign_rank_callback: Callable[[ReplicaID], ReplicaRank], - gang_placement_group: Optional[PlacementGroup] = None, + gang_placement_group: Optional[ + Union[PlacementGroup, "_ReplicaPlacementGroup"] + ] = None, gang_pg_index: Optional[int] = None, gang_context: Optional[GangContext] = None, ) -> ReplicaSchedulingRequest: @@ -1152,6 +1158,7 @@ def start( placement_group_fallback_strategy=( deployment_info.replica_config.placement_group_fallback_strategy ), + accelerator_config=deployment_info.deployment_config.accelerator_config, max_replicas_per_node=( deployment_info.replica_config.max_replicas_per_node ), @@ -1164,9 +1171,11 @@ def on_scheduled( self, actor_handle: ActorHandle, placement_group: Optional[PlacementGroup] = None, + placement_group_manager: Optional[Any] = None, ): self._actor_handle = actor_handle self._placement_group = placement_group + self._replica_pg = placement_group_manager if self._is_cross_language: self._actor_handle = JavaActorHandleProxy(self._actor_handle) @@ -1465,6 +1474,9 @@ def check_ready(self) -> Tuple[ReplicaStartupStatus, Optional[str]]: ) return ReplicaStartupStatus.FAILED, repr(e) + if self._replica_pg is not None: + self._replica_pg.release_reservation_holders() + return ReplicaStartupStatus.SUCCEEDED, None @property @@ -1512,15 +1524,30 @@ def check_stopped(self) -> bool: finally: # Remove the placement group both if the actor has already been deleted or # it was just killed above. - if stopped and self._placement_group is not None: - try: - ray.util.remove_placement_group(self._placement_group) - except ValueError: - # ValueError thrown from ray.util.remove_placement_group means the - # placement group has already been removed. - logger.debug( - f"Placement group for {self._replica_id} was already removed." - ) + if stopped: + # Check for gang placement group first to avoid shutting down + # the shared replica_pg before all replicas are done. + if self._gang_placement_group is not None: + # Avoid calling shutdown() or remove_placement_group() here + # since replicas in Gang PG might still be draining. + self._gang_placement_group = None + self._placement_group = None + self._replica_pg = None + elif self._replica_pg is not None: + self._replica_pg.shutdown() + self._replica_pg = None + self._placement_group = None + elif self._placement_group is not None: + try: + ray.util.remove_placement_group(self._placement_group) + except ValueError: + # ValueError thrown from ray.util.remove_placement_group means the + # placement group has already been removed. + logger.debug( + f"Placement group for {self._replica_id} was already removed." + ) + finally: + self._placement_group = None return stopped @@ -1899,7 +1926,9 @@ def start( self, deployment_info: DeploymentInfo, assign_rank_callback: Callable[[ReplicaID], ReplicaRank], - gang_placement_group: Optional[PlacementGroup] = None, + gang_placement_group: Optional[ + Union[PlacementGroup, "_ReplicaPlacementGroup"] + ] = None, gang_pg_index: Optional[int] = None, gang_context: Optional[GangContext] = None, ) -> ReplicaSchedulingRequest: @@ -2852,6 +2881,8 @@ def __init__( # Updated on replica creation during upscaling and permanent removal during downscaling. self._gang_id_by_replica: Dict[ReplicaID, str] = {} self._replicas_by_gang_id: Dict[str, Set[ReplicaID]] = defaultdict(set) + # Track the actual PG objects to clean them up when the gang empties + self._gang_pg_by_id: Dict[str, Any] = {} # Deployment-scoped actor lifecycle (per deployment) self._deployment_actors = DeploymentActorContainer(self._id) @@ -3957,6 +3988,9 @@ def _add_upscale_gang_replicas( ) for gang_pg, gang_id, pg_name in zip(gang_pgs, gang_ids, gang_pg_names): + # Track the PG object for later cleanup + self._gang_pg_by_id[gang_id] = gang_pg + member_replica_ids = [ ReplicaID(get_random_string(), deployment_id=self._id) for _ in range(gang_size) @@ -4332,8 +4366,9 @@ def _register_gang_replica(self, replica_id: ReplicaID, gang_id: str) -> None: self._gang_id_by_replica[replica_id] = gang_id self._replicas_by_gang_id[gang_id].add(replica_id) - def _unregister_gang_replica(self, replica_id: ReplicaID) -> None: + def _unregister_gang_replica(self, replica: "DeploymentReplica") -> None: """Remove a replica from the gang membership bookkeeping.""" + replica_id = replica.replica_id gang_id = self._gang_id_by_replica.pop(replica_id, None) if gang_id is not None: members = self._replicas_by_gang_id.get(gang_id) @@ -4342,6 +4377,42 @@ def _unregister_gang_replica(self, replica_id: ReplicaID) -> None: if not members: self._replicas_by_gang_id.pop(gang_id, None) + gang_pg = self._gang_pg_by_id.pop(gang_id, None) + + # Fallback for controller actor recovery, if the in-memory `_gang_pg_by_id` dict + # is empty, fetch the placement group from GCS. + if ( + gang_pg is None + and replica.gang_context + and replica.gang_context.pg_name + ): + try: + gang_pg = ray.util.get_placement_group( + replica.gang_context.pg_name + ) + except ValueError: + pass # PG doesn't exist in Ray, nothing to clean up + + if gang_pg is not None: + try: + if hasattr(gang_pg, "shutdown"): + gang_pg.shutdown() + else: + placement_group = ( + gang_pg.placement_group + if hasattr(gang_pg, "placement_group") + else gang_pg + ) + ray.util.remove_placement_group(placement_group) + except ValueError: + logger.debug( + f"Gang placement group for {gang_id} was already removed." + ) + except Exception: + logger.exception( + f"Failed to remove gang placement group for {gang_id}." + ) + def _clear_health_gauge_cache(self, replica_unique_id: str) -> None: """Remove a replica from the health-gauge cache (after it has fully stopped and been removed from tracking).""" @@ -4722,7 +4793,7 @@ def _check_and_update_transitioning_replicas(self): f"Released rank from replica {replica_id} in deployment {self._id}" ) self._autoscaling_state_manager.on_replica_stopped(replica.replica_id) - self._unregister_gang_replica(replica.replica_id) + self._unregister_gang_replica(replica) def _reconfigure_replicas_with_new_ranks( self, replicas_to_reconfigure: List["DeploymentReplica"] @@ -5970,6 +6041,7 @@ def _reserve_gang_placement_groups( replica_pg_fallback_strategy=( replica_config.placement_group_fallback_strategy ), + accelerator_config=deployment_state._target_state.info.deployment_config.accelerator_config, ) if not gang_requests: diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 976031a290f4..ad742a78c6e4 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -40,12 +40,14 @@ wait_for_interrupt, ) from ray.serve.config import ( + AcceleratorConfig, AutoscalingConfig, DeploymentActorConfig, GangSchedulingConfig, HTTPOptions, ProxyLocation, RequestRouterConfig, + TPUAcceleratorConfig, gRPCOptions, ) from ray.serve.context import ( @@ -447,6 +449,25 @@ async def __del__(self): return decorator +def _resolve_accelerator_config( + value: Union[Dict, AcceleratorConfig, None], +) -> Optional[AcceleratorConfig]: + + if value is None or isinstance(value, AcceleratorConfig): + return value + if isinstance(value, dict): + accelerator_type = value.get("accelerator_type") + if accelerator_type == "tpu": + return TPUAcceleratorConfig(**value) + raise ValueError( + f"Unknown accelerator_type {accelerator_type!r}. " + f"Supported types: 'tpu'." + ) + raise TypeError( + f"accelerator_config must be a dict or AcceleratorConfig, got {type(value)}." + ) + + @PublicAPI(stability="stable") def deployment( _func_or_class: Optional[Callable] = None, @@ -464,6 +485,7 @@ def deployment( user_config: Default[Optional[Any]] = DEFAULT.VALUE, max_ongoing_requests: Default[int] = DEFAULT.VALUE, max_queued_requests: Default[int] = DEFAULT.VALUE, + accelerator_config: Default[Union[Dict, AcceleratorConfig, None]] = DEFAULT.VALUE, autoscaling_config: Default[Union[Dict, AutoscalingConfig, None]] = DEFAULT.VALUE, graceful_shutdown_wait_loop_s: Default[float] = DEFAULT.VALUE, graceful_shutdown_timeout_s: Default[float] = DEFAULT.VALUE, @@ -634,11 +656,15 @@ class MyDeployment: if isinstance(logging_config, LoggingConfig): logging_config = logging_config.model_dump() + if accelerator_config is not DEFAULT.VALUE: + accelerator_config = _resolve_accelerator_config(accelerator_config) + deployment_config = DeploymentConfig.from_default( num_replicas=num_replicas if num_replicas is not None else 1, user_config=user_config, max_ongoing_requests=max_ongoing_requests, max_queued_requests=max_queued_requests, + accelerator_config=accelerator_config, autoscaling_config=autoscaling_config, graceful_shutdown_wait_loop_s=graceful_shutdown_wait_loop_s, graceful_shutdown_timeout_s=graceful_shutdown_timeout_s, diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index acf9c565346c..d3b6a393185c 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -4,7 +4,7 @@ import warnings from enum import Enum from functools import cached_property -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union from pydantic import ( BaseModel, @@ -42,7 +42,7 @@ SERVE_LOGGER_NAME, ) from ray.serve._private.utils import validate_ssl_config -from ray.util.annotations import Deprecated, PublicAPI +from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI logger = logging.getLogger(SERVE_LOGGER_NAME) @@ -709,6 +709,53 @@ def get_target_ongoing_requests(self) -> PositiveFloat: return self.target_ongoing_requests +@DeveloperAPI(stability="alpha") +class AcceleratorConfig(BaseModel): + """Base class for structured accelerator configurations. + + Use a concrete subclass — e.g. :class:`TPUAcceleratorConfig` — when + declaring a deployment's accelerator requirements via + ``serve.deployment(accelerator_config=...)``. + """ + + accelerator_type: str = Field( + ..., description="Discriminator identifying the accelerator config type." + ) + + model_config = {"frozen": True, "extra": "forbid"} + + +@DeveloperAPI(stability="alpha") +class TPUAcceleratorConfig(AcceleratorConfig): + """TPU slice specification for a Serve deployment. + + Mirrors the parameters of :func:`ray.util.tpu.slice_placement_group`. + Ray Serve uses this config to provision a TPU slice placement group + per replica and to manage its lifecycle through the controller. + + Example: + >>> from ray.serve.config import TPUAcceleratorConfig + >>> config = TPUAcceleratorConfig(topology="4x4", accelerator_version="v6e") + """ + + accelerator_type: Literal["tpu"] = "tpu" + + topology: str = Field( + ..., description="TPU pod topology, e.g. '2x2', '4x4', '2x2x2'." + ) + accelerator_version: str = Field( + ..., description="TPU accelerator version, e.g. 'v4', 'v5p', 'v6e'." + ) + num_slices: int = Field(default=1, ge=1, description="Number of slices to reserve.") + chips_per_vm: Optional[int] = Field( + default=None, + description=( + "Override for chips per host. Defaults to the canonical value " + "for the given accelerator_version." + ), + ) + + # Keep in sync with ServeDeploymentMode in dashboard/client/src/type/serve.ts @Deprecated class DeploymentMode(str, Enum): diff --git a/python/ray/serve/tests/BUILD.bazel b/python/ray/serve/tests/BUILD.bazel index a66511a3016d..c6124230b9e4 100644 --- a/python/ray/serve/tests/BUILD.bazel +++ b/python/ray/serve/tests/BUILD.bazel @@ -101,6 +101,7 @@ py_test_module_list( py_test_module_list( size = "medium", files = [ + "test_accelerator_config.py", "test_actor_replica_wrapper.py", "test_backpressure.py", "test_backpressure_grpc.py", diff --git a/python/ray/serve/tests/test_accelerator_config.py b/python/ray/serve/tests/test_accelerator_config.py new file mode 100644 index 000000000000..11d0d34a5972 --- /dev/null +++ b/python/ray/serve/tests/test_accelerator_config.py @@ -0,0 +1,107 @@ +from unittest.mock import patch + +import pytest + +import ray +from ray import serve +from ray.cluster_utils import Cluster +from ray.serve._private.common import CreatePlacementGroupRequest +from ray.serve._private.default_impl import ( + _create_replica_placement_group, + _ReplicaPlacementGroup, +) +from ray.serve.config import TPUAcceleratorConfig + + +@pytest.fixture +def mock_tpu_cluster(): + # Simulates a Ray cluster with a multi-host TPU v6e-16 slice (4x4 topology). + pod_type = "v6e-16" + topology = "4x4" + cluster = Cluster() + # Head node + cluster.add_node(num_cpus=4) + + # TPU nodes: A 4x4 v6e slice has 16 chips. We simulate 4 hosts with 4 chips each. + for i in range(4): + env_vars = { + "TPU_NAME": "test-slice", + "TPU_WORKER_ID": str(i), + "TPU_ACCELERATOR_TYPE": pod_type, + "TPU_TOPOLOGY": topology, + } + labels = { + "ray.io/tpu-slice-name": "test-slice", + "ray.io/tpu-worker-id": str(i), + "ray.io/tpu-pod-type": pod_type, + } + resources = {"TPU": 4, "accelerator_type:TPU-V6E": 4} + + # The first node is the "head" of the slice + if i == 0: + resources[f"TPU-{pod_type}-head"] = 1 + + cluster.add_node( + num_cpus=8, + resources=resources, + labels=labels, + env_vars=env_vars, + ) + + cluster.wait_for_nodes() + ray.init(address=cluster.address, ignore_reinit_error=True) + serve.start() + yield cluster + serve.shutdown() + ray.shutdown() + cluster.shutdown() + + +def test_tpu_accelerator_config_integration(mock_tpu_cluster): + """Test that AcceleratorConfig correctly creates SlicePlacementGroup in a mock cluster.""" + + tpu_config = TPUAcceleratorConfig(topology="4x4", accelerator_version="v6e") + + request = CreatePlacementGroupRequest( + bundles=[{"CPU": 1}], # Ignored since accel_config will override it + strategy="SPREAD", + target_node_id=None, + name="test-tpu-pg", + accelerator_config=tpu_config, + ) + + # This should call _create_tpu_placement_group and return a wrapper + replica_pg = _create_replica_placement_group(request) + + assert isinstance(replica_pg, _ReplicaPlacementGroup) + assert replica_pg._slice_pg is not None + + # Verify the placement group is ready + ray.get(replica_pg.placement_group.ready(), timeout=20) + + # Verify cleanup + replica_pg.shutdown() + assert replica_pg._slice_pg is None + + +def test_tpu_accelerator_config_timeout_cleanup(mock_tpu_cluster): + """Test that SlicePlacementGroup cleans up head PGs on timeout.""" + + # Request a topology that requires 8 hosts (v6e-32) when cluster only has 4. + tpu_config = TPUAcceleratorConfig(topology="4x8", accelerator_version="v6e") + + request = CreatePlacementGroupRequest( + bundles=[{"CPU": 1}], + strategy="SPREAD", + target_node_id=None, + name="test-tpu-timeout-pg", + accelerator_config=tpu_config, + ) + + # Patch timeout to be short, and mock remove_placement_group to verify cleanup + with patch("ray._private.accelerators.tpu.remove_placement_group") as mock_remove: + with patch("ray.util.tpu.DEFAULT_TPU_HEAD_RESERVATION_TIMEOUT_S", 2.0): + with pytest.raises(TimeoutError, match="Failed to reserve TPU head"): + _create_replica_placement_group(request) + + assert mock_remove.called diff --git a/python/ray/serve/tests/unit/BUILD.bazel b/python/ray/serve/tests/unit/BUILD.bazel index c7ab2f0db3bb..33b3fd649fa0 100644 --- a/python/ray/serve/tests/unit/BUILD.bazel +++ b/python/ray/serve/tests/unit/BUILD.bazel @@ -41,6 +41,7 @@ py_test_module_list( "RAY_SERVE_FAIL_ON_RANK_ERROR": "1", }, files = [ + "test_accelerator_config.py", "test_deployment_scheduler.py", "test_deployment_state.py", ], diff --git a/python/ray/serve/tests/unit/test_accelerator_config.py b/python/ray/serve/tests/unit/test_accelerator_config.py new file mode 100644 index 000000000000..f77682bc17ba --- /dev/null +++ b/python/ray/serve/tests/unit/test_accelerator_config.py @@ -0,0 +1,119 @@ +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from ray.serve._private.common import CreatePlacementGroupRequest +from ray.serve._private.default_impl import ( + _create_replica_placement_group, + _ReplicaPlacementGroup, +) +from ray.serve.api import deployment +from ray.serve.config import TPUAcceleratorConfig + + +def test_tpu_accelerator_config_construction(): + config = TPUAcceleratorConfig(topology="4x4", accelerator_version="v6e") + assert config.accelerator_type == "tpu" + assert config.topology == "4x4" + assert config.num_slices == 1 # default + + +def test_tpu_accelerator_config_immutable(): + config = TPUAcceleratorConfig(topology="4x4", accelerator_version="v6e") + with pytest.raises(ValidationError): + config.topology = "2x2" + + +def test_tpu_accelerator_config_extra_forbid(): + with pytest.raises(ValidationError): + TPUAcceleratorConfig(topology="4x4", accelerator_version="v6e", bogus_field=1) + + +def test_deployment_options_accept_tpu_config_instance(): + config = TPUAcceleratorConfig(topology="4x4", accelerator_version="v6e") + + @deployment(accelerator_config=config) + class D: + pass + + assert isinstance(D._deployment_config.accelerator_config, TPUAcceleratorConfig) + + +def test_deployment_options_accept_dict_form(): + @deployment( + accelerator_config={ + "accelerator_type": "tpu", + "topology": "4x4", + "accelerator_version": "v6e", + } + ) + class D: + pass + + cfg = D._deployment_config.accelerator_config + assert isinstance(cfg, TPUAcceleratorConfig) + assert cfg.topology == "4x4" + + +def test_deployment_options_dict_unknown_accelerator_type_raises(): + with pytest.raises(ValueError, match="Unknown accelerator_type"): + + @deployment(accelerator_config={"accelerator_type": "xpu"}) + class D: + pass + + +def test_create_replica_placement_group_tpu_dispatch(): + config = TPUAcceleratorConfig(topology="4x4", accelerator_version="v6e") + request = CreatePlacementGroupRequest( + bundles=[], + strategy="SPREAD", + target_node_id="", + name="test", + accelerator_config=config, + ) + + fake_slice_pg = MagicMock() + fake_slice_pg.placement_group = MagicMock() + + with patch( + "ray.serve._private.default_impl.slice_placement_group" + ) as mock_slice_pg: + mock_slice_pg.return_value = fake_slice_pg + + result = _create_replica_placement_group(request) + + assert mock_slice_pg.called + assert result._slice_pg is not None + assert result.placement_group == fake_slice_pg.placement_group + mock_slice_pg.assert_called_once() + + +def test_replica_pg_shutdown_idempotent(): + """Test that _ReplicaPlacementGroup shutdown is idempotent.""" + # Path 1: No accelerator + mock_pg = MagicMock() + adapter = _ReplicaPlacementGroup(placement_group=mock_pg) + + with patch("ray.serve._private.default_impl.remove_placement_group") as mock_remove: + adapter.shutdown() + mock_remove.assert_called_once_with(mock_pg) + + # Call again, should not raise or call remove again + adapter.shutdown() + assert mock_remove.call_count == 1 + + # Path 2: With accelerator + mock_slice_pg = MagicMock() + adapter_with_accel = _ReplicaPlacementGroup( + placement_group=mock_pg, _slice_pg=mock_slice_pg + ) + + adapter_with_accel.shutdown() + mock_slice_pg.shutdown.assert_called_once() + assert adapter_with_accel._slice_pg is None + + # Call again, should not raise or call shutdown again + adapter_with_accel.shutdown() + assert mock_slice_pg.shutdown.call_count == 1 diff --git a/python/ray/serve/tests/unit/test_deployment_state.py b/python/ray/serve/tests/unit/test_deployment_state.py index 860a2ed4bc3e..f2e86663afa2 100644 --- a/python/ray/serve/tests/unit/test_deployment_state.py +++ b/python/ray/serve/tests/unit/test_deployment_state.py @@ -8549,6 +8549,74 @@ def test_gang_downscale_stops_complete_gangs(self, mock_deployment_state_manager ) assert ds.curr_status_info.status == DeploymentStatus.HEALTHY + def test_gang_pg_cleanup_on_downscale(self, mock_deployment_state_manager): + """Verify Gang PG is only destroyed when the last replica finishes stopping.""" + create_dsm, _, _, _ = mock_deployment_state_manager + dsm: DeploymentStateManager = create_dsm( + create_placement_group_fn_override=lambda *args, **kwargs: Mock(), + ) + gang_size = 2 + initial_replicas = 2 + deployment_id = DeploymentID(name="gang_pg_cleanup", app_name="app") + + info, version = deployment_info( + num_replicas=initial_replicas, + version="v1", + gang_scheduling_config=GangSchedulingConfig(gang_size=gang_size), + ) + dsm.deploy(deployment_id, info) + ds = dsm._deployment_states[deployment_id] + + # Mock the gang PG to track when shutdown method is called + mock_gang_pg = MagicMock() + dsm._deployment_scheduler.schedule_gang_placement_groups = Mock( + return_value={ + deployment_id: GangReservationResult( + success=True, + gang_pgs=[mock_gang_pg], + gang_ids=["g1"], + gang_pg_names=["SERVE_GANG::g1"], + ) + } + ) + + # Start all replicas + dsm.update() + for replica in ds._replicas.get([ReplicaState.STARTING]): + replica._actor.set_ready() + dsm.update() + check_counts( + ds, total=initial_replicas, by_state=[(ReplicaState.RUNNING, 2, version)] + ) + + # Scale down to 0 to trigger gang removal + new_info, _ = deployment_info( + num_replicas=0, + version="v1", + gang_scheduling_config=GangSchedulingConfig(gang_size=gang_size), + ) + dsm.deploy(deployment_id, new_info) + dsm.update() + + stopping = ds._replicas.get([ReplicaState.STOPPING]) + assert len(stopping) == 2 + + # Mark first replica done stopping + stopping[0]._actor.set_done_stopping() + dsm.update() + + # The gang PG is not removed yet, because sibling is still stopping + mock_gang_pg.shutdown.assert_not_called() + assert "g1" in ds._gang_pg_by_id + + # Mark second replica done stopping + stopping[1]._actor.set_done_stopping() + dsm.update() + + # The gang PG is removed now that the gang is empty + mock_gang_pg.shutdown.assert_called_once() + assert "g1" not in ds._gang_pg_by_id + def test_gang_downscale_prefers_pending_gang(self, mock_deployment_state_manager): """Downscaling prefers the gang that still has a pending replica.""" create_dsm, _, _, _ = mock_deployment_state_manager diff --git a/src/ray/protobuf/serve.proto b/src/ray/protobuf/serve.proto index a92c0ac3248a..01773799add0 100644 --- a/src/ray/protobuf/serve.proto +++ b/src/ray/protobuf/serve.proto @@ -249,6 +249,9 @@ message DeploymentConfig { // rolling upgrade) is distinguishable from an explicit value. When unset, // the Python side falls back to DEFAULT_ROLLING_UPDATE_PERCENTAGE (0.2). optional double rolling_update_percentage = 23; + + // Structured accelerator configuration for a Serve deployment. + bytes accelerator_config = 24; } // Deployment language. From 1c1dda88889e92813d0b5f386c3ca988b01d73c3 Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Sat, 9 May 2026 01:28:14 +0000 Subject: [PATCH 04/26] fix tests, change discriminator to 'kind', and fix cleanup logic Signed-off-by: Ryan O'Leary --- python/ray/serve/_private/default_impl.py | 1 - .../serve/_private/deployment_scheduler.py | 1 + python/ray/serve/_private/deployment_state.py | 44 ++++++++++--------- python/ray/serve/api.py | 9 ++-- python/ray/serve/config.py | 4 +- .../serve/tests/test_accelerator_config.py | 4 ++ .../serve/tests/test_deployment_scheduler.py | 2 +- .../tests/unit/test_accelerator_config.py | 12 +++-- .../tests/unit/test_deployment_scheduler.py | 8 ++-- 9 files changed, 49 insertions(+), 36 deletions(-) diff --git a/python/ray/serve/_private/default_impl.py b/python/ray/serve/_private/default_impl.py index c84b8da9bed0..793de57bdaab 100644 --- a/python/ray/serve/_private/default_impl.py +++ b/python/ray/serve/_private/default_impl.py @@ -86,7 +86,6 @@ def shutdown(self) -> None: self._slice_pg = None self.placement_group = None elif self.placement_group is not None: - try: remove_placement_group(self.placement_group) except Exception: diff --git a/python/ray/serve/_private/deployment_scheduler.py b/python/ray/serve/_private/deployment_scheduler.py index 2fdf3fa7190b..5944ccefd469 100644 --- a/python/ray/serve/_private/deployment_scheduler.py +++ b/python/ray/serve/_private/deployment_scheduler.py @@ -683,6 +683,7 @@ def _schedule_replica( ), ) + # Inline import to avoid circular dependency with default_impl from ray.serve._private.default_impl import _ReplicaPlacementGroup if isinstance(pg_result, _ReplicaPlacementGroup): diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index 8a55016adebd..5280a97ba2a3 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -727,7 +727,7 @@ def __init__( # we trigger `initialize_and_get_metadata`. self._was_initialized_obj_ref: Optional[ObjectRef] = None # Set to True when `check_ready()` determines the actor cannot be - # recovered (e.g. the previous controller crashed before the actor + # recovered (e.g., the previous controller crashed before the actor # finished its initial setup). The reconciler treats this case as a # silent drop / replace rather than a deploy failure, since the # underlying cause is a controller-side crash, not user code. @@ -783,6 +783,8 @@ def __init__( self._has_user_routing_stats_method: bool = False self._ingress: bool = False self._replica_pg = None + self._gang_placement_group = None + self._gang_pg_index = None # Outbound deployments polling state self._outbound_deployments: Optional[List[DeploymentID]] = None @@ -1525,29 +1527,29 @@ def check_stopped(self) -> bool: # Remove the placement group both if the actor has already been deleted or # it was just killed above. if stopped: - # Check for gang placement group first to avoid shutting down - # the shared replica_pg before all replicas are done. - if self._gang_placement_group is not None: - # Avoid calling shutdown() or remove_placement_group() here - # since replicas in Gang PG might still be draining. + try: + # Gang PGs are shared and managed by the DeploymentStateManager. + # We do nothing here to avoid shutting them down prematurely. + if self._gang_placement_group is not None: + pass + # Replicas with accelerator/wrapper PGs handle their own shutdown. + elif self._replica_pg is not None: + self._replica_pg.shutdown() + # Standard single-replica placement groups. + elif self._placement_group is not None: + try: + ray.util.remove_placement_group(self._placement_group) + except ValueError: + # ValueError thrown from ray.util.remove_placement_group means the + # placement group has already been removed. + logger.debug( + f"Placement group for {self._replica_id} was already removed." + ) + finally: + # Clear references to prevent dangling state. self._gang_placement_group = None - self._placement_group = None - self._replica_pg = None - elif self._replica_pg is not None: - self._replica_pg.shutdown() self._replica_pg = None self._placement_group = None - elif self._placement_group is not None: - try: - ray.util.remove_placement_group(self._placement_group) - except ValueError: - # ValueError thrown from ray.util.remove_placement_group means the - # placement group has already been removed. - logger.debug( - f"Placement group for {self._replica_id} was already removed." - ) - finally: - self._placement_group = None return stopped diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index ad742a78c6e4..9b0028e36407 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -456,11 +456,11 @@ def _resolve_accelerator_config( if value is None or isinstance(value, AcceleratorConfig): return value if isinstance(value, dict): - accelerator_type = value.get("accelerator_type") - if accelerator_type == "tpu": + kind = value.get("kind") + if kind == "tpu": return TPUAcceleratorConfig(**value) raise ValueError( - f"Unknown accelerator_type {accelerator_type!r}. " + f"Unknown accelerator kind {kind!r}. " f"Supported types: 'tpu'." ) raise TypeError( @@ -556,6 +556,9 @@ class MyDeployment: Once this limit is reached, subsequent requests will raise a BackPressureError (for handles) or return an HTTP 503 status code (for HTTP requests). Defaults to -1 (no limit). + accelerator_config: Configuration for hardware accelerators, such as TPUs. + Can be passed as an unstructured dictionary or a structured `AcceleratorConfig` + subclass (e.g. `TPUAcceleratorConfig`). See `AcceleratorConfig` for options. autoscaling_config: Parameters to configure autoscaling behavior. If this is set, `num_replicas` should be "auto" or not set. graceful_shutdown_wait_loop_s: Duration that replicas wait until there is diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index d3b6a393185c..70bebb2e2262 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -718,7 +718,7 @@ class AcceleratorConfig(BaseModel): ``serve.deployment(accelerator_config=...)``. """ - accelerator_type: str = Field( + kind: str = Field( ..., description="Discriminator identifying the accelerator config type." ) @@ -738,7 +738,7 @@ class TPUAcceleratorConfig(AcceleratorConfig): >>> config = TPUAcceleratorConfig(topology="4x4", accelerator_version="v6e") """ - accelerator_type: Literal["tpu"] = "tpu" + kind: Literal["tpu"] = "tpu" topology: str = Field( ..., description="TPU pod topology, e.g. '2x2', '4x4', '2x2x2'." diff --git a/python/ray/serve/tests/test_accelerator_config.py b/python/ray/serve/tests/test_accelerator_config.py index 11d0d34a5972..8b0acde69a3b 100644 --- a/python/ray/serve/tests/test_accelerator_config.py +++ b/python/ray/serve/tests/test_accelerator_config.py @@ -105,3 +105,7 @@ def test_tpu_accelerator_config_timeout_cleanup(mock_tpu_cluster): _create_replica_placement_group(request) assert mock_remove.called + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_deployment_scheduler.py b/python/ray/serve/tests/test_deployment_scheduler.py index b3f8688fecc0..c6332b1230db 100644 --- a/python/ray/serve/tests/test_deployment_scheduler.py +++ b/python/ray/serve/tests/test_deployment_scheduler.py @@ -68,7 +68,7 @@ def test_spread_deployment_scheduling_policy_upscale( replica_actor_handles = [] replica_placement_groups = [] - def on_scheduled(actor_handle, placement_group): + def on_scheduled(actor_handle, placement_group=None, placement_group_manager=None): replica_actor_handles.append(actor_handle) replica_placement_groups.append(placement_group) diff --git a/python/ray/serve/tests/unit/test_accelerator_config.py b/python/ray/serve/tests/unit/test_accelerator_config.py index f77682bc17ba..2950673218ae 100644 --- a/python/ray/serve/tests/unit/test_accelerator_config.py +++ b/python/ray/serve/tests/unit/test_accelerator_config.py @@ -14,7 +14,7 @@ def test_tpu_accelerator_config_construction(): config = TPUAcceleratorConfig(topology="4x4", accelerator_version="v6e") - assert config.accelerator_type == "tpu" + assert config.kind == "tpu" assert config.topology == "4x4" assert config.num_slices == 1 # default @@ -43,7 +43,7 @@ class D: def test_deployment_options_accept_dict_form(): @deployment( accelerator_config={ - "accelerator_type": "tpu", + "kind": "tpu", "topology": "4x4", "accelerator_version": "v6e", } @@ -57,9 +57,9 @@ class D: def test_deployment_options_dict_unknown_accelerator_type_raises(): - with pytest.raises(ValueError, match="Unknown accelerator_type"): + with pytest.raises(ValueError, match="Unknown accelerator kind"): - @deployment(accelerator_config={"accelerator_type": "xpu"}) + @deployment(accelerator_config={"kind": "xpu"}) class D: pass @@ -117,3 +117,7 @@ def test_replica_pg_shutdown_idempotent(): # Call again, should not raise or call shutdown again adapter_with_accel.shutdown() assert mock_slice_pg.shutdown.call_count == 1 + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/unit/test_deployment_scheduler.py b/python/ray/serve/tests/unit/test_deployment_scheduler.py index 0aebe93c5fa2..b5f7d6d278c4 100644 --- a/python/ray/serve/tests/unit/test_deployment_scheduler.py +++ b/python/ray/serve/tests/unit/test_deployment_scheduler.py @@ -578,7 +578,7 @@ def test_schedule_replica(): scheduling_strategy = None - def set_scheduling_strategy(actor_handle, placement_group): + def set_scheduling_strategy(actor_handle, placement_group=None, placement_group_manager=None): nonlocal scheduling_strategy scheduling_strategy = actor_handle._options["scheduling_strategy"] @@ -921,7 +921,7 @@ def test_downscale_single_deployment(): actor_resources={"CPU": 1}, actor_options={}, actor_init_args=(), - on_scheduled=lambda actor_handle, placement_group: actor_handle, + on_scheduled=lambda actor_handle, *args, **kwargs: actor_handle, ), ] }, @@ -1445,7 +1445,7 @@ def test_max_replicas_per_node(self): state = defaultdict(int) - def on_scheduled(actor_handle, placement_group): + def on_scheduled(actor_handle, placement_group=None, placement_group_manager=None): scheduling_strategy = actor_handle._options["scheduling_strategy"] if isinstance(scheduling_strategy, NodeAffinitySchedulingStrategy): state[scheduling_strategy.node_id] += 1 @@ -1603,7 +1603,7 @@ def test_custom_resources(self): # Despite trying to schedule on node that minimizes fragmentation, # should respect custom resources and schedule onto node2 - def on_scheduled(actor_handle, placement_group): + def on_scheduled(actor_handle, placement_group=None, placement_group_manager=None): scheduling_strategy = actor_handle._options["scheduling_strategy"] assert isinstance(scheduling_strategy, NodeAffinitySchedulingStrategy) assert scheduling_strategy.node_id == node_id_2 From 649229e3cb80dcc919046f73fdcabcefd2d234a6 Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Sat, 9 May 2026 01:43:57 +0000 Subject: [PATCH 05/26] fix import and var name Signed-off-by: Ryan O'Leary --- python/ray/serve/_private/deployment_scheduler.py | 6 +++--- python/ray/serve/tests/test_accelerator_config.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/ray/serve/_private/deployment_scheduler.py b/python/ray/serve/_private/deployment_scheduler.py index 5944ccefd469..bb156cbf7fac 100644 --- a/python/ray/serve/_private/deployment_scheduler.py +++ b/python/ray/serve/_private/deployment_scheduler.py @@ -687,10 +687,10 @@ def _schedule_replica( from ray.serve._private.default_impl import _ReplicaPlacementGroup if isinstance(pg_result, _ReplicaPlacementGroup): - pg = pg_result.placement_group + placement_group = pg_result.placement_group slice_pg = pg_result else: - pg = pg_result + placement_group = pg_result slice_pg = None except Exception: # We add a defensive exception here, so the controller can @@ -704,7 +704,7 @@ def _schedule_replica( ) return False scheduling_strategy = PlacementGroupSchedulingStrategy( - placement_group=pg, + placement_group=placement_group, placement_group_capture_child_tasks=True, ) target_labels = None diff --git a/python/ray/serve/tests/test_accelerator_config.py b/python/ray/serve/tests/test_accelerator_config.py index 8b0acde69a3b..f4dc83c5d67d 100644 --- a/python/ray/serve/tests/test_accelerator_config.py +++ b/python/ray/serve/tests/test_accelerator_config.py @@ -1,3 +1,4 @@ +import sys from unittest.mock import patch import pytest From 779d95d56e6a24a1448bb0e70674215ed9883e7c Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Sat, 9 May 2026 02:53:48 +0000 Subject: [PATCH 06/26] add missing import Signed-off-by: Ryan O'Leary scope down PR and remove gang scheduling changes Signed-off-by: Ryan O'Leary --- python/ray/serve/_private/common.py | 9 +- python/ray/serve/_private/default_impl.py | 19 +-- .../serve/_private/deployment_scheduler.py | 56 ++++----- python/ray/serve/_private/deployment_state.py | 69 +++-------- python/ray/serve/api.py | 3 +- .../serve/tests/test_accelerator_config.py | 40 ++++--- .../serve/tests/test_deployment_scheduler.py | 4 +- .../tests/unit/test_accelerator_config.py | 109 +++++++++++------- .../tests/unit/test_deployment_scheduler.py | 60 +++++++--- .../serve/tests/unit/test_deployment_state.py | 101 +++++----------- 10 files changed, 224 insertions(+), 246 deletions(-) diff --git a/python/ray/serve/_private/common.py b/python/ray/serve/_private/common.py index 324c0d8bb7dc..0e3c31eff1a6 100644 --- a/python/ray/serve/_private/common.py +++ b/python/ray/serve/_private/common.py @@ -1,7 +1,7 @@ import json from dataclasses import asdict, dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional from starlette.types import Scope @@ -19,7 +19,6 @@ from ray.util.placement_group import PlacementGroup if TYPE_CHECKING: - from ray.serve._private.default_impl import _ReplicaPlacementGroup from ray.serve.config import AcceleratorConfig REPLICA_ID_FULL_ID_STR_PREFIX = "SERVE_REPLICA::" @@ -902,6 +901,7 @@ class CreatePlacementGroupRequest: bundle_label_selector: Optional[List[Dict[str, str]]] = None fallback_strategy: Optional[List[Dict[str, Any]]] = None accelerator_config: Optional["AcceleratorConfig"] = None + lifetime: Optional[str] = "detached" @dataclass @@ -929,9 +929,6 @@ class GangPlacementGroupRequest: replica_pg_fallback_strategy: Optional[List[Dict[str, Any]]] = None """Fallback strategy for per-replica placement group bundles.""" - accelerator_config: Optional["AcceleratorConfig"] = None - """Optional accelerator configuration for TPU/GPU provisioning.""" - @dataclass class GangReservationResult: @@ -940,7 +937,7 @@ class GangReservationResult: success: bool """True when all gang PGs were created successfully.""" error_message: Optional[str] = None - gang_pgs: Optional[List[Union[PlacementGroup, "_ReplicaPlacementGroup"]]] = None + gang_pgs: Optional[List[PlacementGroup]] = None gang_ids: Optional[List[str]] = None gang_pg_names: Optional[List[str]] = None diff --git a/python/ray/serve/_private/default_impl.py b/python/ray/serve/_private/default_impl.py index 793de57bdaab..7899386d2424 100644 --- a/python/ray/serve/_private/default_impl.py +++ b/python/ray/serve/_private/default_impl.py @@ -1,7 +1,7 @@ import asyncio import logging from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple import ray from ray._common.constants import HEAD_NODE_RESOURCE_NAME @@ -58,7 +58,7 @@ @dataclass -class _ReplicaPlacementGroup: +class ReplicaPlacementGroup: """Internal Serve handle for a replica's placement group(s). Wraps the worker PG and any accelerator-specific cleanup hooks so the @@ -96,7 +96,7 @@ def shutdown(self) -> None: def _create_replica_placement_group( request: CreatePlacementGroupRequest, -) -> _ReplicaPlacementGroup: +) -> ReplicaPlacementGroup: """Internal entry point that supports accelerator-specific dispatch.""" accelerator_config = request.accelerator_config @@ -105,16 +105,16 @@ def _create_replica_placement_group( tpu_config=accelerator_config, strategy=request.strategy, name=request.name, - lifetime="detached", + lifetime=request.lifetime, bundle_label_selector=request.bundle_label_selector, ) - return _ReplicaPlacementGroup( + return ReplicaPlacementGroup( placement_group=slice_pg.placement_group, _slice_pg=slice_pg, ) pg = _default_create_placement_group(request) - return _ReplicaPlacementGroup(placement_group=pg) + return ReplicaPlacementGroup(placement_group=pg) def _default_create_tpu_placement_group( @@ -140,8 +140,9 @@ def create_cluster_node_info_cache(gcs_client: GcsClient) -> ClusterNodeInfoCach return DefaultClusterNodeInfoCache(gcs_client) -CreatePlacementGroupFn = Callable[ - [CreatePlacementGroupRequest], Union[PlacementGroup, _ReplicaPlacementGroup] +CreatePlacementGroupFn = Callable[[CreatePlacementGroupRequest], PlacementGroup] +CreateReplicaPlacementGroupFn = Callable[ + [CreatePlacementGroupRequest], ReplicaPlacementGroup ] @@ -153,7 +154,7 @@ def _default_create_placement_group( request.strategy, _soft_target_node_id=request.target_node_id, name=request.name, - lifetime="detached", + lifetime=request.lifetime, bundle_label_selector=request.bundle_label_selector, ) diff --git a/python/ray/serve/_private/deployment_scheduler.py b/python/ray/serve/_private/deployment_scheduler.py index bb156cbf7fac..1a9db2a31df4 100644 --- a/python/ray/serve/_private/deployment_scheduler.py +++ b/python/ray/serve/_private/deployment_scheduler.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from enum import Enum from functools import total_ordering -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple import ray from ray._raylet import node_labels_match_selector @@ -36,9 +36,6 @@ PlacementGroupSchedulingStrategy, ) -if TYPE_CHECKING: - from ray.serve._private.default_impl import _ReplicaPlacementGroup - logger = logging.getLogger(SERVE_LOGGER_NAME) @@ -641,20 +638,15 @@ def _schedule_replica( replica_id = scheduling_request.replica_id deployment_id = replica_id.deployment_id placement_group = None - slice_pg = None + replica_pg = None scheduling_strategy = default_scheduling_strategy if scheduling_request.gang_placement_group is not None: - # Gang scheduling -- use the reserved gang placement group - pg_wrapper = scheduling_request.gang_placement_group - placement_group = ( - pg_wrapper.placement_group - if hasattr(pg_wrapper, "placement_group") - else pg_wrapper - ) - # Preserve the wrapper for cleanup of head PGs - slice_pg = pg_wrapper if hasattr(pg_wrapper, "placement_group") else None + # Gang scheduling -- use the reserved gang placement group. + # Gang PGs are always bare PlacementGroup objects; accelerator + # deployments bypass gang scheduling entirely (see deployment_state). + placement_group = scheduling_request.gang_placement_group scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=placement_group, @@ -665,7 +657,7 @@ def _schedule_replica( target_labels = None target_node_id = None elif scheduling_request.placement_group_bundles is not None: - slice_pg = None + replica_pg = None placement_group_strategy = ( scheduling_request.placement_group_strategy if scheduling_request.placement_group_strategy @@ -683,15 +675,13 @@ def _schedule_replica( ), ) - # Inline import to avoid circular dependency with default_impl - from ray.serve._private.default_impl import _ReplicaPlacementGroup + from ray.serve._private.default_impl import ReplicaPlacementGroup - if isinstance(pg_result, _ReplicaPlacementGroup): - placement_group = pg_result.placement_group - slice_pg = pg_result - else: - placement_group = pg_result - slice_pg = None + assert isinstance( + pg_result, ReplicaPlacementGroup + ), "_create_placement_group_fn must return a ReplicaPlacementGroup." + pg = pg_result.placement_group + replica_pg = pg_result except Exception: # We add a defensive exception here, so the controller can # make progress even if the placement group isn't created. @@ -704,7 +694,7 @@ def _schedule_replica( ) return False scheduling_strategy = PlacementGroupSchedulingStrategy( - placement_group=placement_group, + placement_group=pg, placement_group_capture_child_tasks=True, ) target_labels = None @@ -748,14 +738,11 @@ def _schedule_replica( ) # Only clean up single-replica PGs. Gang PGs are managed elsewhere. - if scheduling_request.gang_placement_group is None: - if slice_pg is not None: - slice_pg.shutdown() - elif ( - placement_group is not None - and scheduling_request.placement_group_bundles is not None - ): - ray.util.remove_placement_group(placement_group) + if ( + scheduling_request.gang_placement_group is None + and replica_pg is not None + ): + replica_pg.shutdown() return False @@ -771,7 +758,7 @@ def _schedule_replica( scheduling_request.on_scheduled( actor_handle, placement_group=placement_group, - placement_group_manager=slice_pg, + placement_group_manager=replica_pg, ) return True @@ -857,7 +844,7 @@ def _prepare_gangs_for_deployment( # Flatten per-replica bundles to form a placement group to atomically reserve resources # required for each gang - gang_pgs: List[Union[PlacementGroup, "_ReplicaPlacementGroup"]] = [] + gang_pgs: List[PlacementGroup] = [] gang_ids: List[str] = [] gang_pg_names: List[str] = [] for gang_index in range(num_gangs): @@ -908,7 +895,6 @@ def _prepare_gangs_for_deployment( name=pg_name, bundle_label_selector=label_selector, fallback_strategy=fallback_strategy, - accelerator_config=request.accelerator_config, ) ) gang_pgs.append(pg) diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index 5280a97ba2a3..1cc597f420de 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -18,7 +18,7 @@ from ray.actor import ActorHandle if TYPE_CHECKING: - from ray.serve._private.default_impl import _ReplicaPlacementGroup + from ray.serve._private.default_impl import ReplicaPlacementGroup from ray.exceptions import ( RayActorError, RayError, @@ -1020,7 +1020,7 @@ def start( deployment_info: DeploymentInfo, assign_rank_callback: Callable[[ReplicaID], ReplicaRank], gang_placement_group: Optional[ - Union[PlacementGroup, "_ReplicaPlacementGroup"] + Union[PlacementGroup, "ReplicaPlacementGroup"] ] = None, gang_pg_index: Optional[int] = None, gang_context: Optional[GangContext] = None, @@ -1528,14 +1528,16 @@ def check_stopped(self) -> bool: # it was just killed above. if stopped: try: - # Gang PGs are shared and managed by the DeploymentStateManager. - # We do nothing here to avoid shutting them down prematurely. + # 1. Gang PGs are shared and managed by the DeploymentStateManager. + # We do nothing active here to avoid shutting them down prematurely. if self._gang_placement_group is not None: pass - # Replicas with accelerator/wrapper PGs handle their own shutdown. + + # 2. Replicas with accelerator/wrapper PGs handle their own shutdown. elif self._replica_pg is not None: self._replica_pg.shutdown() - # Standard single-replica placement groups. + + # 3. Standard single-replica placement groups. elif self._placement_group is not None: try: ray.util.remove_placement_group(self._placement_group) @@ -1546,7 +1548,7 @@ def check_stopped(self) -> bool: f"Placement group for {self._replica_id} was already removed." ) finally: - # Clear references to prevent dangling state. + # Always clear references to prevent memory leaks and dangling state. self._gang_placement_group = None self._replica_pg = None self._placement_group = None @@ -1929,7 +1931,7 @@ def start( deployment_info: DeploymentInfo, assign_rank_callback: Callable[[ReplicaID], ReplicaRank], gang_placement_group: Optional[ - Union[PlacementGroup, "_ReplicaPlacementGroup"] + Union[PlacementGroup, "ReplicaPlacementGroup"] ] = None, gang_pg_index: Optional[int] = None, gang_context: Optional[GangContext] = None, @@ -2883,8 +2885,6 @@ def __init__( # Updated on replica creation during upscaling and permanent removal during downscaling. self._gang_id_by_replica: Dict[ReplicaID, str] = {} self._replicas_by_gang_id: Dict[str, Set[ReplicaID]] = defaultdict(set) - # Track the actual PG objects to clean them up when the gang empties - self._gang_pg_by_id: Dict[str, Any] = {} # Deployment-scoped actor lifecycle (per deployment) self._deployment_actors = DeploymentActorContainer(self._id) @@ -3266,6 +3266,11 @@ def get_gang_config(self): @property def _is_gang_deployment(self) -> bool: """Returns True if this deployment uses gang scheduling.""" + if ( + self._target_state is not None + and self._target_state.info.deployment_config.accelerator_config is not None + ): + return False return self.get_gang_config() is not None def _get_target_replica_delta(self) -> int: @@ -3990,8 +3995,6 @@ def _add_upscale_gang_replicas( ) for gang_pg, gang_id, pg_name in zip(gang_pgs, gang_ids, gang_pg_names): - # Track the PG object for later cleanup - self._gang_pg_by_id[gang_id] = gang_pg member_replica_ids = [ ReplicaID(get_random_string(), deployment_id=self._id) @@ -4368,9 +4371,8 @@ def _register_gang_replica(self, replica_id: ReplicaID, gang_id: str) -> None: self._gang_id_by_replica[replica_id] = gang_id self._replicas_by_gang_id[gang_id].add(replica_id) - def _unregister_gang_replica(self, replica: "DeploymentReplica") -> None: + def _unregister_gang_replica(self, replica_id: ReplicaID) -> None: """Remove a replica from the gang membership bookkeeping.""" - replica_id = replica.replica_id gang_id = self._gang_id_by_replica.pop(replica_id, None) if gang_id is not None: members = self._replicas_by_gang_id.get(gang_id) @@ -4379,42 +4381,6 @@ def _unregister_gang_replica(self, replica: "DeploymentReplica") -> None: if not members: self._replicas_by_gang_id.pop(gang_id, None) - gang_pg = self._gang_pg_by_id.pop(gang_id, None) - - # Fallback for controller actor recovery, if the in-memory `_gang_pg_by_id` dict - # is empty, fetch the placement group from GCS. - if ( - gang_pg is None - and replica.gang_context - and replica.gang_context.pg_name - ): - try: - gang_pg = ray.util.get_placement_group( - replica.gang_context.pg_name - ) - except ValueError: - pass # PG doesn't exist in Ray, nothing to clean up - - if gang_pg is not None: - try: - if hasattr(gang_pg, "shutdown"): - gang_pg.shutdown() - else: - placement_group = ( - gang_pg.placement_group - if hasattr(gang_pg, "placement_group") - else gang_pg - ) - ray.util.remove_placement_group(placement_group) - except ValueError: - logger.debug( - f"Gang placement group for {gang_id} was already removed." - ) - except Exception: - logger.exception( - f"Failed to remove gang placement group for {gang_id}." - ) - def _clear_health_gauge_cache(self, replica_unique_id: str) -> None: """Remove a replica from the health-gauge cache (after it has fully stopped and been removed from tracking).""" @@ -4795,7 +4761,7 @@ def _check_and_update_transitioning_replicas(self): f"Released rank from replica {replica_id} in deployment {self._id}" ) self._autoscaling_state_manager.on_replica_stopped(replica.replica_id) - self._unregister_gang_replica(replica) + self._unregister_gang_replica(replica.replica_id) def _reconfigure_replicas_with_new_ranks( self, replicas_to_reconfigure: List["DeploymentReplica"] @@ -6043,7 +6009,6 @@ def _reserve_gang_placement_groups( replica_pg_fallback_strategy=( replica_config.placement_group_fallback_strategy ), - accelerator_config=deployment_state._target_state.info.deployment_config.accelerator_config, ) if not gang_requests: diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 9b0028e36407..0be229666855 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -460,8 +460,7 @@ def _resolve_accelerator_config( if kind == "tpu": return TPUAcceleratorConfig(**value) raise ValueError( - f"Unknown accelerator kind {kind!r}. " - f"Supported types: 'tpu'." + f"Unknown accelerator kind {kind!r}. " f"Supported types: 'tpu'." ) raise TypeError( f"accelerator_config must be a dict or AcceleratorConfig, got {type(value)}." diff --git a/python/ray/serve/tests/test_accelerator_config.py b/python/ray/serve/tests/test_accelerator_config.py index f4dc83c5d67d..25171ae5dd62 100644 --- a/python/ray/serve/tests/test_accelerator_config.py +++ b/python/ray/serve/tests/test_accelerator_config.py @@ -1,5 +1,5 @@ import sys -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -8,13 +8,13 @@ from ray.cluster_utils import Cluster from ray.serve._private.common import CreatePlacementGroupRequest from ray.serve._private.default_impl import ( + ReplicaPlacementGroup, _create_replica_placement_group, - _ReplicaPlacementGroup, ) from ray.serve.config import TPUAcceleratorConfig -@pytest.fixture +@pytest.fixture(scope="module") def mock_tpu_cluster(): # Simulates a Ray cluster with a multi-host TPU v6e-16 slice (4x4 topology). pod_type = "v6e-16" @@ -64,17 +64,18 @@ def test_tpu_accelerator_config_integration(mock_tpu_cluster): tpu_config = TPUAcceleratorConfig(topology="4x4", accelerator_version="v6e") request = CreatePlacementGroupRequest( - bundles=[{"CPU": 1}], # Ignored since accel_config will override it + bundles=[{"CPU": 1}], strategy="SPREAD", target_node_id=None, name="test-tpu-pg", accelerator_config=tpu_config, + lifetime="detached", ) # This should call _create_tpu_placement_group and return a wrapper replica_pg = _create_replica_placement_group(request) - assert isinstance(replica_pg, _ReplicaPlacementGroup) + assert isinstance(replica_pg, ReplicaPlacementGroup) assert replica_pg._slice_pg is not None # Verify the placement group is ready @@ -84,12 +85,16 @@ def test_tpu_accelerator_config_integration(mock_tpu_cluster): replica_pg.shutdown() assert replica_pg._slice_pg is None + # Verify idempotency of shutdown logic + replica_pg.shutdown() + assert replica_pg._slice_pg is None -def test_tpu_accelerator_config_timeout_cleanup(mock_tpu_cluster): - """Test that SlicePlacementGroup cleans up head PGs on timeout.""" - # Request a topology that requires 8 hosts (v6e-32) when cluster only has 4. - tpu_config = TPUAcceleratorConfig(topology="4x8", accelerator_version="v6e") +def test_tpu_accelerator_config_partial_failure_cleanup(mock_tpu_cluster): + """Test that SlicePlacementGroup cleans up head PGs if a multi-slice reservation fails.""" + + # Request 2 slices to test partial failure cleanup + tpu_config = TPUAcceleratorConfig(topology="4x4", accelerator_version="v6e", num_slices=2) request = CreatePlacementGroupRequest( bundles=[{"CPU": 1}], @@ -97,15 +102,24 @@ def test_tpu_accelerator_config_timeout_cleanup(mock_tpu_cluster): target_node_id=None, name="test-tpu-timeout-pg", accelerator_config=tpu_config, + lifetime="detached", ) - # Patch timeout to be short, and mock remove_placement_group to verify cleanup - with patch("ray._private.accelerators.tpu.remove_placement_group") as mock_remove: - with patch("ray.util.tpu.DEFAULT_TPU_HEAD_RESERVATION_TIMEOUT_S", 2.0): + # Patch remove_placement_group where it is USED (ray.util.tpu) + with patch("ray.util.tpu.remove_placement_group") as mock_remove: + with patch("ray.util.tpu.reserve_tpu_slice") as mock_reserve: + # Succeed for first slice, fail for second + mock_head_pg = MagicMock() + mock_reserve.side_effect = [ + ("slice-1", mock_head_pg), + TimeoutError("Failed to reserve TPU head"), + ] + with pytest.raises(TimeoutError, match="Failed to reserve TPU head"): _create_replica_placement_group(request) - assert mock_remove.called + # Verify that the first slice's head PG was cleanly rolled back + mock_remove.assert_called_once_with(mock_head_pg) if __name__ == "__main__": diff --git a/python/ray/serve/tests/test_deployment_scheduler.py b/python/ray/serve/tests/test_deployment_scheduler.py index c6332b1230db..7c5ac5bbf2fb 100644 --- a/python/ray/serve/tests/test_deployment_scheduler.py +++ b/python/ray/serve/tests/test_deployment_scheduler.py @@ -68,7 +68,9 @@ def test_spread_deployment_scheduling_policy_upscale( replica_actor_handles = [] replica_placement_groups = [] - def on_scheduled(actor_handle, placement_group=None, placement_group_manager=None): + def on_scheduled( + actor_handle, placement_group=None, placement_group_manager=None + ): replica_actor_handles.append(actor_handle) replica_placement_groups.append(placement_group) diff --git a/python/ray/serve/tests/unit/test_accelerator_config.py b/python/ray/serve/tests/unit/test_accelerator_config.py index 2950673218ae..4b1a37f23649 100644 --- a/python/ray/serve/tests/unit/test_accelerator_config.py +++ b/python/ray/serve/tests/unit/test_accelerator_config.py @@ -1,3 +1,4 @@ +import sys from unittest.mock import MagicMock, patch import pytest @@ -5,11 +6,13 @@ from ray.serve._private.common import CreatePlacementGroupRequest from ray.serve._private.default_impl import ( + ReplicaPlacementGroup, _create_replica_placement_group, - _ReplicaPlacementGroup, + _default_create_placement_group, ) from ray.serve.api import deployment from ray.serve.config import TPUAcceleratorConfig +from ray.util.placement_group import PlacementGroup def test_tpu_accelerator_config_construction(): @@ -64,59 +67,85 @@ class D: pass -def test_create_replica_placement_group_tpu_dispatch(): - config = TPUAcceleratorConfig(topology="4x4", accelerator_version="v6e") +@pytest.mark.parametrize( + "invalid_kwargs", + [ + {"topology": "4x4"}, # missing accelerator_version + {"accelerator_version": "v6e"}, # missing topology + {"topology": 123, "accelerator_version": "v6e"}, # topology should be str + { + "topology": "4x4", + "accelerator_version": "v6e", + "num_slices": "two", + }, # num_slices should be int + { + "topology": "4x4", + "accelerator_version": "v6e", + "num_slices": 0, + }, # num_slices must be >= 1 + ], +) +def test_tpu_accelerator_config_validation(invalid_kwargs): + with pytest.raises(ValidationError): + TPUAcceleratorConfig(**invalid_kwargs) + + +@pytest.mark.parametrize( + "creation_fn, expects_wrapper", + [ + (_default_create_placement_group, False), + (_create_replica_placement_group, True), + ], +) +def test_placement_group_creation_types(creation_fn, expects_wrapper): + """Verify that external overrides return bare PGs while internal ones return wrappers.""" request = CreatePlacementGroupRequest( - bundles=[], + bundles=[{"CPU": 1.0}], strategy="SPREAD", target_node_id="", name="test", - accelerator_config=config, ) - fake_slice_pg = MagicMock() - fake_slice_pg.placement_group = MagicMock() - - with patch( - "ray.serve._private.default_impl.slice_placement_group" - ) as mock_slice_pg: - mock_slice_pg.return_value = fake_slice_pg + mock_pg = MagicMock(spec=PlacementGroup) + with patch("ray.util.placement_group", return_value=mock_pg): + result = creation_fn(request) - result = _create_replica_placement_group(request) + if expects_wrapper: + assert isinstance(result, ReplicaPlacementGroup) + assert result.placement_group == mock_pg + else: + assert result == mock_pg + assert not isinstance(result, ReplicaPlacementGroup) - assert mock_slice_pg.called - assert result._slice_pg is not None - assert result.placement_group == fake_slice_pg.placement_group - mock_slice_pg.assert_called_once() - -def test_replica_pg_shutdown_idempotent(): - """Test that _ReplicaPlacementGroup shutdown is idempotent.""" - # Path 1: No accelerator +@pytest.mark.parametrize("with_accelerator", [False, True]) +def test_replica_pg_shutdown_idempotent(with_accelerator): + """Test that ReplicaPlacementGroup shutdown is idempotent.""" mock_pg = MagicMock() - adapter = _ReplicaPlacementGroup(placement_group=mock_pg) - with patch("ray.serve._private.default_impl.remove_placement_group") as mock_remove: - adapter.shutdown() - mock_remove.assert_called_once_with(mock_pg) + if with_accelerator: + mock_slice_pg = MagicMock() + adapter = ReplicaPlacementGroup( + placement_group=mock_pg, _slice_pg=mock_slice_pg + ) - # Call again, should not raise or call remove again adapter.shutdown() - assert mock_remove.call_count == 1 + mock_slice_pg.shutdown.assert_called_once() + assert adapter._slice_pg is None - # Path 2: With accelerator - mock_slice_pg = MagicMock() - adapter_with_accel = _ReplicaPlacementGroup( - placement_group=mock_pg, _slice_pg=mock_slice_pg - ) - - adapter_with_accel.shutdown() - mock_slice_pg.shutdown.assert_called_once() - assert adapter_with_accel._slice_pg is None - - # Call again, should not raise or call shutdown again - adapter_with_accel.shutdown() - assert mock_slice_pg.shutdown.call_count == 1 + adapter.shutdown() + assert mock_slice_pg.shutdown.call_count == 1 + else: + adapter = ReplicaPlacementGroup(placement_group=mock_pg) + + with patch( + "ray.serve._private.default_impl.remove_placement_group" + ) as mock_remove: + adapter.shutdown() + mock_remove.assert_called_once_with(mock_pg) + + adapter.shutdown() + assert mock_remove.call_count == 1 if __name__ == "__main__": diff --git a/python/ray/serve/tests/unit/test_deployment_scheduler.py b/python/ray/serve/tests/unit/test_deployment_scheduler.py index b5f7d6d278c4..b51928a197cf 100644 --- a/python/ray/serve/tests/unit/test_deployment_scheduler.py +++ b/python/ray/serve/tests/unit/test_deployment_scheduler.py @@ -570,7 +570,9 @@ def test_schedule_replica(): scheduler = default_impl.create_deployment_scheduler( cluster_node_info_cache, head_node_id_override="fake-head-node-id", - create_placement_group_fn_override=lambda request: MockPlacementGroup(request), + create_placement_group_fn_override=lambda request: default_impl.ReplicaPlacementGroup( + placement_group=MockPlacementGroup(request) + ), ) scheduler.on_deployment_created(d_id, SpreadDeploymentSchedulingPolicy()) @@ -578,7 +580,9 @@ def test_schedule_replica(): scheduling_strategy = None - def set_scheduling_strategy(actor_handle, placement_group=None, placement_group_manager=None): + def set_scheduling_strategy( + actor_handle, placement_group=None, placement_group_manager=None + ): nonlocal scheduling_strategy scheduling_strategy = actor_handle._options["scheduling_strategy"] @@ -1246,7 +1250,10 @@ def test_basic(self): assert len(on_scheduled_mock.call_args_list) == 2 for call in on_scheduled_mock.call_args_list: - assert call.kwargs == {"placement_group": None} + assert call.kwargs == { + "placement_group": None, + "placement_group_manager": None, + } assert len(call.args) == 1 scheduling_strategy = call.args[0]._options["scheduling_strategy"] assert isinstance(scheduling_strategy, NodeAffinitySchedulingStrategy) @@ -1254,7 +1261,7 @@ def test_basic(self): assert len(on_scheduled_mock2.call_args_list) == 1 call = on_scheduled_mock2.call_args_list[0] - assert call.kwargs == {"placement_group": None} + assert call.kwargs == {"placement_group": None, "placement_group_manager": None} assert len(call.args) == 1 scheduling_strategy = call.args[0]._options["scheduling_strategy"] assert isinstance(scheduling_strategy, NodeAffinitySchedulingStrategy) @@ -1270,9 +1277,9 @@ def test_placement_groups(self): scheduler = default_impl.create_deployment_scheduler( cluster_node_info_cache, head_node_id_override="fake-head-node-id", - create_placement_group_fn_override=lambda *args, **kwargs: MockPlacementGroup( # noqa - *args, **kwargs - ), + create_placement_group_fn_override=lambda *args, **kwargs: default_impl.ReplicaPlacementGroup( + placement_group=MockPlacementGroup(*args, **kwargs) + ), # noqa ) _ = ray.util.placement_group @@ -1413,7 +1420,10 @@ def test_heterogeneous_resources(self): scheduling_strategy = call.args[0]._options["scheduling_strategy"] assert isinstance(scheduling_strategy, NodeAffinitySchedulingStrategy) assert scheduling_strategy.node_id == node_id_1 - assert call.kwargs == {"placement_group": None} + assert call.kwargs == { + "placement_group": None, + "placement_group_manager": None, + } def test_max_replicas_per_node(self): """Test that at most `max_replicas_per_node` number of replicas @@ -1431,9 +1441,9 @@ def test_max_replicas_per_node(self): scheduler = default_impl.create_deployment_scheduler( cluster_node_info_cache, head_node_id_override="fake-head-node-id", - create_placement_group_fn_override=lambda *args, **kwargs: MockPlacementGroup( # noqa - *args, **kwargs - ), + create_placement_group_fn_override=lambda *args, **kwargs: default_impl.ReplicaPlacementGroup( + placement_group=MockPlacementGroup(*args, **kwargs) + ), # noqa ) scheduler.on_deployment_created(d_id1, SpreadDeploymentSchedulingPolicy()) scheduler.on_deployment_deployed( @@ -1445,7 +1455,9 @@ def test_max_replicas_per_node(self): state = defaultdict(int) - def on_scheduled(actor_handle, placement_group=None, placement_group_manager=None): + def on_scheduled( + actor_handle, placement_group=None, placement_group_manager=None + ): scheduling_strategy = actor_handle._options["scheduling_strategy"] if isinstance(scheduling_strategy, NodeAffinitySchedulingStrategy): state[scheduling_strategy.node_id] += 1 @@ -1589,9 +1601,9 @@ def test_custom_resources(self): scheduler = default_impl.create_deployment_scheduler( cluster_node_info_cache, head_node_id_override="fake-head-node-id", - create_placement_group_fn_override=lambda *args, **kwargs: MockPlacementGroup( # noqa - *args, **kwargs - ), + create_placement_group_fn_override=lambda *args, **kwargs: default_impl.ReplicaPlacementGroup( + placement_group=MockPlacementGroup(*args, **kwargs) + ), # noqa ) scheduler.on_deployment_created(d_id, SpreadDeploymentSchedulingPolicy()) scheduler.on_deployment_deployed( @@ -1603,7 +1615,9 @@ def test_custom_resources(self): # Despite trying to schedule on node that minimizes fragmentation, # should respect custom resources and schedule onto node2 - def on_scheduled(actor_handle, placement_group=None, placement_group_manager=None): + def on_scheduled( + actor_handle, placement_group=None, placement_group_manager=None + ): scheduling_strategy = actor_handle._options["scheduling_strategy"] assert isinstance(scheduling_strategy, NodeAffinitySchedulingStrategy) assert scheduling_strategy.node_id == node_id_2 @@ -1718,7 +1732,9 @@ def fail_once_create_pg(request): call_count += 1 if call_count == 1: raise RuntimeError("Simulated PG creation failure") - return MockPlacementGroup(request) + return default_impl.ReplicaPlacementGroup( + placement_group=MockPlacementGroup(request) + ) scheduler = default_impl.create_deployment_scheduler( cluster_node_info_cache, @@ -1858,7 +1874,10 @@ def test_pack_prefers_newly_non_idle_node(self): strategy1 = call1.args[0]._options["scheduling_strategy"] assert isinstance(strategy1, NodeAffinitySchedulingStrategy) assert strategy1.node_id == node_id_1 - assert call1.kwargs == {"placement_group": None} + assert call1.kwargs == { + "placement_group": None, + "placement_group_manager": None, + } # The CPU replica should also go to node 1 (now non-idle) rather # than node 2 (idle but tighter fit). The PACK scheduler prefers @@ -1868,7 +1887,10 @@ def test_pack_prefers_newly_non_idle_node(self): strategy2 = call2.args[0]._options["scheduling_strategy"] assert isinstance(strategy2, NodeAffinitySchedulingStrategy) assert strategy2.node_id == node_id_1 - assert call2.kwargs == {"placement_group": None} + assert call2.kwargs == { + "placement_group": None, + "placement_group_manager": None, + } class TestScheduleGangPlacementGroups: diff --git a/python/ray/serve/tests/unit/test_deployment_state.py b/python/ray/serve/tests/unit/test_deployment_state.py index f2e86663afa2..7f98cd324b9e 100644 --- a/python/ray/serve/tests/unit/test_deployment_state.py +++ b/python/ray/serve/tests/unit/test_deployment_state.py @@ -65,7 +65,11 @@ get_capacity_adjusted_num_replicas, get_random_string, ) -from ray.serve.config import DeploymentActorConfig, GangSchedulingConfig +from ray.serve.config import ( + DeploymentActorConfig, + GangSchedulingConfig, + TPUAcceleratorConfig, +) from ray.serve.schema import ReplicaRank from ray.util.placement_group import validate_placement_group @@ -8549,74 +8553,6 @@ def test_gang_downscale_stops_complete_gangs(self, mock_deployment_state_manager ) assert ds.curr_status_info.status == DeploymentStatus.HEALTHY - def test_gang_pg_cleanup_on_downscale(self, mock_deployment_state_manager): - """Verify Gang PG is only destroyed when the last replica finishes stopping.""" - create_dsm, _, _, _ = mock_deployment_state_manager - dsm: DeploymentStateManager = create_dsm( - create_placement_group_fn_override=lambda *args, **kwargs: Mock(), - ) - gang_size = 2 - initial_replicas = 2 - deployment_id = DeploymentID(name="gang_pg_cleanup", app_name="app") - - info, version = deployment_info( - num_replicas=initial_replicas, - version="v1", - gang_scheduling_config=GangSchedulingConfig(gang_size=gang_size), - ) - dsm.deploy(deployment_id, info) - ds = dsm._deployment_states[deployment_id] - - # Mock the gang PG to track when shutdown method is called - mock_gang_pg = MagicMock() - dsm._deployment_scheduler.schedule_gang_placement_groups = Mock( - return_value={ - deployment_id: GangReservationResult( - success=True, - gang_pgs=[mock_gang_pg], - gang_ids=["g1"], - gang_pg_names=["SERVE_GANG::g1"], - ) - } - ) - - # Start all replicas - dsm.update() - for replica in ds._replicas.get([ReplicaState.STARTING]): - replica._actor.set_ready() - dsm.update() - check_counts( - ds, total=initial_replicas, by_state=[(ReplicaState.RUNNING, 2, version)] - ) - - # Scale down to 0 to trigger gang removal - new_info, _ = deployment_info( - num_replicas=0, - version="v1", - gang_scheduling_config=GangSchedulingConfig(gang_size=gang_size), - ) - dsm.deploy(deployment_id, new_info) - dsm.update() - - stopping = ds._replicas.get([ReplicaState.STOPPING]) - assert len(stopping) == 2 - - # Mark first replica done stopping - stopping[0]._actor.set_done_stopping() - dsm.update() - - # The gang PG is not removed yet, because sibling is still stopping - mock_gang_pg.shutdown.assert_not_called() - assert "g1" in ds._gang_pg_by_id - - # Mark second replica done stopping - stopping[1]._actor.set_done_stopping() - dsm.update() - - # The gang PG is removed now that the gang is empty - mock_gang_pg.shutdown.assert_called_once() - assert "g1" not in ds._gang_pg_by_id - def test_gang_downscale_prefers_pending_gang(self, mock_deployment_state_manager): """Downscaling prefers the gang that still has a pending replica.""" create_dsm, _, _, _ = mock_deployment_state_manager @@ -8690,6 +8626,33 @@ def test_gang_downscale_prefers_pending_gang(self, mock_deployment_state_manager ) assert ds.curr_status_info.status == DeploymentStatus.HEALTHY + def test_accelerator_deployment_skips_gang_setup( + self, mock_deployment_state_manager + ): + """A deployment with accelerator_config should not create gang PG state.""" + create_dsm, _, _, _ = mock_deployment_state_manager + dsm: DeploymentStateManager = create_dsm() + deployment_id = DeploymentID(name="accelerator_skips_gang", app_name="app") + + info, version = deployment_info( + num_replicas=2, + version="v1", + accelerator_config=TPUAcceleratorConfig( + topology="4x4", accelerator_version="v6e" + ), + gang_scheduling_config=GangSchedulingConfig(gang_size=2), + ) + + dsm.deploy(deployment_id, info) + ds = dsm._deployment_states[deployment_id] + ds._add_upscale_gang_replicas = MagicMock() + ds._add_upscale_replicas = MagicMock(return_value=[]) + + dsm.update() + + ds._add_upscale_gang_replicas.assert_not_called() + ds._add_upscale_replicas.assert_called_once() + class TestGangHealthCheck: def _deploy_gang(self, mock_deployment_state_manager, gang_size, num_replicas): From 3a1d72439827d31974c4a2d782ca449ac7d11c98 Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Mon, 11 May 2026 19:49:52 +0000 Subject: [PATCH 07/26] lint and remove unused type alias Signed-off-by: Ryan O'Leary --- python/ray/serve/_private/default_impl.py | 3 --- python/ray/serve/tests/test_accelerator_config.py | 4 +++- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/ray/serve/_private/default_impl.py b/python/ray/serve/_private/default_impl.py index 7899386d2424..a5405991325b 100644 --- a/python/ray/serve/_private/default_impl.py +++ b/python/ray/serve/_private/default_impl.py @@ -141,9 +141,6 @@ def create_cluster_node_info_cache(gcs_client: GcsClient) -> ClusterNodeInfoCach CreatePlacementGroupFn = Callable[[CreatePlacementGroupRequest], PlacementGroup] -CreateReplicaPlacementGroupFn = Callable[ - [CreatePlacementGroupRequest], ReplicaPlacementGroup -] def _default_create_placement_group( diff --git a/python/ray/serve/tests/test_accelerator_config.py b/python/ray/serve/tests/test_accelerator_config.py index 25171ae5dd62..89ccded4db2e 100644 --- a/python/ray/serve/tests/test_accelerator_config.py +++ b/python/ray/serve/tests/test_accelerator_config.py @@ -94,7 +94,9 @@ def test_tpu_accelerator_config_partial_failure_cleanup(mock_tpu_cluster): """Test that SlicePlacementGroup cleans up head PGs if a multi-slice reservation fails.""" # Request 2 slices to test partial failure cleanup - tpu_config = TPUAcceleratorConfig(topology="4x4", accelerator_version="v6e", num_slices=2) + tpu_config = TPUAcceleratorConfig( + topology="4x4", accelerator_version="v6e", num_slices=2 + ) request = CreatePlacementGroupRequest( bundles=[{"CPU": 1}], From d60312358732460240d358d601f746bda539de61 Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Mon, 11 May 2026 20:06:49 +0000 Subject: [PATCH 08/26] add comment to inline import Signed-off-by: Ryan O'Leary --- python/ray/serve/_private/deployment_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/serve/_private/deployment_scheduler.py b/python/ray/serve/_private/deployment_scheduler.py index 1a9db2a31df4..72783df6b1f0 100644 --- a/python/ray/serve/_private/deployment_scheduler.py +++ b/python/ray/serve/_private/deployment_scheduler.py @@ -674,7 +674,7 @@ def _schedule_replica( accelerator_config=scheduling_request.accelerator_config, ), ) - + # Import ReplicaPlacementGroup inline here to avoid circular dependency with default_impl from ray.serve._private.default_impl import ReplicaPlacementGroup assert isinstance( From 80162c25325b894aac19ce5cb9bae7499e85edbc Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Mon, 11 May 2026 20:08:46 +0000 Subject: [PATCH 09/26] Tighten typing for placement-group fields after PR restructure Signed-off-by: Ryan O'Leary --- python/ray/serve/_private/deployment_state.py | 14 ++--- .../tests/unit/test_accelerator_config.py | 53 +++++++++++++------ 2 files changed, 41 insertions(+), 26 deletions(-) diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index 1cc597f420de..e0be76cdea2f 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -10,7 +10,7 @@ from copy import copy from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple import ray from ray import ObjectRef, cloudpickle @@ -782,7 +782,7 @@ def __init__( self._last_record_routing_stats_time: float = 0.0 self._has_user_routing_stats_method: bool = False self._ingress: bool = False - self._replica_pg = None + self._replica_pg: Optional["ReplicaPlacementGroup"] = None self._gang_placement_group = None self._gang_pg_index = None @@ -1019,9 +1019,7 @@ def start( self, deployment_info: DeploymentInfo, assign_rank_callback: Callable[[ReplicaID], ReplicaRank], - gang_placement_group: Optional[ - Union[PlacementGroup, "ReplicaPlacementGroup"] - ] = None, + gang_placement_group: Optional[PlacementGroup] = None, gang_pg_index: Optional[int] = None, gang_context: Optional[GangContext] = None, ) -> ReplicaSchedulingRequest: @@ -1173,7 +1171,7 @@ def on_scheduled( self, actor_handle: ActorHandle, placement_group: Optional[PlacementGroup] = None, - placement_group_manager: Optional[Any] = None, + placement_group_manager: Optional["ReplicaPlacementGroup"] = None, ): self._actor_handle = actor_handle self._placement_group = placement_group @@ -1930,9 +1928,7 @@ def start( self, deployment_info: DeploymentInfo, assign_rank_callback: Callable[[ReplicaID], ReplicaRank], - gang_placement_group: Optional[ - Union[PlacementGroup, "ReplicaPlacementGroup"] - ] = None, + gang_placement_group: Optional[PlacementGroup] = None, gang_pg_index: Optional[int] = None, gang_context: Optional[GangContext] = None, ) -> ReplicaSchedulingRequest: diff --git a/python/ray/serve/tests/unit/test_accelerator_config.py b/python/ray/serve/tests/unit/test_accelerator_config.py index 4b1a37f23649..0450760adaf4 100644 --- a/python/ray/serve/tests/unit/test_accelerator_config.py +++ b/python/ray/serve/tests/unit/test_accelerator_config.py @@ -8,7 +8,6 @@ from ray.serve._private.default_impl import ( ReplicaPlacementGroup, _create_replica_placement_group, - _default_create_placement_group, ) from ray.serve.api import deployment from ray.serve.config import TPUAcceleratorConfig @@ -90,32 +89,52 @@ def test_tpu_accelerator_config_validation(invalid_kwargs): TPUAcceleratorConfig(**invalid_kwargs) -@pytest.mark.parametrize( - "creation_fn, expects_wrapper", - [ - (_default_create_placement_group, False), - (_create_replica_placement_group, True), - ], -) -def test_placement_group_creation_types(creation_fn, expects_wrapper): - """Verify that external overrides return bare PGs while internal ones return wrappers.""" +@pytest.mark.parametrize("with_accelerator", [False, True]) +def test_placement_group_creation_types(with_accelerator): + """Verify that _create_replica_placement_group always returns wrappers.""" + + accelerator_config = None + if with_accelerator: + accelerator_config = TPUAcceleratorConfig( + topology="4x4", accelerator_version="v6e" + ) + request = CreatePlacementGroupRequest( bundles=[{"CPU": 1.0}], strategy="SPREAD", target_node_id="", name="test", + accelerator_config=accelerator_config, ) mock_pg = MagicMock(spec=PlacementGroup) - with patch("ray.util.placement_group", return_value=mock_pg): - result = creation_fn(request) - if expects_wrapper: - assert isinstance(result, ReplicaPlacementGroup) - assert result.placement_group == mock_pg + # Accelerator path. Returns a wrapper holding a SlicePlacementGroup. + if with_accelerator: + mock_slice_pg = MagicMock() + mock_slice_pg.placement_group = mock_pg + with patch( + "ray.serve._private.default_impl.slice_placement_group", + return_value=mock_slice_pg, + ): + result = _create_replica_placement_group(request) + # Non-accelerator path. Returns a wrapper holding a regular PG. + else: + with patch("ray.util.placement_group", return_value=mock_pg): + result = _create_replica_placement_group(request) + + assert isinstance(result, ReplicaPlacementGroup), ( + "_create_replica_placement_group must always return a ReplicaPlacementGroup, " + "regardless of whether accelerator_config is set." + ) + assert result.placement_group == mock_pg + + if with_accelerator: + assert ( + result._slice_pg is not None + ), "Accelerator path must set _slice_pg for cleanup tracking." else: - assert result == mock_pg - assert not isinstance(result, ReplicaPlacementGroup) + assert result._slice_pg is None, "Non-accelerator path must not set _slice_pg." @pytest.mark.parametrize("with_accelerator", [False, True]) From e45ee822a357ce3826d59291f81a2116fb792e51 Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Mon, 11 May 2026 20:49:08 +0000 Subject: [PATCH 10/26] remove added whitespace Signed-off-by: Ryan O'Leary --- python/ray/serve/_private/deployment_state.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index e0be76cdea2f..d6d2aa3812c3 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -3991,7 +3991,6 @@ def _add_upscale_gang_replicas( ) for gang_pg, gang_id, pg_name in zip(gang_pgs, gang_ids, gang_pg_names): - member_replica_ids = [ ReplicaID(get_random_string(), deployment_id=self._id) for _ in range(gang_size) From f96ef1e86189812e8dce5e6f105857735cee0a2c Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Mon, 11 May 2026 20:59:07 +0000 Subject: [PATCH 11/26] fix external placement group function override Signed-off-by: Ryan O'Leary --- .../serve/_private/deployment_scheduler.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/python/ray/serve/_private/deployment_scheduler.py b/python/ray/serve/_private/deployment_scheduler.py index 72783df6b1f0..3bddfacc5eb1 100644 --- a/python/ray/serve/_private/deployment_scheduler.py +++ b/python/ray/serve/_private/deployment_scheduler.py @@ -677,11 +677,12 @@ def _schedule_replica( # Import ReplicaPlacementGroup inline here to avoid circular dependency with default_impl from ray.serve._private.default_impl import ReplicaPlacementGroup - assert isinstance( - pg_result, ReplicaPlacementGroup - ), "_create_placement_group_fn must return a ReplicaPlacementGroup." - pg = pg_result.placement_group - replica_pg = pg_result + if isinstance(pg_result, ReplicaPlacementGroup): + placement_group = pg_result.placement_group + replica_pg = pg_result + else: + placement_group = pg_result + replica_pg = None except Exception: # We add a defensive exception here, so the controller can # make progress even if the placement group isn't created. @@ -694,7 +695,7 @@ def _schedule_replica( ) return False scheduling_strategy = PlacementGroupSchedulingStrategy( - placement_group=pg, + placement_group=placement_group, placement_group_capture_child_tasks=True, ) target_labels = None @@ -738,11 +739,14 @@ def _schedule_replica( ) # Only clean up single-replica PGs. Gang PGs are managed elsewhere. - if ( - scheduling_request.gang_placement_group is None - and replica_pg is not None - ): - replica_pg.shutdown() + if scheduling_request.gang_placement_group is None: + if replica_pg is not None: + replica_pg.shutdown() + elif ( + placement_group is not None + and scheduling_request.placement_group_bundles is not None + ): + ray.util.remove_placement_group(placement_group) return False From 5d95c7882b7e5d9e99776e240b9ef298aaf2b740 Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Mon, 11 May 2026 23:00:22 +0000 Subject: [PATCH 12/26] add resources_per_bundle and fix bundles defaulting logic, also add tests Signed-off-by: Ryan O'Leary --- python/ray/serve/_private/common.py | 18 ++++- python/ray/serve/_private/default_impl.py | 22 ++++- .../serve/_private/deployment_scheduler.py | 10 ++- python/ray/serve/config.py | 15 ++++ .../tests/unit/test_accelerator_config.py | 80 ++++++++++++++++++ .../tests/unit/test_deployment_scheduler.py | 81 +++++++++++++++++++ 6 files changed, 220 insertions(+), 6 deletions(-) diff --git a/python/ray/serve/_private/common.py b/python/ray/serve/_private/common.py index 1bdb2ca2f158..39ceedfb63e8 100644 --- a/python/ray/serve/_private/common.py +++ b/python/ray/serve/_private/common.py @@ -896,10 +896,20 @@ class ReplicaQueueLengthInfo: @dataclass(frozen=True) class CreatePlacementGroupRequest: - bundles: List[Dict[str, float]] - strategy: str - target_node_id: str - name: str + """Internal request for creating a per-replica placement group. + + Either ``bundles`` or ``accelerator_config`` must be provided: + - For plain CPU/GPU deployments, the caller provides ``bundles`` and the + default path creates a standard PlacementGroup. + - For accelerator deployments (e.g. TPU), the caller provides + ``accelerator_config`` and the dispatch derives bundles from the + structured config (e.g. TPU topology -> per-host bundles). + """ + + bundles: Optional[List[Dict[str, float]]] = None + strategy: str = "PACK" + target_node_id: Optional[str] = None + name: str = "" runtime_env: Optional[str] = None bundle_label_selector: Optional[List[Dict[str, str]]] = None fallback_strategy: Optional[List[Dict[str, Any]]] = None diff --git a/python/ray/serve/_private/default_impl.py b/python/ray/serve/_private/default_impl.py index a5405991325b..9fc87a330a33 100644 --- a/python/ray/serve/_private/default_impl.py +++ b/python/ray/serve/_private/default_impl.py @@ -97,7 +97,18 @@ def shutdown(self) -> None: def _create_replica_placement_group( request: CreatePlacementGroupRequest, ) -> ReplicaPlacementGroup: - """Internal entry point that supports accelerator-specific dispatch.""" + """Internal entry point that supports accelerator-specific dispatch. + + Dispatches on ``request.accelerator_config``: + - TPUAcceleratorConfig: derive bundles from topology via + slice_placement_group; ``request.bundles`` is ignored. + - None: use ``request.bundles`` to create a standard PlacementGroup. + + Raises ValueError if neither bundles nor a recognized accelerator + config is provided - this catches users setting an unrecognized + accelerator_config type without explicit bundles, which would + otherwise schedule with no PG at all. + """ accelerator_config = request.accelerator_config if isinstance(accelerator_config, TPUAcceleratorConfig): @@ -113,6 +124,14 @@ def _create_replica_placement_group( _slice_pg=slice_pg, ) + if request.bundles is None: + raise ValueError( + "CreatePlacementGroupRequest requires either non-None bundles " + "or a recognized accelerator_config. Got accelerator_config=" + f"{type(accelerator_config).__name__ if accelerator_config else None}, " + "bundles=None." + ) + pg = _default_create_placement_group(request) return ReplicaPlacementGroup(placement_group=pg) @@ -129,6 +148,7 @@ def _default_create_tpu_placement_group( accelerator_version=tpu_config.accelerator_version, num_slices=tpu_config.num_slices, chips_per_vm=tpu_config.chips_per_vm, + resources_per_bundle=tpu_config.resources_per_bundle, strategy=strategy, name=name, lifetime=lifetime, diff --git a/python/ray/serve/_private/deployment_scheduler.py b/python/ray/serve/_private/deployment_scheduler.py index 3bddfacc5eb1..4d1a4e3caa3d 100644 --- a/python/ray/serve/_private/deployment_scheduler.py +++ b/python/ray/serve/_private/deployment_scheduler.py @@ -656,7 +656,15 @@ def _schedule_replica( # TODO (jeffreywang): Add support for target labels and node affinity target_labels = None target_node_id = None - elif scheduling_request.placement_group_bundles is not None: + elif ( + scheduling_request.placement_group_bundles is not None + or scheduling_request.accelerator_config is not None + ): + # Per-replica PG path. Entered when either: + # - The user provided explicit bundles (CPU/GPU deployments), or + # - The user provided an accelerator_config that derives its own + # bundles from structured fields (e.g. TPUAcceleratorConfig + # derives bundles from topology via slice_placement_group). replica_pg = None placement_group_strategy = ( scheduling_request.placement_group_strategy diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index 70bebb2e2262..ef4bb3f6a42d 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -733,6 +733,12 @@ class TPUAcceleratorConfig(AcceleratorConfig): Ray Serve uses this config to provision a TPU slice placement group per replica and to manage its lifecycle through the controller. + When set on a deployment, this config drives placement-group creation + entirely. The deployment's ``placement_group_bundles`` and + ``placement_group_strategy`` fields are ignored - the bundles are + derived from ``topology`` (or optionally ``resources_per_bundle``), + and the strategy is chosen internally to honor slice gang scheduling. + Example: >>> from ray.serve.config import TPUAcceleratorConfig >>> config = TPUAcceleratorConfig(topology="4x4", accelerator_version="v6e") @@ -754,6 +760,15 @@ class TPUAcceleratorConfig(AcceleratorConfig): "for the given accelerator_version." ), ) + resources_per_bundle: Optional[Dict[str, float]] = Field( + default=None, + description=( + "Resources to include in every worker bundle. When unspecified, " + "SlicePlacementGroup defaults to one bundle per TPU host with " + "the bundle resources set to the number of chips on that host. " + "See ray.util.tpu.slice_placement_group for details." + ), + ) # Keep in sync with ServeDeploymentMode in dashboard/client/src/type/serve.ts diff --git a/python/ray/serve/tests/unit/test_accelerator_config.py b/python/ray/serve/tests/unit/test_accelerator_config.py index 0450760adaf4..518502228fd8 100644 --- a/python/ray/serve/tests/unit/test_accelerator_config.py +++ b/python/ray/serve/tests/unit/test_accelerator_config.py @@ -12,6 +12,7 @@ from ray.serve.api import deployment from ray.serve.config import TPUAcceleratorConfig from ray.util.placement_group import PlacementGroup +from ray.util.tpu import SlicePlacementGroup def test_tpu_accelerator_config_construction(): @@ -167,5 +168,84 @@ def test_replica_pg_shutdown_idempotent(with_accelerator): assert mock_remove.call_count == 1 +def test_create_replica_placement_group_rejects_no_bundles_no_config(): + """Without bundles or a recognized accelerator_config, raises ValueError. + + Catches future accelerator types added to AcceleratorConfig but not + wired into _create_replica_placement_group. + """ + request = CreatePlacementGroupRequest( + bundles=None, + strategy="PACK", + target_node_id="", + name="test", + accelerator_config=None, + ) + with pytest.raises(ValueError, match="requires either non-None bundles"): + _create_replica_placement_group(request) + + +def test_create_replica_placement_group_tpu_ignores_bundles(): + """TPU dispatch ignores request.bundles -- they're derived from topology.""" + request = CreatePlacementGroupRequest( + bundles=[{"CPU": 1}], + strategy="PACK", + target_node_id="", + name="test", + accelerator_config=TPUAcceleratorConfig( + topology="2x2", accelerator_version="v6e" + ), + ) + + mock_slice_pg = MagicMock() + mock_slice_pg.placement_group = MagicMock(spec=PlacementGroup) + + with patch( + "ray.serve._private.default_impl.slice_placement_group", + return_value=mock_slice_pg, + ) as mock_slice_pg_func: + result = _create_replica_placement_group(request) + + mock_slice_pg_func.assert_called_once() + + assert result.placement_group == mock_slice_pg.placement_group + + +def test_tpu_config_resources_per_bundle_forwarded_to_slice_pg(monkeypatch): + """The resources_per_bundle field is forwarded to slice_placement_group.""" + captured = {} + + # Mock slice_placement_group to capture the arguments it receives. + def mock_slice_pg(**kwargs): + captured.update(kwargs) + mock = MagicMock(spec=SlicePlacementGroup) + mock.placement_group = MagicMock() + return mock + + monkeypatch.setattr( + "ray.serve._private.default_impl.slice_placement_group", + mock_slice_pg, + ) + + # Create a config with custom resources per bundle. + config = TPUAcceleratorConfig( + topology="4x4", + accelerator_version="v6e", + resources_per_bundle={"TPU": 1, "memory": 1_000_000}, + ) + request = CreatePlacementGroupRequest( + accelerator_config=config, + name="test", + ) + + # Call the dispatch function. + _create_replica_placement_group(request) + + # Verify that resources_per_bundle and other fields were forwarded correctly. + assert captured["resources_per_bundle"] == {"TPU": 1, "memory": 1_000_000} + assert captured["topology"] == "4x4" + assert captured["accelerator_version"] == "v6e" + + if __name__ == "__main__": sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/unit/test_deployment_scheduler.py b/python/ray/serve/tests/unit/test_deployment_scheduler.py index b51928a197cf..07e62d6e20e7 100644 --- a/python/ray/serve/tests/unit/test_deployment_scheduler.py +++ b/python/ray/serve/tests/unit/test_deployment_scheduler.py @@ -37,6 +37,7 @@ MockClusterNodeInfoCache, MockPlacementGroup, ) +from ray.serve.config import TPUAcceleratorConfig from ray.tests.conftest import * # noqa from ray.util.scheduling_strategies import ( In, @@ -712,6 +713,86 @@ def set_scheduling_strategy( } +@pytest.mark.parametrize( + "bundles, acc_config_present, expect_pg_created", + [ + # Accelerator config only -> enters PG path + (None, True, True), + # Bundles only -> enters PG path + ([{"CPU": 1}], False, True), + # Both set -> enters PG path + ([{"CPU": 1}], True, True), + # Neither set -> falls through to default scheduling + (None, False, False), + ], +) +def test_schedule_replica_dispatch(bundles, acc_config_present, expect_pg_created): + """Validate that scheduler routes to PG path correctly based on bundles and config.""" + # Setup deployment IDs and cache. + d_id = DeploymentID("deployment_test", "app1") + cluster_node_info_cache = MockClusterNodeInfoCache() + captured_requests = [] + + # Mock create_placement_group_fn to record what it received. + def mock_create_pg(request): + captured_requests.append(request) + return default_impl.ReplicaPlacementGroup( + placement_group=MockPlacementGroup(request) + ) + + # Initialize scheduler with mock PG creator. + scheduler = default_impl.create_deployment_scheduler( + cluster_node_info_cache, + head_node_id_override="fake-head-node-id", + create_placement_group_fn_override=mock_create_pg, + ) + scheduler.on_deployment_created(d_id, SpreadDeploymentSchedulingPolicy()) + scheduler.on_deployment_deployed(d_id, rconfig(ray_actor_options={"num_cpus": 1})) + + r0_id = ReplicaID(unique_id="r0", deployment_id=d_id) + + acc_config = None + if acc_config_present: + acc_config = TPUAcceleratorConfig(topology="2x2", accelerator_version="v6e") + + scheduling_strategy = None + + def set_scheduling_strategy(actor_handle, *args, **kwargs): + nonlocal scheduling_strategy + scheduling_strategy = actor_handle._options["scheduling_strategy"] + + # Construct the scheduling request. + scheduling_request = ReplicaSchedulingRequest( + replica_id=r0_id, + actor_def=MockActorClass(), + actor_resources={"CPU": 1}, + placement_group_bundles=bundles, + accelerator_config=acc_config, + actor_options={"name": "r0"}, + actor_init_args=(), + on_scheduled=set_scheduling_strategy, + ) + + scheduler._pending_replicas[d_id][r0_id] = scheduling_request + + # Call _schedule_replica. + scheduler._schedule_replica( + scheduling_request=scheduling_request, + default_scheduling_strategy="some_default", + target_node_id=None, + target_labels=None, + ) + + # Verify scheduling params are as expected. + if expect_pg_created: + assert len(captured_requests) == 1 + assert captured_requests[0].accelerator_config == acc_config + assert captured_requests[0].bundles == bundles + else: + assert len(captured_requests) == 0 + assert scheduling_strategy == "some_default" + + def test_downscale_multiple_deployments(): """Test to make sure downscale prefers replicas without node id and then replicas on a node with fewest replicas of all deployments. From afb07a00f2bb3aad0c246732b4c8941604166b24 Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Tue, 12 May 2026 01:52:47 +0000 Subject: [PATCH 13/26] Safely unwrap ReplicaPlacementGroup for gangs and fix type alias Signed-off-by: Ryan O'Leary --- python/ray/serve/_private/default_impl.py | 6 ++++-- python/ray/serve/_private/deployment_scheduler.py | 14 +++++++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/python/ray/serve/_private/default_impl.py b/python/ray/serve/_private/default_impl.py index 9fc87a330a33..1265c8adbd88 100644 --- a/python/ray/serve/_private/default_impl.py +++ b/python/ray/serve/_private/default_impl.py @@ -1,7 +1,7 @@ import asyncio import logging from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple, Union import ray from ray._common.constants import HEAD_NODE_RESOURCE_NAME @@ -160,7 +160,9 @@ def create_cluster_node_info_cache(gcs_client: GcsClient) -> ClusterNodeInfoCach return DefaultClusterNodeInfoCache(gcs_client) -CreatePlacementGroupFn = Callable[[CreatePlacementGroupRequest], PlacementGroup] +CreatePlacementGroupFn = Callable[ + [CreatePlacementGroupRequest], Union[PlacementGroup, ReplicaPlacementGroup] +] def _default_create_placement_group( diff --git a/python/ray/serve/_private/deployment_scheduler.py b/python/ray/serve/_private/deployment_scheduler.py index 4d1a4e3caa3d..3771835ccab3 100644 --- a/python/ray/serve/_private/deployment_scheduler.py +++ b/python/ray/serve/_private/deployment_scheduler.py @@ -899,7 +899,7 @@ def _prepare_gangs_for_deployment( ) try: - pg = self._create_placement_group_fn( + pg_result = self._create_placement_group_fn( CreatePlacementGroupRequest( bundles=bundles, strategy=request.gang_placement_strategy, @@ -909,6 +909,18 @@ def _prepare_gangs_for_deployment( fallback_strategy=fallback_strategy, ) ) + + # Unwrap the ReplicaPlacementGroup to get the underyling PlacementGroup. + # Gang scheduling currently does not support accelerator_config (since it's + # handled by the specific accelerator backend), so we don't need the + # wrapper. Inline import here is required to avoid circular dependencies. + from ray.serve._private.default_impl import ReplicaPlacementGroup + + if isinstance(pg_result, ReplicaPlacementGroup): + pg = pg_result.placement_group + else: + pg = pg_result + gang_pgs.append(pg) gang_ids.append(gang_id) gang_pg_names.append(pg_name) From 70b6a6f0a6a27fc0f5ec787a3fef7ba8138f475f Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Tue, 12 May 2026 02:29:38 +0000 Subject: [PATCH 14/26] Fix placement group leakage on actor creation failure for custom overrides with accelerators Signed-off-by: Ryan O'Leary --- python/ray/serve/_private/deployment_scheduler.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/ray/serve/_private/deployment_scheduler.py b/python/ray/serve/_private/deployment_scheduler.py index 3771835ccab3..55e5e2a5bcba 100644 --- a/python/ray/serve/_private/deployment_scheduler.py +++ b/python/ray/serve/_private/deployment_scheduler.py @@ -750,10 +750,7 @@ def _schedule_replica( if scheduling_request.gang_placement_group is None: if replica_pg is not None: replica_pg.shutdown() - elif ( - placement_group is not None - and scheduling_request.placement_group_bundles is not None - ): + elif placement_group is not None: ray.util.remove_placement_group(placement_group) return False From 89dc61ff4d71e374461d6090498ff1cb6b3e9ec2 Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Tue, 12 May 2026 03:00:40 +0000 Subject: [PATCH 15/26] Release TPU reservation holders in cross-language replica startup path to prevent placement group leaks Signed-off-by: Ryan O'Leary --- python/ray/serve/_private/deployment_state.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index d6d2aa3812c3..a5960c6ceb02 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -1438,6 +1438,8 @@ def check_ready(self) -> Tuple[ReplicaStartupStatus, Optional[str]]: try: # TODO(simon): fully implement reconfigure for Java replicas. if self._is_cross_language: + if self._replica_pg is not None: + self._replica_pg.release_reservation_holders() return ReplicaStartupStatus.SUCCEEDED, None # todo: The replica's userconfig whitch java client created From f82eab9dfb27bc90bffa91f97b649820925f54d9 Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Tue, 12 May 2026 03:13:20 +0000 Subject: [PATCH 16/26] Remove redundant replica_pg reassignment in deployment scheduler Signed-off-by: Ryan O'Leary --- python/ray/serve/_private/deployment_scheduler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/ray/serve/_private/deployment_scheduler.py b/python/ray/serve/_private/deployment_scheduler.py index 55e5e2a5bcba..58f586a4a229 100644 --- a/python/ray/serve/_private/deployment_scheduler.py +++ b/python/ray/serve/_private/deployment_scheduler.py @@ -665,7 +665,6 @@ def _schedule_replica( # - The user provided an accelerator_config that derives its own # bundles from structured fields (e.g. TPUAcceleratorConfig # derives bundles from topology via slice_placement_group). - replica_pg = None placement_group_strategy = ( scheduling_request.placement_group_strategy if scheduling_request.placement_group_strategy From 19c8aeb131f4fa697ca9f930fe4fc1796afee29d Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Tue, 12 May 2026 03:33:05 +0000 Subject: [PATCH 17/26] Safeguard check_stopped placement group teardown with robust exception handling to prevent controller crashes Signed-off-by: Ryan O'Leary --- python/ray/serve/_private/deployment_state.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index a5960c6ceb02..e76f7f80e57a 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -1547,6 +1547,10 @@ def check_stopped(self) -> bool: logger.debug( f"Placement group for {self._replica_id} was already removed." ) + except Exception: + logger.exception( + f"Unexpected error shutting down placement groups for {self._replica_id}." + ) finally: # Always clear references to prevent memory leaks and dangling state. self._gang_placement_group = None From 8eab85a3930f7d59a3791327c197bf1c805495f5 Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Tue, 12 May 2026 11:45:30 +0000 Subject: [PATCH 18/26] fix gang pg cleanup to fix tests Signed-off-by: Ryan O'Leary --- python/ray/serve/_private/deployment_state.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index e76f7f80e57a..b9b3c9640061 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -1528,16 +1528,20 @@ def check_stopped(self) -> bool: # it was just killed above. if stopped: try: - # 1. Gang PGs are shared and managed by the DeploymentStateManager. - # We do nothing active here to avoid shutting them down prematurely. + # Teardown shared gang placement group. The first replica in the + # gang to stop deletes it. Subsequent replicas catch ValueError. if self._gang_placement_group is not None: - pass + try: + ray.util.remove_placement_group(self._gang_placement_group) + except ValueError: + # Already removed by another replica in this gang. + pass - # 2. Replicas with accelerator/wrapper PGs handle their own shutdown. + # Replicas with accelerator/wrapper PGs handle their own shutdown. elif self._replica_pg is not None: self._replica_pg.shutdown() - # 3. Standard single-replica placement groups. + # Standard single-replica placement groups. elif self._placement_group is not None: try: ray.util.remove_placement_group(self._placement_group) @@ -1552,7 +1556,7 @@ def check_stopped(self) -> bool: f"Unexpected error shutting down placement groups for {self._replica_id}." ) finally: - # Always clear references to prevent memory leaks and dangling state. + # Clear references to prevent memory leaks and dangling state. self._gang_placement_group = None self._replica_pg = None self._placement_group = None From e84850ad52df0e3cf120c15fce902602d35453ab Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Tue, 12 May 2026 21:19:10 +0000 Subject: [PATCH 19/26] add check in api for accelerator_config and gang at same time Signed-off-by: Ryan O'Leary --- python/ray/serve/api.py | 15 +++++- .../tests/unit/test_accelerator_config.py | 48 ++++++++++++++++++- 2 files changed, 61 insertions(+), 2 deletions(-) diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index c54c9ccdcbb7..6695de3245d6 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -669,9 +669,22 @@ class MyDeployment: if isinstance(logging_config, LoggingConfig): logging_config = logging_config.model_dump() - if accelerator_config is not DEFAULT.VALUE: + if accelerator_config is not DEFAULT.VALUE and accelerator_config is not None: accelerator_config = _resolve_accelerator_config(accelerator_config) + if ( + gang_scheduling_config is not DEFAULT.VALUE + and gang_scheduling_config is not None + ): + # The only supported accelerator_config currently is for TPU, which utilizes + # SlicePlacementGroup internally for atomic scheduling of SPMD workers. This + # check can be loosened if additional accelerator configs are added in the + # future that don't manage their own gang scheduling. + raise ValueError( + "Cannot specify both `accelerator_config` and `gang_scheduling_config`. " + "Accelerator configurations automatically manage their own gang scheduling." + ) + deployment_config = DeploymentConfig.from_default( num_replicas=num_replicas if num_replicas is not None else 1, user_config=user_config, diff --git a/python/ray/serve/tests/unit/test_accelerator_config.py b/python/ray/serve/tests/unit/test_accelerator_config.py index 518502228fd8..dbc5b12c0089 100644 --- a/python/ray/serve/tests/unit/test_accelerator_config.py +++ b/python/ray/serve/tests/unit/test_accelerator_config.py @@ -10,7 +10,7 @@ _create_replica_placement_group, ) from ray.serve.api import deployment -from ray.serve.config import TPUAcceleratorConfig +from ray.serve.config import GangSchedulingConfig, TPUAcceleratorConfig from ray.util.placement_group import PlacementGroup from ray.util.tpu import SlicePlacementGroup @@ -247,5 +247,51 @@ def mock_slice_pg(**kwargs): assert captured["accelerator_version"] == "v6e" +@pytest.mark.parametrize( + "options, should_raise", + [ + ({}, False), + ( + { + "accelerator_config": TPUAcceleratorConfig( + topology="2x2", accelerator_version="v6e" + ) + }, + False, + ), + ( + { + "gang_scheduling_config": GangSchedulingConfig(gang_size=2), + "num_replicas": 2, + }, + False, + ), + ( + { + "accelerator_config": TPUAcceleratorConfig( + topology="2x2", accelerator_version="v6e" + ), + "gang_scheduling_config": GangSchedulingConfig(gang_size=2), + "num_replicas": 2, + }, + True, + ), + ], +) +def test_deployment_config_mutual_exclusivity(options, should_raise): + """accelerator_config and gang_scheduling_config validation matrix.""" + + def create_deployment(): + @deployment(**options) + class D: + pass + + if should_raise: + with pytest.raises(ValueError, match="Cannot specify both"): + create_deployment() + else: + create_deployment() + + if __name__ == "__main__": sys.exit(pytest.main(["-v", "-s", __file__])) From acefb0631e5e2e20ef468811caf6153d1a1a2198 Mon Sep 17 00:00:00 2001 From: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com> Date: Thu, 21 May 2026 00:08:39 -0700 Subject: [PATCH 20/26] Apply suggestions from code review Co-authored-by: Jeffrey Wang Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com> --- python/ray/serve/_private/default_impl.py | 2 +- python/ray/serve/api.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/serve/_private/default_impl.py b/python/ray/serve/_private/default_impl.py index 1265c8adbd88..d2943067f87b 100644 --- a/python/ray/serve/_private/default_impl.py +++ b/python/ray/serve/_private/default_impl.py @@ -66,7 +66,7 @@ class ReplicaPlacementGroup: plain CPU/GPU PG or a TPU slice reservation. """ - placement_group: PlacementGroup + placement_group: Optional[PlacementGroup] _slice_pg: Optional[SlicePlacementGroup] = None def release_reservation_holders(self) -> None: diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 6695de3245d6..548259d52142 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -471,7 +471,7 @@ def _resolve_accelerator_config( if kind == "tpu": return TPUAcceleratorConfig(**value) raise ValueError( - f"Unknown accelerator kind {kind!r}. " f"Supported types: 'tpu'." + f"Unknown accelerator kind {kind!r}. Supported types: 'tpu'." ) raise TypeError( f"accelerator_config must be a dict or AcceleratorConfig, got {type(value)}." From deb9767b017b3dd9245c7becb5b45500cd76eaf6 Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Thu, 21 May 2026 08:13:47 +0000 Subject: [PATCH 21/26] remove circular dependency / import, add constants for commonly used strings, change from Dev API to PublicAPI, and fix other comments Signed-off-by: Ryan O'Leary --- python/ray/serve/_private/common.py | 23 +-- python/ray/serve/_private/default_impl.py | 127 ++--------------- .../serve/_private/deployment_scheduler.py | 12 +- python/ray/serve/_private/deployment_state.py | 10 +- .../serve/_private/placement_group_utils.py | 132 ++++++++++++++++++ python/ray/serve/_private/version.py | 9 ++ python/ray/serve/api.py | 7 +- python/ray/serve/config.py | 10 +- .../serve/tests/test_accelerator_config.py | 6 +- .../tests/unit/test_accelerator_config.py | 12 +- .../tests/unit/test_deployment_version.py | 47 +++++++ 11 files changed, 228 insertions(+), 167 deletions(-) create mode 100644 python/ray/serve/_private/placement_group_utils.py diff --git a/python/ray/serve/_private/common.py b/python/ray/serve/_private/common.py index 39ceedfb63e8..54b7d35f8d24 100644 --- a/python/ray/serve/_private/common.py +++ b/python/ray/serve/_private/common.py @@ -19,7 +19,7 @@ from ray.util.placement_group import PlacementGroup if TYPE_CHECKING: - from ray.serve.config import AcceleratorConfig + pass REPLICA_ID_FULL_ID_STR_PREFIX = "SERVE_REPLICA::" GANG_PG_NAME_PREFIX = "SERVE_GANG::" @@ -894,27 +894,6 @@ class ReplicaQueueLengthInfo: num_ongoing_requests: int -@dataclass(frozen=True) -class CreatePlacementGroupRequest: - """Internal request for creating a per-replica placement group. - - Either ``bundles`` or ``accelerator_config`` must be provided: - - For plain CPU/GPU deployments, the caller provides ``bundles`` and the - default path creates a standard PlacementGroup. - - For accelerator deployments (e.g. TPU), the caller provides - ``accelerator_config`` and the dispatch derives bundles from the - structured config (e.g. TPU topology -> per-host bundles). - """ - - bundles: Optional[List[Dict[str, float]]] = None - strategy: str = "PACK" - target_node_id: Optional[str] = None - name: str = "" - runtime_env: Optional[str] = None - bundle_label_selector: Optional[List[Dict[str, str]]] = None - fallback_strategy: Optional[List[Dict[str, Any]]] = None - accelerator_config: Optional["AcceleratorConfig"] = None - lifetime: Optional[str] = "detached" @dataclass diff --git a/python/ray/serve/_private/default_impl.py b/python/ray/serve/_private/default_impl.py index 4bd15b591aff..b741e8187481 100644 --- a/python/ray/serve/_private/default_impl.py +++ b/python/ray/serve/_private/default_impl.py @@ -1,7 +1,5 @@ import asyncio -import logging -from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import ray from ray._common.constants import HEAD_NODE_RESOURCE_NAME @@ -11,7 +9,6 @@ DefaultClusterNodeInfoCache, ) from ray.serve._private.common import ( - CreatePlacementGroupRequest, DeploymentHandleSource, DeploymentID, EndpointInfo, @@ -35,6 +32,11 @@ from ray.serve._private.event_loop_monitoring import EventLoopMonitor from ray.serve._private.grpc_util import gRPCGenericServer from ray.serve._private.handle_options import DynamicHandleOptions, InitHandleOptions +from ray.serve._private.placement_group_utils import ( + CreatePlacementGroupRequest, + ReplicaPlacementGroup, + _create_replica_placement_group, +) from ray.serve._private.router import CurrentLoopRouter, Router, SingletonThreadRouter from ray.serve._private.utils import ( asyncio_grpc_exception_handler, @@ -44,9 +46,8 @@ inside_ray_client_context, resolve_deployment_response, ) -from ray.serve.config import ControllerOptions, TPUAcceleratorConfig -from ray.util.placement_group import PlacementGroup, remove_placement_group -from ray.util.tpu import SlicePlacementGroup, slice_placement_group +from ray.serve.config import ControllerOptions +from ray.util.placement_group import PlacementGroup # NOTE: Please read carefully before changing! # @@ -55,105 +56,6 @@ # API modified w/o substantial enough justification -@dataclass -class ReplicaPlacementGroup: - """Internal Serve handle for a replica's placement group(s). - - Wraps the worker PG and any accelerator-specific cleanup hooks so the - controller doesn't need to know whether the underlying request was a - plain CPU/GPU PG or a TPU slice reservation. - """ - - placement_group: Optional[PlacementGroup] - _slice_pg: Optional[SlicePlacementGroup] = None - - def release_reservation_holders(self) -> None: - """Call after ``placement_group.ready()`` resolves successfully. - - Releases any internal reservation-holder PGs (e.g. TPU head PGs) - that were only needed to claim resources during scheduling. No-op - for non-accelerator deployments. - """ - if self._slice_pg is not None: - self._slice_pg.release_head_pgs() - - def shutdown(self) -> None: - """Tear down the replica's PG(s). Idempotent.""" - if self._slice_pg is not None: - self._slice_pg.shutdown() - self._slice_pg = None - self.placement_group = None - elif self.placement_group is not None: - try: - remove_placement_group(self.placement_group) - except Exception: - logger.exception("Failed to remove placement group.") - finally: - self.placement_group = None - - -def _create_replica_placement_group( - request: CreatePlacementGroupRequest, -) -> ReplicaPlacementGroup: - """Internal entry point that supports accelerator-specific dispatch. - - Dispatches on ``request.accelerator_config``: - - TPUAcceleratorConfig: derive bundles from topology via - slice_placement_group; ``request.bundles`` is ignored. - - None: use ``request.bundles`` to create a standard PlacementGroup. - - Raises ValueError if neither bundles nor a recognized accelerator - config is provided - this catches users setting an unrecognized - accelerator_config type without explicit bundles, which would - otherwise schedule with no PG at all. - """ - accelerator_config = request.accelerator_config - - if isinstance(accelerator_config, TPUAcceleratorConfig): - slice_pg = _default_create_tpu_placement_group( - tpu_config=accelerator_config, - strategy=request.strategy, - name=request.name, - lifetime=request.lifetime, - bundle_label_selector=request.bundle_label_selector, - ) - return ReplicaPlacementGroup( - placement_group=slice_pg.placement_group, - _slice_pg=slice_pg, - ) - - if request.bundles is None: - raise ValueError( - "CreatePlacementGroupRequest requires either non-None bundles " - "or a recognized accelerator_config. Got accelerator_config=" - f"{type(accelerator_config).__name__ if accelerator_config else None}, " - "bundles=None." - ) - - pg = _default_create_placement_group(request) - return ReplicaPlacementGroup(placement_group=pg) - - -def _default_create_tpu_placement_group( - tpu_config: TPUAcceleratorConfig, - strategy: str, - name: str, - lifetime: Optional[str], - bundle_label_selector: Optional[List[Dict[str, str]]] = None, -) -> SlicePlacementGroup: - return slice_placement_group( - topology=tpu_config.topology, - accelerator_version=tpu_config.accelerator_version, - num_slices=tpu_config.num_slices, - chips_per_vm=tpu_config.chips_per_vm, - resources_per_bundle=tpu_config.resources_per_bundle, - strategy=strategy, - name=name, - lifetime=lifetime, - bundle_label_selector=bundle_label_selector, - ) - - def create_cluster_node_info_cache(gcs_client: GcsClient) -> ClusterNodeInfoCache: return DefaultClusterNodeInfoCache(gcs_client) @@ -163,19 +65,6 @@ def create_cluster_node_info_cache(gcs_client: GcsClient) -> ClusterNodeInfoCach ] -def _default_create_placement_group( - request: CreatePlacementGroupRequest, -) -> PlacementGroup: - return ray.util.placement_group( - request.bundles, - request.strategy, - _soft_target_node_id=request.target_node_id, - name=request.name, - lifetime=request.lifetime, - bundle_label_selector=request.bundle_label_selector, - ) - - def create_deployment_scheduler( cluster_node_info_cache: ClusterNodeInfoCache, head_node_id_override: Optional[str] = None, diff --git a/python/ray/serve/_private/deployment_scheduler.py b/python/ray/serve/_private/deployment_scheduler.py index 58f586a4a229..d974690a2fe4 100644 --- a/python/ray/serve/_private/deployment_scheduler.py +++ b/python/ray/serve/_private/deployment_scheduler.py @@ -14,7 +14,6 @@ from ray.serve._private.cluster_node_info_cache import ClusterNodeInfoCache from ray.serve._private.common import ( GANG_PG_NAME_PREFIX, - CreatePlacementGroupRequest, DeploymentID, GangPlacementGroupRequest, GangReservationResult, @@ -27,6 +26,10 @@ RAY_SERVE_USE_PACK_SCHEDULING_STRATEGY, SERVE_LOGGER_NAME, ) +from ray.serve._private.placement_group_utils import ( + CreatePlacementGroupRequest, + ReplicaPlacementGroup, +) from ray.serve.config import AcceleratorConfig from ray.util.placement_group import PlacementGroup from ray.util.scheduling_strategies import ( @@ -681,8 +684,7 @@ def _schedule_replica( accelerator_config=scheduling_request.accelerator_config, ), ) - # Import ReplicaPlacementGroup inline here to avoid circular dependency with default_impl - from ray.serve._private.default_impl import ReplicaPlacementGroup + # Statically imported from placement_group_utils if isinstance(pg_result, ReplicaPlacementGroup): placement_group = pg_result.placement_group @@ -909,9 +911,7 @@ def _prepare_gangs_for_deployment( # Unwrap the ReplicaPlacementGroup to get the underyling PlacementGroup. # Gang scheduling currently does not support accelerator_config (since it's # handled by the specific accelerator backend), so we don't need the - # wrapper. Inline import here is required to avoid circular dependencies. - from ray.serve._private.default_impl import ReplicaPlacementGroup - + # wrapper. if isinstance(pg_result, ReplicaPlacementGroup): pg = pg_result.placement_group else: diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index f753a0155c7e..d7a1e4b10351 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -10,15 +10,12 @@ from copy import copy from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple +from typing import Any, Callable, Dict, List, Optional, Set, Tuple import ray from ray import ObjectRef, cloudpickle from ray._common import ray_constants from ray.actor import ActorHandle - -if TYPE_CHECKING: - from ray.serve._private.default_impl import ReplicaPlacementGroup from ray.exceptions import ( RayActorError, RayError, @@ -79,6 +76,7 @@ ) from ray.serve._private.exceptions import DeploymentIsBeingDeletedError from ray.serve._private.long_poll import LongPollHost, LongPollNamespace +from ray.serve._private.placement_group_utils import ReplicaPlacementGroup from ray.serve._private.storage.kv_store import KVStoreBase from ray.serve._private.usage import ServeUsageTag from ray.serve._private.utils import ( @@ -1542,7 +1540,9 @@ def check_stopped(self) -> bool: ray.util.remove_placement_group(self._gang_placement_group) except ValueError: # Already removed by another replica in this gang. - pass + logger.debug( + f"Gang placement group for {self._replica_id} was already removed." + ) # Replicas with accelerator/wrapper PGs handle their own shutdown. elif self._replica_pg is not None: diff --git a/python/ray/serve/_private/placement_group_utils.py b/python/ray/serve/_private/placement_group_utils.py new file mode 100644 index 000000000000..fa7ed622ffbe --- /dev/null +++ b/python/ray/serve/_private/placement_group_utils.py @@ -0,0 +1,132 @@ +import logging +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import ray +from ray.util.placement_group import PlacementGroup, remove_placement_group +from ray.util.tpu import SlicePlacementGroup, slice_placement_group +from ray.serve._private.constants import SERVE_LOGGER_NAME + +logger = logging.getLogger(SERVE_LOGGER_NAME) + +# NOTE: Please read carefully before changing! +# +# Similar to `default_impl.py`, methods like `_default_create_placement_group` are +# common extension points and should be treated as a Developer API. + +@dataclass(frozen=True) +class CreatePlacementGroupRequest: + """Internal request for creating a per-replica placement group. + + Either ``bundles`` or ``accelerator_config`` must be provided: + - For plain CPU/GPU deployments, the caller provides ``bundles`` and the + default path creates a standard PlacementGroup. + - For accelerator deployments (e.g. TPU), the caller provides + ``accelerator_config`` and the dispatch derives bundles from the + structured config (e.g. TPU topology -> per-host bundles). + """ + + bundles: Optional[List[Dict[str, float]]] = None + strategy: str = "PACK" + target_node_id: Optional[str] = None + name: str = "" + runtime_env: Optional[str] = None + bundle_label_selector: Optional[List[Dict[str, str]]] = None + fallback_strategy: Optional[List[Dict[str, Any]]] = None + accelerator_config: Optional[Any] = None + + +@dataclass +class ReplicaPlacementGroup: + """Internal Serve handle for a replica's placement group(s). + + Wraps the worker PG and any accelerator-specific cleanup hooks so the + controller doesn't need to know whether the underlying request was a + plain CPU/GPU PG or a TPU slice reservation. + """ + + placement_group: Optional[PlacementGroup] + _slice_pg: Optional[SlicePlacementGroup] = None + + def release_reservation_holders(self) -> None: + """Call after ``placement_group.ready()`` resolves successfully. + + Releases any internal reservation-holder PGs (e.g. TPU head PGs) + that were only needed to claim resources during scheduling. No-op + for non-accelerator deployments. + """ + if self._slice_pg is not None: + self._slice_pg.release_head_pgs() + + def shutdown(self) -> None: + """Tear down the replica's PG(s). Idempotent.""" + if self._slice_pg is not None: + self._slice_pg.shutdown() + self._slice_pg = None + self.placement_group = None + elif self.placement_group is not None: + try: + remove_placement_group(self.placement_group) + except Exception: + logger.exception("Failed to remove placement group.") + finally: + self.placement_group = None + + +def _default_create_placement_group( + request: CreatePlacementGroupRequest, +) -> PlacementGroup: + return ray.util.placement_group( + request.bundles, + request.strategy, + _soft_target_node_id=request.target_node_id, + name=request.name, + lifetime="detached", + bundle_label_selector=request.bundle_label_selector, + ) + + +def _create_replica_placement_group( + request: CreatePlacementGroupRequest, +) -> ReplicaPlacementGroup: + """Internal entry point that supports accelerator-specific dispatch. + + Dispatches on ``request.accelerator_config``: + - TPUAcceleratorConfig: derive bundles from topology via + slice_placement_group; ``request.bundles`` is ignored. + - None: use ``request.bundles`` to create a standard PlacementGroup. + + Raises ValueError if neither bundles nor a recognized accelerator + config is provided - this catches users setting an unrecognized + accelerator_config type without explicit bundles, which would + otherwise schedule with no PG at all. + """ + accelerator_config = request.accelerator_config + + if getattr(accelerator_config, "kind", None) == "tpu": + slice_pg = slice_placement_group( + topology=accelerator_config.topology, + accelerator_version=accelerator_config.accelerator_version, + num_slices=accelerator_config.num_slices, + chips_per_vm=accelerator_config.chips_per_vm, + resources_per_bundle=accelerator_config.resources_per_bundle, + strategy=request.strategy, + name=request.name, + lifetime="detached", + bundle_label_selector=request.bundle_label_selector, + ) + return ReplicaPlacementGroup( + placement_group=slice_pg.placement_group, + _slice_pg=slice_pg, + ) + + if request.bundles is None: + raise ValueError( + "CreatePlacementGroupRequest requires either non-None bundles " + "or a recognized accelerator_config. Got accelerator_config=" + f"{type(accelerator_config).__name__ if accelerator_config else None}, " + "bundles=None." + ) + + pg = _default_create_placement_group(request) + return ReplicaPlacementGroup(placement_group=pg) diff --git a/python/ray/serve/_private/version.py b/python/ray/serve/_private/version.py index a0aea7a2a63d..743755ada42a 100644 --- a/python/ray/serve/_private/version.py +++ b/python/ray/serve/_private/version.py @@ -79,6 +79,8 @@ def requires_actor_restart(self, new_version): or self.max_replicas_per_node != new_version.max_replicas_per_node or self.gang_scheduling_config_hash != new_version.gang_scheduling_config_hash + or self.accelerator_config_hash + != new_version.accelerator_config_hash ) def requires_actor_reconfigure(self, new_version): @@ -124,6 +126,12 @@ def compute_hashes(self): else {} ) self.gang_scheduling_config_hash = crc32(serialized_gang_scheduling_config) + serialized_accelerator_config = _serialize( + self.deployment_config.accelerator_config.model_dump() + if self.deployment_config.accelerator_config is not None + else {} + ) + self.accelerator_config_hash = crc32(serialized_accelerator_config) # Include app-level route prefix in the version hashes so changing # it triggers an in-place reconfigure of running replicas. serialized_route_prefix = _serialize(self.route_prefix) @@ -152,6 +160,7 @@ def compute_hashes(self): ] ) + serialized_gang_scheduling_config + + serialized_accelerator_config ) def to_proto(self) -> bytes: diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index e5e279b2d83f..55183b9b5edb 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -40,6 +40,7 @@ wait_for_interrupt, ) from ray.serve.config import ( + ACCELERATOR_KIND_TPU, AcceleratorConfig, AutoscalingConfig, ControllerOptions, @@ -477,7 +478,7 @@ def _resolve_accelerator_config( return value if isinstance(value, dict): kind = value.get("kind") - if kind == "tpu": + if kind == ACCELERATOR_KIND_TPU: return TPUAcceleratorConfig(**value) raise ValueError( f"Unknown accelerator kind {kind!r}. Supported types: 'tpu'." @@ -685,6 +686,10 @@ class MyDeployment: gang_scheduling_config is not DEFAULT.VALUE and gang_scheduling_config is not None ): + # TODO(ryanaoleary@): Revisit this mutual exclusivity restriction once + # Data Parallel (DP) attention or more complex multi-slice gang + # scheduling is supported for TPUs. + # # The only supported accelerator_config currently is for TPU, which utilizes # SlicePlacementGroup internally for atomic scheduling of SPMD workers. This # check can be loosened if additional accelerator configs are added in the diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index 48c3a09f53f4..2af55f6750fd 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -42,10 +42,12 @@ SERVE_LOGGER_NAME, ) from ray.serve._private.utils import validate_ssl_config -from ray.util.annotations import DeveloperAPI, PublicAPI +from ray.util.annotations import PublicAPI logger = logging.getLogger(SERVE_LOGGER_NAME) +ACCELERATOR_KIND_TPU = "tpu" + @PublicAPI(stability="stable") class AutoscalingContext: @@ -709,7 +711,7 @@ def get_target_ongoing_requests(self) -> PositiveFloat: return self.target_ongoing_requests -@DeveloperAPI(stability="alpha") +@PublicAPI(stability="alpha") class AcceleratorConfig(BaseModel): """Base class for structured accelerator configurations. @@ -725,7 +727,7 @@ class AcceleratorConfig(BaseModel): model_config = {"frozen": True, "extra": "forbid"} -@DeveloperAPI(stability="alpha") +@PublicAPI(stability="alpha") class TPUAcceleratorConfig(AcceleratorConfig): """TPU slice specification for a Serve deployment. @@ -744,7 +746,7 @@ class TPUAcceleratorConfig(AcceleratorConfig): >>> config = TPUAcceleratorConfig(topology="4x4", accelerator_version="v6e") """ - kind: Literal["tpu"] = "tpu" + kind: Literal["tpu"] = ACCELERATOR_KIND_TPU topology: str = Field( ..., description="TPU pod topology, e.g. '2x2', '4x4', '2x2x2'." diff --git a/python/ray/serve/tests/test_accelerator_config.py b/python/ray/serve/tests/test_accelerator_config.py index 89ccded4db2e..7ef34d276dfd 100644 --- a/python/ray/serve/tests/test_accelerator_config.py +++ b/python/ray/serve/tests/test_accelerator_config.py @@ -6,8 +6,8 @@ import ray from ray import serve from ray.cluster_utils import Cluster -from ray.serve._private.common import CreatePlacementGroupRequest -from ray.serve._private.default_impl import ( +from ray.serve._private.placement_group_utils import ( + CreatePlacementGroupRequest, ReplicaPlacementGroup, _create_replica_placement_group, ) @@ -69,7 +69,6 @@ def test_tpu_accelerator_config_integration(mock_tpu_cluster): target_node_id=None, name="test-tpu-pg", accelerator_config=tpu_config, - lifetime="detached", ) # This should call _create_tpu_placement_group and return a wrapper @@ -104,7 +103,6 @@ def test_tpu_accelerator_config_partial_failure_cleanup(mock_tpu_cluster): target_node_id=None, name="test-tpu-timeout-pg", accelerator_config=tpu_config, - lifetime="detached", ) # Patch remove_placement_group where it is USED (ray.util.tpu) diff --git a/python/ray/serve/tests/unit/test_accelerator_config.py b/python/ray/serve/tests/unit/test_accelerator_config.py index dbc5b12c0089..5f3756c5b769 100644 --- a/python/ray/serve/tests/unit/test_accelerator_config.py +++ b/python/ray/serve/tests/unit/test_accelerator_config.py @@ -4,8 +4,8 @@ import pytest from pydantic import ValidationError -from ray.serve._private.common import CreatePlacementGroupRequest -from ray.serve._private.default_impl import ( +from ray.serve._private.placement_group_utils import ( + CreatePlacementGroupRequest, ReplicaPlacementGroup, _create_replica_placement_group, ) @@ -115,7 +115,7 @@ def test_placement_group_creation_types(with_accelerator): mock_slice_pg = MagicMock() mock_slice_pg.placement_group = mock_pg with patch( - "ray.serve._private.default_impl.slice_placement_group", + "ray.serve._private.placement_group_utils.slice_placement_group", return_value=mock_slice_pg, ): result = _create_replica_placement_group(request) @@ -159,7 +159,7 @@ def test_replica_pg_shutdown_idempotent(with_accelerator): adapter = ReplicaPlacementGroup(placement_group=mock_pg) with patch( - "ray.serve._private.default_impl.remove_placement_group" + "ray.serve._private.placement_group_utils.remove_placement_group" ) as mock_remove: adapter.shutdown() mock_remove.assert_called_once_with(mock_pg) @@ -201,7 +201,7 @@ def test_create_replica_placement_group_tpu_ignores_bundles(): mock_slice_pg.placement_group = MagicMock(spec=PlacementGroup) with patch( - "ray.serve._private.default_impl.slice_placement_group", + "ray.serve._private.placement_group_utils.slice_placement_group", return_value=mock_slice_pg, ) as mock_slice_pg_func: result = _create_replica_placement_group(request) @@ -223,7 +223,7 @@ def mock_slice_pg(**kwargs): return mock monkeypatch.setattr( - "ray.serve._private.default_impl.slice_placement_group", + "ray.serve._private.placement_group_utils.slice_placement_group", mock_slice_pg, ) diff --git a/python/ray/serve/tests/unit/test_deployment_version.py b/python/ray/serve/tests/unit/test_deployment_version.py index 6a4779769fba..7e7e3b6d3dff 100644 --- a/python/ray/serve/tests/unit/test_deployment_version.py +++ b/python/ray/serve/tests/unit/test_deployment_version.py @@ -2,6 +2,7 @@ from ray.serve._private.config import DeploymentConfig from ray.serve._private.deployment_state import DeploymentVersion +from ray.serve.config import TPUAcceleratorConfig def test_validation(): @@ -406,6 +407,52 @@ def test_requires_long_poll_broadcast(): assert not v1.requires_long_poll_broadcast(v2) +def test_accelerator_config(): + v1 = DeploymentVersion("1", DeploymentConfig(), {"num_cpus": 0.1}) + v2 = DeploymentVersion( + "1", + DeploymentConfig( + accelerator_config=TPUAcceleratorConfig( + topology="4x4", accelerator_version="v6e" + ) + ), + {"num_cpus": 0.1}, + ) + v3 = DeploymentVersion( + "1", + DeploymentConfig( + accelerator_config=TPUAcceleratorConfig( + topology="4x4", accelerator_version="v6e" + ) + ), + {"num_cpus": 0.1}, + ) + v4 = DeploymentVersion( + "1", + DeploymentConfig( + accelerator_config=TPUAcceleratorConfig( + topology="2x2", accelerator_version="v6e" + ) + ), + {"num_cpus": 0.1}, + ) + + # Changing accelerator_config from None -> TPUAcceleratorConfig triggers restart + assert v1 != v2 + assert hash(v1) != hash(v2) + assert v1.requires_actor_restart(v2) + + # Same TPUAcceleratorConfig does not trigger restart + assert v2 == v3 + assert hash(v2) == hash(v3) + assert not v2.requires_actor_restart(v3) + + # Changing topology from 4x4 -> 2x2 triggers restart + assert v3 != v4 + assert hash(v3) != hash(v4) + assert v3.requires_actor_restart(v4) + + if __name__ == "__main__": import sys From 2f8282d1844599d8d93820bc731c420ca6734e8d Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Thu, 21 May 2026 08:25:38 +0000 Subject: [PATCH 22/26] run linter and remove empty type checking block Signed-off-by: Ryan O'Leary --- python/ray/serve/_private/common.py | 5 +---- python/ray/serve/_private/placement_group_utils.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/python/ray/serve/_private/common.py b/python/ray/serve/_private/common.py index 54b7d35f8d24..50d9a925b1ac 100644 --- a/python/ray/serve/_private/common.py +++ b/python/ray/serve/_private/common.py @@ -1,7 +1,7 @@ import json from dataclasses import asdict, dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional +from typing import Any, Awaitable, Callable, Dict, List, Optional from starlette.types import Scope @@ -18,9 +18,6 @@ from ray.util.annotations import PublicAPI from ray.util.placement_group import PlacementGroup -if TYPE_CHECKING: - pass - REPLICA_ID_FULL_ID_STR_PREFIX = "SERVE_REPLICA::" GANG_PG_NAME_PREFIX = "SERVE_GANG::" diff --git a/python/ray/serve/_private/placement_group_utils.py b/python/ray/serve/_private/placement_group_utils.py index fa7ed622ffbe..1e4e6af9102f 100644 --- a/python/ray/serve/_private/placement_group_utils.py +++ b/python/ray/serve/_private/placement_group_utils.py @@ -3,9 +3,9 @@ from typing import Any, Dict, List, Optional import ray +from ray.serve._private.constants import SERVE_LOGGER_NAME from ray.util.placement_group import PlacementGroup, remove_placement_group from ray.util.tpu import SlicePlacementGroup, slice_placement_group -from ray.serve._private.constants import SERVE_LOGGER_NAME logger = logging.getLogger(SERVE_LOGGER_NAME) From a95f272924cb14d0bfb5a8f8ce724f203ab48771 Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Thu, 21 May 2026 08:30:41 +0000 Subject: [PATCH 23/26] run lint again and remove unneeded comment Signed-off-by: Ryan O'Leary --- python/ray/serve/_private/common.py | 2 -- python/ray/serve/_private/deployment_scheduler.py | 2 -- python/ray/serve/_private/placement_group_utils.py | 1 + python/ray/serve/_private/version.py | 3 +-- python/ray/serve/api.py | 4 +--- 5 files changed, 3 insertions(+), 9 deletions(-) diff --git a/python/ray/serve/_private/common.py b/python/ray/serve/_private/common.py index 50d9a925b1ac..9b18e1bca057 100644 --- a/python/ray/serve/_private/common.py +++ b/python/ray/serve/_private/common.py @@ -891,8 +891,6 @@ class ReplicaQueueLengthInfo: num_ongoing_requests: int - - @dataclass class GangPlacementGroupRequest: """Request to reserve gang placement groups for a deployment.""" diff --git a/python/ray/serve/_private/deployment_scheduler.py b/python/ray/serve/_private/deployment_scheduler.py index d974690a2fe4..3e8dddf8f3bf 100644 --- a/python/ray/serve/_private/deployment_scheduler.py +++ b/python/ray/serve/_private/deployment_scheduler.py @@ -684,8 +684,6 @@ def _schedule_replica( accelerator_config=scheduling_request.accelerator_config, ), ) - # Statically imported from placement_group_utils - if isinstance(pg_result, ReplicaPlacementGroup): placement_group = pg_result.placement_group replica_pg = pg_result diff --git a/python/ray/serve/_private/placement_group_utils.py b/python/ray/serve/_private/placement_group_utils.py index 1e4e6af9102f..6795f54a2df9 100644 --- a/python/ray/serve/_private/placement_group_utils.py +++ b/python/ray/serve/_private/placement_group_utils.py @@ -14,6 +14,7 @@ # Similar to `default_impl.py`, methods like `_default_create_placement_group` are # common extension points and should be treated as a Developer API. + @dataclass(frozen=True) class CreatePlacementGroupRequest: """Internal request for creating a per-replica placement group. diff --git a/python/ray/serve/_private/version.py b/python/ray/serve/_private/version.py index 743755ada42a..9a71b8652fd3 100644 --- a/python/ray/serve/_private/version.py +++ b/python/ray/serve/_private/version.py @@ -79,8 +79,7 @@ def requires_actor_restart(self, new_version): or self.max_replicas_per_node != new_version.max_replicas_per_node or self.gang_scheduling_config_hash != new_version.gang_scheduling_config_hash - or self.accelerator_config_hash - != new_version.accelerator_config_hash + or self.accelerator_config_hash != new_version.accelerator_config_hash ) def requires_actor_reconfigure(self, new_version): diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 55183b9b5edb..7c618ba13d46 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -480,9 +480,7 @@ def _resolve_accelerator_config( kind = value.get("kind") if kind == ACCELERATOR_KIND_TPU: return TPUAcceleratorConfig(**value) - raise ValueError( - f"Unknown accelerator kind {kind!r}. Supported types: 'tpu'." - ) + raise ValueError(f"Unknown accelerator kind {kind!r}. Supported types: 'tpu'.") raise TypeError( f"accelerator_config must be a dict or AcceleratorConfig, got {type(value)}." ) From 26ae3e27f6deb6e13e6b61017b2e5fa5f9f95c9b Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Thu, 21 May 2026 08:55:35 +0000 Subject: [PATCH 24/26] move constant to constants.py, remove release_reservation_holders and change default to SPREAD strategy Signed-off-by: Ryan O'Leary --- python/ray/serve/_private/constants.py | 2 + .../serve/_private/deployment_scheduler.py | 8 +- python/ray/serve/_private/deployment_state.py | 5 - .../serve/_private/placement_group_utils.py | 4 +- python/ray/serve/config.py | 3 +- .../tests/unit/test_deployment_scheduler.py | 96 +++++++++++++++++++ 6 files changed, 108 insertions(+), 10 deletions(-) diff --git a/python/ray/serve/_private/constants.py b/python/ray/serve/_private/constants.py index 952e1aff90f4..bcd5c3dafe15 100644 --- a/python/ray/serve/_private/constants.py +++ b/python/ray/serve/_private/constants.py @@ -43,6 +43,8 @@ #: Ray namespace used for all Serve actors SERVE_NAMESPACE = "serve" +ACCELERATOR_KIND_TPU = "tpu" + DEFAULT_HTTP_HOST = os.environ.get("RAY_SERVE_DEFAULT_HTTP_HOST") #: HTTP Port diff --git a/python/ray/serve/_private/deployment_scheduler.py b/python/ray/serve/_private/deployment_scheduler.py index 3e8dddf8f3bf..2106be1cfdc9 100644 --- a/python/ray/serve/_private/deployment_scheduler.py +++ b/python/ray/serve/_private/deployment_scheduler.py @@ -21,6 +21,7 @@ ) from ray.serve._private.config import ReplicaConfig from ray.serve._private.constants import ( + ACCELERATOR_KIND_TPU, RAY_SERVE_HIGH_PRIORITY_CUSTOM_RESOURCES, RAY_SERVE_USE_COMPACT_SCHEDULING_STRATEGY, RAY_SERVE_USE_PACK_SCHEDULING_STRATEGY, @@ -671,7 +672,12 @@ def _schedule_replica( placement_group_strategy = ( scheduling_request.placement_group_strategy if scheduling_request.placement_group_strategy - else "PACK" + else ( + "SPREAD" + if getattr(scheduling_request.accelerator_config, "kind", None) + == ACCELERATOR_KIND_TPU + else "PACK" + ) ) try: pg_result = self._create_placement_group_fn( diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index d7a1e4b10351..262524da94b3 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -1436,8 +1436,6 @@ def check_ready(self) -> Tuple[ReplicaStartupStatus, Optional[str]]: try: # TODO(simon): fully implement reconfigure for Java replicas. if self._is_cross_language: - if self._replica_pg is not None: - self._replica_pg.release_reservation_holders() return ReplicaStartupStatus.SUCCEEDED, None # todo: The replica's userconfig whitch java client created @@ -1474,9 +1472,6 @@ def check_ready(self) -> Tuple[ReplicaStartupStatus, Optional[str]]: ) return ReplicaStartupStatus.FAILED, repr(e) - if self._replica_pg is not None: - self._replica_pg.release_reservation_holders() - return ReplicaStartupStatus.SUCCEEDED, None @property diff --git a/python/ray/serve/_private/placement_group_utils.py b/python/ray/serve/_private/placement_group_utils.py index 6795f54a2df9..818fc71c902c 100644 --- a/python/ray/serve/_private/placement_group_utils.py +++ b/python/ray/serve/_private/placement_group_utils.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional import ray -from ray.serve._private.constants import SERVE_LOGGER_NAME +from ray.serve._private.constants import ACCELERATOR_KIND_TPU, SERVE_LOGGER_NAME from ray.util.placement_group import PlacementGroup, remove_placement_group from ray.util.tpu import SlicePlacementGroup, slice_placement_group @@ -104,7 +104,7 @@ def _create_replica_placement_group( """ accelerator_config = request.accelerator_config - if getattr(accelerator_config, "kind", None) == "tpu": + if getattr(accelerator_config, "kind", None) == ACCELERATOR_KIND_TPU: slice_pg = slice_placement_group( topology=accelerator_config.topology, accelerator_version=accelerator_config.accelerator_version, diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index 2af55f6750fd..274d2b7538ad 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -27,6 +27,7 @@ # Import types needed for AutoscalingContext from ray.serve._private.common import DeploymentID, ReplicaID, TimeSeries from ray.serve._private.constants import ( + ACCELERATOR_KIND_TPU, DEFAULT_AUTOSCALING_POLICY_NAME, DEFAULT_GRPC_PORT, DEFAULT_HTTP_HOST, @@ -46,8 +47,6 @@ logger = logging.getLogger(SERVE_LOGGER_NAME) -ACCELERATOR_KIND_TPU = "tpu" - @PublicAPI(stability="stable") class AutoscalingContext: diff --git a/python/ray/serve/tests/unit/test_deployment_scheduler.py b/python/ray/serve/tests/unit/test_deployment_scheduler.py index 07e62d6e20e7..a34ccd41eef8 100644 --- a/python/ray/serve/tests/unit/test_deployment_scheduler.py +++ b/python/ray/serve/tests/unit/test_deployment_scheduler.py @@ -793,6 +793,102 @@ def set_scheduling_strategy(actor_handle, *args, **kwargs): assert scheduling_strategy == "some_default" +def test_placement_group_strategy_defaulting(): + """Validate that placement group strategy defaults to SPREAD for TPU configs and PACK for standard.""" + d_id = DeploymentID("strategy_test", "app1") + cluster_node_info_cache = MockClusterNodeInfoCache() + captured_requests = [] + + def mock_create_pg(request): + captured_requests.append(request) + return default_impl.ReplicaPlacementGroup( + placement_group=MockPlacementGroup(request) + ) + + scheduler = default_impl.create_deployment_scheduler( + cluster_node_info_cache, + head_node_id_override="fake-head-node-id", + create_placement_group_fn_override=mock_create_pg, + ) + scheduler.on_deployment_created(d_id, SpreadDeploymentSchedulingPolicy()) + scheduler.on_deployment_deployed(d_id, rconfig(ray_actor_options={"num_cpus": 1})) + + # Case 1: TPU Accelerator Config is set, placement_group_strategy is not. + # Expect strategy defaults to "SPREAD". + r0_id = ReplicaID(unique_id="r0", deployment_id=d_id) + acc_config = TPUAcceleratorConfig(topology="2x2", accelerator_version="v6e") + req_tpu = ReplicaSchedulingRequest( + replica_id=r0_id, + actor_def=MockActorClass(), + actor_resources={"CPU": 1}, + placement_group_bundles=None, + accelerator_config=acc_config, + placement_group_strategy=None, + actor_options={"name": "r0"}, + actor_init_args=(), + on_scheduled=lambda *args, **kwargs: None, + ) + scheduler._pending_replicas[d_id][r0_id] = req_tpu + scheduler._schedule_replica( + scheduling_request=req_tpu, + default_scheduling_strategy="some_default", + target_node_id=None, + target_labels=None, + ) + assert len(captured_requests) == 1 + assert captured_requests[0].strategy == "SPREAD" + + captured_requests.clear() + + # Case 2: TPU Accelerator Config is set, and placement_group_strategy is explicitly provided. + # Expect strategy is respected. + req_tpu_explicit = ReplicaSchedulingRequest( + replica_id=r0_id, + actor_def=MockActorClass(), + actor_resources={"CPU": 1}, + placement_group_bundles=None, + accelerator_config=acc_config, + placement_group_strategy="STRICT_PACK", + actor_options={"name": "r0"}, + actor_init_args=(), + on_scheduled=lambda *args, **kwargs: None, + ) + scheduler._pending_replicas[d_id][r0_id] = req_tpu_explicit + scheduler._schedule_replica( + scheduling_request=req_tpu_explicit, + default_scheduling_strategy="some_default", + target_node_id=None, + target_labels=None, + ) + assert len(captured_requests) == 1 + assert captured_requests[0].strategy == "STRICT_PACK" + + captured_requests.clear() + + # Case 3: Standard GPU/CPU config (bundles set), placement_group_strategy is not. + # Expect strategy defaults to "PACK". + req_std = ReplicaSchedulingRequest( + replica_id=r0_id, + actor_def=MockActorClass(), + actor_resources={"CPU": 1}, + placement_group_bundles=[{"CPU": 1}], + accelerator_config=None, + placement_group_strategy=None, + actor_options={"name": "r0"}, + actor_init_args=(), + on_scheduled=lambda *args, **kwargs: None, + ) + scheduler._pending_replicas[d_id][r0_id] = req_std + scheduler._schedule_replica( + scheduling_request=req_std, + default_scheduling_strategy="some_default", + target_node_id=None, + target_labels=None, + ) + assert len(captured_requests) == 1 + assert captured_requests[0].strategy == "PACK" + + def test_downscale_multiple_deployments(): """Test to make sure downscale prefers replicas without node id and then replicas on a node with fewest replicas of all deployments. From 234f19b1c49f0f84405611fcb06e3e546ee036a9 Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Thu, 21 May 2026 09:04:28 +0000 Subject: [PATCH 25/26] remove unused function Signed-off-by: Ryan O'Leary --- python/ray/serve/_private/placement_group_utils.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/python/ray/serve/_private/placement_group_utils.py b/python/ray/serve/_private/placement_group_utils.py index 818fc71c902c..dbfd86342c26 100644 --- a/python/ray/serve/_private/placement_group_utils.py +++ b/python/ray/serve/_private/placement_group_utils.py @@ -49,16 +49,6 @@ class ReplicaPlacementGroup: placement_group: Optional[PlacementGroup] _slice_pg: Optional[SlicePlacementGroup] = None - def release_reservation_holders(self) -> None: - """Call after ``placement_group.ready()`` resolves successfully. - - Releases any internal reservation-holder PGs (e.g. TPU head PGs) - that were only needed to claim resources during scheduling. No-op - for non-accelerator deployments. - """ - if self._slice_pg is not None: - self._slice_pg.release_head_pgs() - def shutdown(self) -> None: """Tear down the replica's PG(s). Idempotent.""" if self._slice_pg is not None: From 3d1bc7c70d7131dbf289aa798f7d45c124c72f45 Mon Sep 17 00:00:00 2001 From: Ryan O'Leary Date: Thu, 21 May 2026 18:14:40 +0000 Subject: [PATCH 26/26] fix missing import, resolve circular dependency Signed-off-by: Ryan O'Leary --- python/ray/serve/_private/test_utils.py | 2 +- python/ray/serve/_private/version.py | 6 +- python/ray/serve/api.py | 19 +---- python/ray/serve/config.py | 16 ++++ python/ray/serve/deployment.py | 19 +++++ python/ray/serve/schema.py | 38 +++++++++ .../tests/unit/test_deployment_scheduler.py | 2 +- python/ray/serve/tests/unit/test_schema.py | 80 +++++++++++++++++++ 8 files changed, 159 insertions(+), 23 deletions(-) diff --git a/python/ray/serve/_private/test_utils.py b/python/ray/serve/_private/test_utils.py index 3f2686ad7112..56f5f9322020 100644 --- a/python/ray/serve/_private/test_utils.py +++ b/python/ray/serve/_private/test_utils.py @@ -26,7 +26,6 @@ from ray.actor import ActorHandle from ray.serve._private.client import ServeControllerClient from ray.serve._private.common import ( - CreatePlacementGroupRequest, DeploymentID, DeploymentStatus, ReplicaID, @@ -45,6 +44,7 @@ ReplicaStartupStatus, ReplicaState, ) +from ray.serve._private.placement_group_utils import CreatePlacementGroupRequest from ray.serve._private.proxy import DRAINING_MESSAGE from ray.serve._private.replica_result import ReplicaResult from ray.serve._private.request_router import ( diff --git a/python/ray/serve/_private/version.py b/python/ray/serve/_private/version.py index 9a71b8652fd3..2171006c4778 100644 --- a/python/ray/serve/_private/version.py +++ b/python/ray/serve/_private/version.py @@ -125,10 +125,10 @@ def compute_hashes(self): else {} ) self.gang_scheduling_config_hash = crc32(serialized_gang_scheduling_config) - serialized_accelerator_config = _serialize( - self.deployment_config.accelerator_config.model_dump() + serialized_accelerator_config = ( + self.deployment_config.accelerator_config.model_dump_json().encode("utf-8") if self.deployment_config.accelerator_config is not None - else {} + else b"" ) self.accelerator_config_hash = crc32(serialized_accelerator_config) # Include app-level route prefix in the version hashes so changing diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 7c618ba13d46..bdf93b40839d 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -40,7 +40,6 @@ wait_for_interrupt, ) from ray.serve.config import ( - ACCELERATOR_KIND_TPU, AcceleratorConfig, AutoscalingConfig, ControllerOptions, @@ -49,7 +48,7 @@ HTTPOptions, ProxyLocation, RequestRouterConfig, - TPUAcceleratorConfig, + _resolve_accelerator_config, gRPCOptions, ) from ray.serve.context import ( @@ -470,22 +469,6 @@ async def __del__(self): return decorator -def _resolve_accelerator_config( - value: Union[Dict, AcceleratorConfig, None], -) -> Optional[AcceleratorConfig]: - - if value is None or isinstance(value, AcceleratorConfig): - return value - if isinstance(value, dict): - kind = value.get("kind") - if kind == ACCELERATOR_KIND_TPU: - return TPUAcceleratorConfig(**value) - raise ValueError(f"Unknown accelerator kind {kind!r}. Supported types: 'tpu'.") - raise TypeError( - f"accelerator_config must be a dict or AcceleratorConfig, got {type(value)}." - ) - - @PublicAPI(stability="stable") def deployment( _func_or_class: Optional[Callable] = None, diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index 274d2b7538ad..e61261d8acdb 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -1228,3 +1228,19 @@ def _validate_runtime_failure_policy(cls, v): "RESTART_REPLICA policy is not yet implemented. File a GitHub issue if you need this feature." ) return v + + +def _resolve_accelerator_config( + value: Union[Dict, AcceleratorConfig, None], +) -> Optional[AcceleratorConfig]: + + if value is None or isinstance(value, AcceleratorConfig): + return value + if isinstance(value, dict): + kind = value.get("kind") + if kind == ACCELERATOR_KIND_TPU: + return TPUAcceleratorConfig(**value) + raise ValueError(f"Unknown accelerator kind {kind!r}. Supported types: 'tpu'.") + raise TypeError( + f"accelerator_config must be a dict or AcceleratorConfig, got {type(value)}." + ) diff --git a/python/ray/serve/deployment.py b/python/ray/serve/deployment.py index 783b413d1b3c..4b55aa71a42f 100644 --- a/python/ray/serve/deployment.py +++ b/python/ray/serve/deployment.py @@ -13,9 +13,11 @@ from ray.serve._private.usage import ServeUsageTag from ray.serve._private.utils import DEFAULT, Default from ray.serve.config import ( + AcceleratorConfig, AutoscalingConfig, DeploymentActorConfig, GangSchedulingConfig, + _resolve_accelerator_config, ) from ray.serve.schema import DeploymentSchema, LoggingConfig, RayActorOptionsSchema from ray.util.annotations import PublicAPI @@ -257,6 +259,9 @@ def options( deployment_actors: Default[ Optional[List[Union[Dict, DeploymentActorConfig]]] ] = DEFAULT.VALUE, + accelerator_config: Default[ + Union[Dict, AcceleratorConfig, None] + ] = DEFAULT.VALUE, ) -> "Deployment": """Return a copy of this deployment with updated options. @@ -408,6 +413,18 @@ def options( if gang_scheduling_config is not DEFAULT.VALUE: new_deployment_config.gang_scheduling_config = gang_scheduling_config + if accelerator_config is not DEFAULT.VALUE: + if accelerator_config is not None: + accelerator_config = _resolve_accelerator_config(accelerator_config) + new_deployment_config.accelerator_config = accelerator_config + + ac = new_deployment_config.accelerator_config + gc = new_deployment_config.gang_scheduling_config + if ac is not None and gc is not None: + raise ValueError( + "Cannot specify both `accelerator_config` and `gang_scheduling_config`." + ) + if deployment_actors is not DEFAULT.VALUE: new_deployment_config.deployment_actors = deployment_actors @@ -513,6 +530,7 @@ def deployment_to_schema(d: Deployment) -> DeploymentSchema: "gang_scheduling_config": d._deployment_config.gang_scheduling_config, "deployment_actors": d._deployment_config.deployment_actors, "rolling_update_percentage": d._deployment_config.rolling_update_percentage, + "accelerator_config": d._deployment_config.accelerator_config, } # Let non-user-configured options be set to defaults. If the schema @@ -577,6 +595,7 @@ def schema_to_deployment(s: DeploymentSchema) -> Deployment: gang_scheduling_config=s.gang_scheduling_config, deployment_actors=s.deployment_actors, rolling_update_percentage=s.rolling_update_percentage, + accelerator_config=s.accelerator_config, ) deployment_config.user_configured_option_names = ( s._get_user_configured_option_names() diff --git a/python/ray/serve/schema.py b/python/ray/serve/schema.py index 5d2489d89cbb..1f63c4424ad0 100644 --- a/python/ray/serve/schema.py +++ b/python/ray/serve/schema.py @@ -38,6 +38,7 @@ from ray.serve._private.deployment_info import DeploymentInfo from ray.serve._private.utils import DEFAULT, validate_ssl_config from ray.serve.config import ( + AcceleratorConfig, AutoscalingConfig, AutoscalingPolicy, ControllerOptions, @@ -45,6 +46,7 @@ GangSchedulingConfig, ProxyLocation, RequestRouterConfig, + _resolve_accelerator_config, ) from ray.util.annotations import PublicAPI @@ -478,6 +480,10 @@ class DeploymentSchema(BaseModel): gt=0.0, le=1.0, ) + accelerator_config: Optional[Union[Dict, AcceleratorConfig]] = Field( + default=DEFAULT.VALUE, + description="Structured accelerator configuration for the deployment replicas.", + ) @model_validator(mode="before") @classmethod @@ -502,6 +508,20 @@ def validate_num_replicas_and_autoscaling_config(cls, values): return values + @model_validator(mode="before") + @classmethod + def validate_accelerator_config(cls, values): + accelerator_config = values.get("accelerator_config", None) + if accelerator_config in [None, DEFAULT.VALUE]: + return values + + if isinstance(accelerator_config, dict): + values["accelerator_config"] = _resolve_accelerator_config( + accelerator_config + ) + + return values + @model_validator(mode="before") @classmethod def validate_gang_scheduling_config(cls, values): @@ -626,6 +646,21 @@ def validate_placement_group_strategy_and_gang_scheduling_config(self): return self + @model_validator(mode="after") + def validate_accelerator_config_and_gang_scheduling_config(self): + accelerator_config = self.accelerator_config + gang_scheduling_config = self.gang_scheduling_config + + if accelerator_config not in [ + DEFAULT.VALUE, + None, + ] and gang_scheduling_config not in [DEFAULT.VALUE, None]: + raise ValueError( + "Cannot specify both `accelerator_config` and `gang_scheduling_config`." + ) + + return self + @model_validator(mode="after") def validate_max_queued_requests(self): max_queued_requests = self.max_queued_requests @@ -689,6 +724,9 @@ def _deployment_info_to_schema(name: str, info: DeploymentInfo) -> DeploymentSch info.deployment_config.gang_scheduling_config.model_dump() ) + if info.deployment_config.accelerator_config is not None: + schema.accelerator_config = info.deployment_config.accelerator_config + if info.deployment_config.deployment_actors is not None: deployment_actors = [] for cfg in info.deployment_config.deployment_actors: diff --git a/python/ray/serve/tests/unit/test_deployment_scheduler.py b/python/ray/serve/tests/unit/test_deployment_scheduler.py index a34ccd41eef8..287cf8420099 100644 --- a/python/ray/serve/tests/unit/test_deployment_scheduler.py +++ b/python/ray/serve/tests/unit/test_deployment_scheduler.py @@ -12,7 +12,6 @@ from ray.serve._private import default_impl from ray.serve._private.common import ( GANG_PG_NAME_PREFIX, - CreatePlacementGroupRequest, DeploymentID, GangPlacementGroupRequest, ReplicaID, @@ -32,6 +31,7 @@ SpreadDeploymentSchedulingPolicy, ) from ray.serve._private.deployment_state import DeploymentStateManager +from ray.serve._private.placement_group_utils import CreatePlacementGroupRequest from ray.serve._private.test_utils import ( MockActorClass, MockClusterNodeInfoCache, diff --git a/python/ray/serve/tests/unit/test_schema.py b/python/ray/serve/tests/unit/test_schema.py index d3d73c66b836..3bd4d35714e1 100644 --- a/python/ray/serve/tests/unit/test_schema.py +++ b/python/ray/serve/tests/unit/test_schema.py @@ -18,6 +18,7 @@ GangPlacementStrategy, GangRuntimeFailurePolicy, GangSchedulingConfig, + TPUAcceleratorConfig, ) from ray.serve.deployment import Deployment, deployment_to_schema, schema_to_deployment from ray.serve.schema import ( @@ -1219,6 +1220,85 @@ def test_schema_to_deployment_gang_scheduling_config_from_dict(): assert dep.num_replicas == 6 +def test_accelerator_config_deployment_schema_roundtrip(): + # Ensure deployment_to_schema -> schema_to_deployment preserves accelerator_config + accelerator_config = TPUAcceleratorConfig(topology="4x4", accelerator_version="v6e") + dc = DeploymentConfig.from_default( + num_replicas=4, + accelerator_config=accelerator_config, + ) + dc.user_configured_option_names = {"num_replicas", "accelerator_config"} + + rc = ReplicaConfig.create(deployment_def="", init_args=(), init_kwargs={}) + dep = Deployment( + name="TpuDep", + deployment_config=dc, + replica_config=rc, + _internal=True, + ) + + schema = deployment_to_schema(dep) + assert isinstance(schema.accelerator_config, TPUAcceleratorConfig) + assert schema.accelerator_config.topology == "4x4" + assert schema.accelerator_config.accelerator_version == "v6e" + assert schema.num_replicas == 4 + + dep2 = schema_to_deployment(schema) + ac2 = dep2._deployment_config.accelerator_config + assert isinstance(ac2, TPUAcceleratorConfig) + assert ac2.topology == "4x4" + assert ac2.accelerator_version == "v6e" + assert dep2.num_replicas == 4 + + +def test_schema_to_deployment_accelerator_config_from_dict(): + # Ensure schema_to_deployment works when accelerator_config comes from a dict + schema = DeploymentSchema.model_validate( + { + "name": "TpuDep", + "num_replicas": 2, + "accelerator_config": { + "kind": "tpu", + "topology": "2x2", + "accelerator_version": "v6e", + }, + } + ) + + assert isinstance(schema.accelerator_config, TPUAcceleratorConfig) + assert schema.accelerator_config.topology == "2x2" + + dep = schema_to_deployment(schema) + ac = dep._deployment_config.accelerator_config + assert isinstance(ac, TPUAcceleratorConfig) + assert ac.topology == "2x2" + assert ac.accelerator_version == "v6e" + assert dep.num_replicas == 2 + + +def test_mutual_exclusivity_accelerator_and_gang(): + # Cannot specify both accelerator_config and gang_scheduling_config + with pytest.raises(ValueError) as e: + DeploymentSchema.model_validate( + { + "name": "TpuDep", + "num_replicas": 2, + "accelerator_config": { + "kind": "tpu", + "topology": "2x2", + "accelerator_version": "v6e", + }, + "gang_scheduling_config": { + "gang_size": 2, + }, + } + ) + assert ( + "Cannot specify both `accelerator_config` and `gang_scheduling_config`" + in str(e.value) + ) + + def test_deployment_actors_deployment_schema_roundtrip(): """Ensure deployment_to_schema -> schema_to_deployment preserves deployment_actors.""" actor_config = DeploymentActorConfig(