Skip to content

Commit 220a1fa

Browse files
authored
Kubernetes: add AMD GPU support (#3178)
In addition, Pod.spec.affinity has been fixed (the inner structure was incorrect) Part-of: #3126
1 parent d501a93 commit 220a1fa

3 files changed

Lines changed: 208 additions & 102 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ dependencies = [
3232
"python-multipart>=0.0.16",
3333
"filelock",
3434
"psutil",
35-
"gpuhunt==0.1.8",
35+
"gpuhunt==0.1.10",
3636
"argcomplete>=3.5.0",
3737
"ignore-python>=0.2.0",
3838
"orjson",

src/dstack/_internal/core/backends/kubernetes/compute.py

Lines changed: 160 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from enum import Enum
66
from typing import List, Optional, Tuple
77

8-
from gpuhunt import KNOWN_NVIDIA_GPUS, AcceleratorVendor
8+
from gpuhunt import KNOWN_AMD_GPUS, KNOWN_NVIDIA_GPUS, AcceleratorVendor
99
from kubernetes import client
1010

1111
from dstack._internal.core.backends.base.compute import (
@@ -59,19 +59,31 @@
5959
logger = get_logger(__name__)
6060

6161
JUMP_POD_SSH_PORT = 22
62-
63-
NVIDIA_GPU_NAME_TO_GPU_INFO = {gpu.name: gpu for gpu in KNOWN_NVIDIA_GPUS}
64-
NVIDIA_GPU_NAMES = NVIDIA_GPU_NAME_TO_GPU_INFO.keys()
62+
DUMMY_REGION = "-"
6563

6664
NVIDIA_GPU_RESOURCE = "nvidia.com/gpu"
67-
NVIDIA_GPU_COUNT_LABEL = f"{NVIDIA_GPU_RESOURCE}.count"
68-
NVIDIA_GPU_PRODUCT_LABEL = f"{NVIDIA_GPU_RESOURCE}.product"
6965
NVIDIA_GPU_NODE_TAINT = NVIDIA_GPU_RESOURCE
66+
NVIDIA_GPU_PRODUCT_LABEL = f"{NVIDIA_GPU_RESOURCE}.product"
67+
68+
AMD_GPU_RESOURCE = "amd.com/gpu"
69+
AMD_GPU_NODE_TAINT = AMD_GPU_RESOURCE
70+
# The oldest but still supported label format, the safest option, see the commit message:
71+
# https://github.com/ROCm/k8s-device-plugin/commit/c0b0231b391a56bc9da4f362d561e25e960d7a48
72+
# E.g., beta.amd.com/gpu.device-id.74b5=4 - A node with four MI300X VF (0x74b5) GPUs
73+
# We cannot rely on the beta.amd.com/gpu.product-name.* label, as it may be missing, see the issue:
74+
# https://github.com/ROCm/k8s-device-plugin/issues/112
75+
AMD_GPU_DEVICE_ID_LABEL_PREFIX = f"beta.{AMD_GPU_RESOURCE}.device-id."
7076

7177
# Taints we know and tolerate when creating our objects, e.g., the jump pod.
72-
TOLERATED_NODE_TAINTS = (NVIDIA_GPU_NODE_TAINT,)
78+
TOLERATED_NODE_TAINTS = (NVIDIA_GPU_NODE_TAINT, AMD_GPU_NODE_TAINT)
7379

74-
DUMMY_REGION = "-"
80+
NVIDIA_GPU_NAME_TO_GPU_INFO = {gpu.name: gpu for gpu in KNOWN_NVIDIA_GPUS}
81+
NVIDIA_GPU_NAMES = NVIDIA_GPU_NAME_TO_GPU_INFO.keys()
82+
83+
AMD_GPU_DEVICE_ID_TO_GPU_INFO = {
84+
device_id: gpu_info for gpu_info in KNOWN_AMD_GPUS for device_id in gpu_info.device_ids
85+
}
86+
AMD_GPU_NAME_TO_DEVICE_IDS = {gpu.name: gpu.device_ids for gpu in KNOWN_AMD_GPUS}
7587

7688

7789
class Operator(str, Enum):
@@ -112,21 +124,15 @@ def get_offers_by_requirements(
112124
nodes = get_value(node_list, ".items", list[client.V1Node], required=True)
113125
for node in nodes:
114126
try:
115-
labels = get_value(node, ".metadata.labels", dict[str, str]) or {}
116127
name = get_value(node, ".metadata.name", str, required=True)
117-
cpus = _parse_cpu(
118-
get_value(node, ".status.allocatable['cpu']", str, required=True)
119-
)
120128
cpu_arch = normalize_arch(
121129
get_value(node, ".status.node_info.architecture", str)
122130
).to_cpu_architecture()
123-
memory_mib = _parse_memory(
124-
get_value(node, ".status.allocatable['memory']", str, required=True)
125-
)
126-
gpus, _ = _get_gpus_from_node_labels(labels)
127-
disk_size_mib = _parse_memory(
128-
get_value(node, ".status.allocatable['ephemeral-storage']", str, required=True)
129-
)
131+
allocatable = get_value(node, ".status.allocatable", dict[str, str], required=True)
132+
cpus = _parse_cpu(allocatable["cpu"])
133+
memory_mib = _parse_memory(allocatable["memory"])
134+
disk_size_mib = _parse_memory(allocatable["ephemeral-storage"])
135+
gpus = _get_node_gpus(node)
130136
except (AttributeError, KeyError, ValueError) as e:
131137
logger.exception("Failed to process node: %s: %s", type(e).__name__, e)
132138
continue
@@ -218,59 +224,18 @@ def run_job(
218224
"GPU is requested but the offer has no GPUs:"
219225
f" {gpu_spec=} {instance_offer=}",
220226
)
221-
offer_gpu = offer_gpus[0]
222-
matching_gpu_label_values: set[str] = set()
223-
# We cannot generate an expected GPU label value from the Gpu model instance
224-
# as the actual values may have additional components (socket, memory type, etc.)
225-
# that we don't preserve in the Gpu model, e.g., "NVIDIA-H100-80GB-HBM3".
226-
# Moreover, a single Gpu may match multiple label values.
227-
# As a workaround, we iterate and process all node labels once again (we already
228-
# processed them in `get_offers_by_requirements()`).
229-
node_list = call_api_method(
230-
self.api.list_node,
231-
client.V1NodeList,
232-
)
233-
nodes = get_value(node_list, ".items", list[client.V1Node], required=True)
234-
for node in nodes:
235-
labels = get_value(node, ".metadata.labels", dict[str, str])
236-
if not labels:
237-
continue
238-
gpus, gpu_label_value = _get_gpus_from_node_labels(labels)
239-
if not gpus or gpu_label_value is None:
240-
continue
241-
if gpus[0] == offer_gpu:
242-
matching_gpu_label_values.add(gpu_label_value)
243-
if not matching_gpu_label_values:
244-
raise ComputeError(
245-
f"GPU is requested but no matching GPU labels found: {gpu_spec=}"
246-
)
247-
logger.debug(
248-
"Requesting %d GPU(s), node labels: %s", gpu_min, matching_gpu_label_values
249-
)
250-
# TODO: support other GPU vendors
251-
resources_requests[NVIDIA_GPU_RESOURCE] = str(gpu_min)
252-
resources_limits[NVIDIA_GPU_RESOURCE] = str(gpu_min)
253-
node_affinity = client.V1NodeAffinity(
254-
required_during_scheduling_ignored_during_execution=[
255-
client.V1NodeSelectorTerm(
256-
match_expressions=[
257-
client.V1NodeSelectorRequirement(
258-
key=NVIDIA_GPU_PRODUCT_LABEL,
259-
operator=Operator.IN,
260-
values=list(matching_gpu_label_values),
261-
),
262-
],
263-
),
264-
],
227+
gpu_resource, node_affinity, node_taint = _get_pod_spec_parameters_for_gpu(
228+
self.api, offer_gpus[0]
265229
)
230+
logger.debug("Requesting GPU resource: %s=%d", gpu_resource, gpu_min)
231+
resources_requests[gpu_resource] = resources_limits[gpu_resource] = str(gpu_min)
266232
# It should be NoSchedule, but we also add NoExecute toleration just in case.
267233
for effect in [TaintEffect.NO_SCHEDULE, TaintEffect.NO_EXECUTE]:
268234
tolerations.append(
269235
client.V1Toleration(
270-
key=NVIDIA_GPU_NODE_TAINT, operator=Operator.EXISTS, effect=effect
236+
key=node_taint, operator=Operator.EXISTS, effect=effect
271237
)
272238
)
273-
274239
if (memory_min := resources_spec.memory.min) is not None:
275240
resources_requests["memory"] = _render_memory(memory_min)
276241
if (
@@ -332,7 +297,9 @@ def run_job(
332297
volume_mounts=volume_mounts,
333298
)
334299
],
335-
affinity=node_affinity,
300+
affinity=client.V1Affinity(
301+
node_affinity=node_affinity,
302+
),
336303
tolerations=tolerations,
337304
volumes=volumes_,
338305
),
@@ -551,34 +518,144 @@ def _render_memory(memory: Memory) -> str:
551518
return f"{float(memory)}Gi"
552519

553520

554-
def _get_gpus_from_node_labels(labels: dict[str, str]) -> tuple[list[Gpu], Optional[str]]:
521+
def _get_node_gpus(node: client.V1Node) -> list[Gpu]:
522+
node_name = get_value(node, ".metadata.name", str, required=True)
523+
allocatable = get_value(node, ".status.allocatable", dict[str, str], required=True)
524+
labels = get_value(node, ".metadata.labels", dict[str, str]) or {}
525+
for gpu_resource, gpu_getter in (
526+
(NVIDIA_GPU_RESOURCE, _get_nvidia_gpu_from_node_labels),
527+
(AMD_GPU_RESOURCE, _get_amd_gpu_from_node_labels),
528+
):
529+
_gpu_count = allocatable.get(gpu_resource)
530+
if not _gpu_count:
531+
continue
532+
gpu_count = int(_gpu_count)
533+
if gpu_count < 1:
534+
continue
535+
gpu = gpu_getter(labels)
536+
if gpu is None:
537+
logger.warning(
538+
"Node %s: GPU resource found, but failed to detect its model: %s=%d",
539+
node_name,
540+
gpu_resource,
541+
gpu_count,
542+
)
543+
return []
544+
return [gpu] * gpu_count
545+
logger.debug("Node %s: no GPU resource found", node_name)
546+
return []
547+
548+
549+
def _get_nvidia_gpu_from_node_labels(labels: dict[str, str]) -> Optional[Gpu]:
555550
# We rely on https://github.com/NVIDIA/k8s-device-plugin/tree/main/docs/gpu-feature-discovery
556551
# to detect gpus. Note that "nvidia.com/gpu.product" is not a short gpu name like "T4" or
557552
# "A100" but a product name like "Tesla-T4" or "A100-SXM4-40GB".
558553
# Thus, we convert the product name to a known gpu name.
559-
# TODO: support other GPU vendors
560-
gpu_count = labels.get(NVIDIA_GPU_COUNT_LABEL)
561554
gpu_product = labels.get(NVIDIA_GPU_PRODUCT_LABEL)
562-
if gpu_count is None or gpu_product is None:
563-
return [], None
564-
gpu_count = int(gpu_count)
565-
gpu_name = None
566-
for known_gpu_name in NVIDIA_GPU_NAMES:
567-
if known_gpu_name.lower() in gpu_product.lower().split("-"):
568-
gpu_name = known_gpu_name
555+
if gpu_product is None:
556+
return None
557+
for gpu_name in NVIDIA_GPU_NAMES:
558+
if gpu_name.lower() in gpu_product.lower().split("-"):
569559
break
570-
if gpu_name is None:
571-
return [], None
560+
else:
561+
return None
572562
gpu_info = NVIDIA_GPU_NAME_TO_GPU_INFO[gpu_name]
573563
gpu_memory = gpu_info.memory * 1024
574564
# A100 may come in two variants
575565
if "40GB" in gpu_product:
576566
gpu_memory = 40 * 1024
577-
gpus = [
578-
Gpu(vendor=AcceleratorVendor.NVIDIA, name=gpu_name, memory_mib=gpu_memory)
579-
for _ in range(gpu_count)
580-
]
581-
return gpus, gpu_product
567+
return Gpu(vendor=AcceleratorVendor.NVIDIA, name=gpu_name, memory_mib=gpu_memory)
568+
569+
570+
def _get_amd_gpu_from_node_labels(labels: dict[str, str]) -> Optional[Gpu]:
571+
# (AMDGPUInfo.name, AMDGPUInfo.memory) pairs
572+
gpus: set[tuple[str, int]] = set()
573+
for label in labels:
574+
if not label.startswith(AMD_GPU_DEVICE_ID_LABEL_PREFIX):
575+
continue
576+
_, _, _device_id = label.rpartition(".")
577+
device_id = int(_device_id, 16)
578+
gpu_info = AMD_GPU_DEVICE_ID_TO_GPU_INFO.get(device_id)
579+
if gpu_info is None:
580+
logger.warning("Unknown AMD GPU device id: %X", device_id)
581+
continue
582+
gpus.add((gpu_info.name, gpu_info.memory))
583+
if not gpus:
584+
return None
585+
if len(gpus) == 1:
586+
gpu_name, gpu_memory_gib = next(iter(gpus))
587+
return Gpu(vendor=AcceleratorVendor.AMD, name=gpu_name, memory_mib=gpu_memory_gib * 1024)
588+
logger.warning("Multiple AMD GPU models detected: %s, ignoring all GPUs", gpus)
589+
return None
590+
591+
592+
def _get_pod_spec_parameters_for_gpu(
593+
api: client.CoreV1Api, gpu: Gpu
594+
) -> tuple[str, client.V1NodeAffinity, str]:
595+
gpu_vendor = gpu.vendor
596+
assert gpu_vendor is not None
597+
if gpu_vendor == AcceleratorVendor.NVIDIA:
598+
node_affinity = _get_nvidia_gpu_node_affinity(api, gpu)
599+
return NVIDIA_GPU_RESOURCE, node_affinity, NVIDIA_GPU_NODE_TAINT
600+
if gpu_vendor == AcceleratorVendor.AMD:
601+
node_affinity = _get_amd_gpu_node_affinity(gpu)
602+
return AMD_GPU_RESOURCE, node_affinity, AMD_GPU_NODE_TAINT
603+
raise ComputeError(f"Unsupported GPU vendor: {gpu_vendor}")
604+
605+
606+
def _get_nvidia_gpu_node_affinity(api: client.CoreV1Api, gpu: Gpu) -> client.V1NodeAffinity:
607+
matching_gpu_label_values: set[str] = set()
608+
# We cannot generate an expected GPU label value from the Gpu model instance
609+
# as the actual values may have additional components (socket, memory type, etc.)
610+
# that we don't preserve in the Gpu model, e.g., "NVIDIA-H100-80GB-HBM3".
611+
# Moreover, a single Gpu may match multiple label values.
612+
# As a workaround, we iterate and process all node labels once again (we already
613+
# processed them in `get_offers_by_requirements()`).
614+
node_list = call_api_method(api.list_node, client.V1NodeList)
615+
nodes = get_value(node_list, ".items", list[client.V1Node], required=True)
616+
for node in nodes:
617+
labels = get_value(node, ".metadata.labels", dict[str, str]) or {}
618+
if _get_nvidia_gpu_from_node_labels(labels) == gpu:
619+
matching_gpu_label_values.add(labels[NVIDIA_GPU_PRODUCT_LABEL])
620+
if not matching_gpu_label_values:
621+
raise ComputeError(f"NVIDIA GPU is requested but no matching GPU labels found: {gpu=}")
622+
logger.debug("Selecting nodes by labels %s for NVIDIA %s", matching_gpu_label_values, gpu.name)
623+
return client.V1NodeAffinity(
624+
required_during_scheduling_ignored_during_execution=client.V1NodeSelector(
625+
node_selector_terms=[
626+
client.V1NodeSelectorTerm(
627+
match_expressions=[
628+
client.V1NodeSelectorRequirement(
629+
key=NVIDIA_GPU_PRODUCT_LABEL,
630+
operator=Operator.IN,
631+
values=list(matching_gpu_label_values),
632+
),
633+
],
634+
),
635+
],
636+
),
637+
)
638+
639+
640+
def _get_amd_gpu_node_affinity(gpu: Gpu) -> client.V1NodeAffinity:
641+
device_ids = AMD_GPU_NAME_TO_DEVICE_IDS.get(gpu.name)
642+
if device_ids is None:
643+
raise ComputeError(f"AMD GPU is requested but no matching device ids found: {gpu=}")
644+
return client.V1NodeAffinity(
645+
required_during_scheduling_ignored_during_execution=client.V1NodeSelector(
646+
node_selector_terms=[
647+
client.V1NodeSelectorTerm(
648+
match_expressions=[
649+
client.V1NodeSelectorRequirement(
650+
key=f"{AMD_GPU_DEVICE_ID_LABEL_PREFIX}{device_id:x}",
651+
operator=Operator.EXISTS,
652+
),
653+
],
654+
)
655+
for device_id in device_ids
656+
],
657+
),
658+
)
582659

583660

584661
def _continue_setup_jump_pod(
Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,53 @@
1-
from dstack._internal.core.backends.kubernetes.compute import _get_gpus_from_node_labels
1+
import logging
2+
3+
import pytest
4+
from gpuhunt import AcceleratorVendor
5+
6+
from dstack._internal.core.backends.kubernetes.compute import (
7+
_get_amd_gpu_from_node_labels,
8+
_get_nvidia_gpu_from_node_labels,
9+
)
210
from dstack._internal.core.models.instances import Gpu
311

412

5-
class TestGetGPUsFromNodeLabels:
13+
class TestGetNvidiaGPUFromNodeLabels:
14+
def test_returns_none_if_no_labels(self):
15+
assert _get_nvidia_gpu_from_node_labels({}) is None
16+
17+
def test_returns_correct_memory_for_different_A100(self):
18+
assert _get_nvidia_gpu_from_node_labels(
19+
{"nvidia.com/gpu.product": "A100-SXM4-40GB"}
20+
) == Gpu(vendor=AcceleratorVendor.NVIDIA, name="A100", memory_mib=40 * 1024)
21+
22+
assert _get_nvidia_gpu_from_node_labels(
23+
{"nvidia.com/gpu.product": "A100-SXM4-80GB"}
24+
) == Gpu(vendor=AcceleratorVendor.NVIDIA, name="A100", memory_mib=80 * 1024)
25+
26+
27+
class TestGetAMDGPUFromNodeLabels:
628
def test_returns_no_gpus_if_no_labels(self):
7-
assert _get_gpus_from_node_labels({}) == ([], None)
29+
assert _get_amd_gpu_from_node_labels({}) is None
830

9-
def test_returns_no_gpus_if_missing_labels(self):
10-
assert _get_gpus_from_node_labels({"nvidia.com/gpu.count": 1}) == ([], None)
31+
def test_returns_known_gpu(self):
32+
assert _get_amd_gpu_from_node_labels({"beta.amd.com/gpu.device-id.74b5": "4"}) == Gpu(
33+
vendor=AcceleratorVendor.AMD, name="MI300X", memory_mib=192 * 1024
34+
)
1135

12-
def test_returns_correct_memory_for_different_A100(self):
13-
assert _get_gpus_from_node_labels(
14-
{
15-
"nvidia.com/gpu.count": 1,
16-
"nvidia.com/gpu.product": "A100-SXM4-40GB",
17-
}
18-
) == ([Gpu(name="A100", memory_mib=40 * 1024)], "A100-SXM4-40GB")
19-
assert _get_gpus_from_node_labels(
20-
{
21-
"nvidia.com/gpu.count": 1,
22-
"nvidia.com/gpu.product": "A100-SXM4-80GB",
23-
}
24-
) == ([Gpu(name="A100", memory_mib=80 * 1024)], "A100-SXM4-80GB")
36+
def test_returns_known_gpu_if_multiple_device_ids_match_the_same_gpu(self):
37+
# 4x AMD Instinct MI300X VF + 4x AMD Instinct MI300X
38+
labels = {"beta.amd.com/gpu.device-id.74b5": "4", "beta.amd.com/gpu.device-id.74a1": "4"}
39+
assert _get_amd_gpu_from_node_labels(labels) == Gpu(
40+
vendor=AcceleratorVendor.AMD, name="MI300X", memory_mib=192 * 1024
41+
)
42+
43+
def test_returns_none_if_device_id_is_unknown(self, caplog: pytest.LogCaptureFixture):
44+
caplog.set_level(logging.WARNING)
45+
assert _get_amd_gpu_from_node_labels({"beta.amd.com/gpu.device-id.ffff": "4"}) is None
46+
assert "Unknown AMD GPU device id: FFFF" in caplog.text
47+
48+
def test_returns_none_if_multiple_gpu_models(self, caplog: pytest.LogCaptureFixture):
49+
caplog.set_level(logging.WARNING)
50+
# 4x AMD Instinct MI300X VF + 4x AMD Instinct MI325X
51+
labels = {"beta.amd.com/gpu.device-id.74b5": "4", "beta.amd.com/gpu.device-id.74a5": "4"}
52+
assert _get_amd_gpu_from_node_labels(labels) is None
53+
assert "Multiple AMD GPU models detected" in caplog.text

0 commit comments

Comments
 (0)