Skip to content

Commit 6507ac4

Browse files
authored
Kubernetes: add NVIDIA GPU toleration (#3160)
Part-of: #3126
1 parent 4e7ff02 commit 6507ac4

File tree

1 file changed

+70
-7
lines changed
  • src/dstack/_internal/core/backends/kubernetes

1 file changed

+70
-7
lines changed

src/dstack/_internal/core/backends/kubernetes/compute.py

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import tempfile
33
import threading
44
import time
5+
from enum import Enum
56
from typing import List, Optional, Tuple
67

78
from gpuhunt import KNOWN_NVIDIA_GPUS, AcceleratorVendor
@@ -62,9 +63,28 @@
6263
NVIDIA_GPU_NAME_TO_GPU_INFO = {gpu.name: gpu for gpu in KNOWN_NVIDIA_GPUS}
6364
NVIDIA_GPU_NAMES = NVIDIA_GPU_NAME_TO_GPU_INFO.keys()
6465

66+
NVIDIA_GPU_RESOURCE = "nvidia.com/gpu"
67+
NVIDIA_GPU_COUNT_LABEL = f"{NVIDIA_GPU_RESOURCE}.count"
68+
NVIDIA_GPU_PRODUCT_LABEL = f"{NVIDIA_GPU_RESOURCE}.product"
69+
NVIDIA_GPU_NODE_TAINT = NVIDIA_GPU_RESOURCE
70+
71+
# Taints we know and tolerate when creating our objects, e.g., the jump pod.
72+
TOLERATED_NODE_TAINTS = (NVIDIA_GPU_NODE_TAINT,)
73+
6574
DUMMY_REGION = "-"
6675

6776

77+
class Operator(str, Enum):
78+
EXISTS = "Exists"
79+
IN = "In"
80+
81+
82+
class TaintEffect(str, Enum):
83+
NO_EXECUTE = "NoExecute"
84+
NO_SCHEDULE = "NoSchedule"
85+
PREFER_NO_SCHEDULE = "PreferNoSchedule"
86+
87+
6888
class KubernetesCompute(
6989
ComputeWithFilteredOffersCached,
7090
ComputeWithPrivilegedSupport,
@@ -181,6 +201,7 @@ def run_job(
181201
resources_requests: dict[str, str] = {}
182202
resources_limits: dict[str, str] = {}
183203
node_affinity: Optional[client.V1NodeAffinity] = None
204+
tolerations: list[client.V1Toleration] = []
184205
volumes_: list[client.V1Volume] = []
185206
volume_mounts: list[client.V1VolumeMount] = []
186207

@@ -226,21 +247,28 @@ def run_job(
226247
"Requesting %d GPU(s), node labels: %s", gpu_min, matching_gpu_label_values
227248
)
228249
# TODO: support other GPU vendors
229-
resources_requests["nvidia.com/gpu"] = str(gpu_min)
230-
resources_limits["nvidia.com/gpu"] = str(gpu_min)
250+
resources_requests[NVIDIA_GPU_RESOURCE] = str(gpu_min)
251+
resources_limits[NVIDIA_GPU_RESOURCE] = str(gpu_min)
231252
node_affinity = client.V1NodeAffinity(
232253
required_during_scheduling_ignored_during_execution=[
233254
client.V1NodeSelectorTerm(
234255
match_expressions=[
235256
client.V1NodeSelectorRequirement(
236-
key="nvidia.com/gpu.product",
237-
operator="In",
257+
key=NVIDIA_GPU_PRODUCT_LABEL,
258+
operator=Operator.IN,
238259
values=list(matching_gpu_label_values),
239260
),
240261
],
241262
),
242263
],
243264
)
265+
# It should be NoSchedule, but we also add NoExecute toleration just in case.
266+
for effect in [TaintEffect.NO_SCHEDULE, TaintEffect.NO_EXECUTE]:
267+
tolerations.append(
268+
client.V1Toleration(
269+
key=NVIDIA_GPU_NODE_TAINT, operator=Operator.EXISTS, effect=effect
270+
)
271+
)
244272

245273
if (memory_min := resources_spec.memory.min) is not None:
246274
resources_requests["memory"] = _render_memory(memory_min)
@@ -304,6 +332,7 @@ def run_job(
304332
)
305333
],
306334
affinity=node_affinity,
335+
tolerations=tolerations,
307336
volumes=volumes_,
308337
),
309338
)
@@ -527,8 +556,8 @@ def _get_gpus_from_node_labels(labels: dict[str, str]) -> tuple[list[Gpu], Optio
527556
# "A100" but a product name like "Tesla-T4" or "A100-SXM4-40GB".
528557
# Thus, we convert the product name to a known gpu name.
529558
# TODO: support other GPU vendors
530-
gpu_count = labels.get("nvidia.com/gpu.count")
531-
gpu_product = labels.get("nvidia.com/gpu.product")
559+
gpu_count = labels.get(NVIDIA_GPU_COUNT_LABEL)
560+
gpu_product = labels.get(NVIDIA_GPU_PRODUCT_LABEL)
532561
if gpu_count is None or gpu_product is None:
533562
return [], None
534563
gpu_count = int(gpu_count)
@@ -647,6 +676,39 @@ def _create_jump_pod_service(
647676
namespace=namespace,
648677
name=pod_name,
649678
)
679+
680+
node_list = call_api_method(api.list_node, client.V1NodeList)
681+
nodes = get_value(node_list, ".items", list[client.V1Node], required=True)
682+
# False if we found at least one node without any "hard" taint, that is, if we don't need to
683+
# specify the toleration.
684+
toleration_required = True
685+
# (key, effect) pairs.
686+
tolerated_taints: set[tuple[str, str]] = set()
687+
for node in nodes:
688+
# True if the node has at least one NoExecute or NoSchedule taint.
689+
has_hard_taint = False
690+
taints = get_value(node, ".spec.taints", list[client.V1Taint]) or []
691+
for taint in taints:
692+
effect = get_value(taint, ".effect", str, required=True)
693+
# A "soft" taint, ignore.
694+
if effect == TaintEffect.PREFER_NO_SCHEDULE:
695+
continue
696+
has_hard_taint = True
697+
key = get_value(taint, ".key", str, required=True)
698+
if key in TOLERATED_NODE_TAINTS:
699+
tolerated_taints.add((key, effect))
700+
if not has_hard_taint:
701+
toleration_required = False
702+
break
703+
tolerations: list[client.V1Toleration] = []
704+
if toleration_required:
705+
for key, effect in tolerated_taints:
706+
tolerations.append(
707+
client.V1Toleration(key=key, operator=Operator.EXISTS, effect=effect)
708+
)
709+
if not tolerations:
710+
logger.warning("No appropriate node found, the jump pod may never be scheduled")
711+
650712
commands = _get_jump_pod_commands(authorized_keys=ssh_public_keys)
651713
pod = client.V1Pod(
652714
metadata=client.V1ObjectMeta(
@@ -667,7 +729,8 @@ def _create_jump_pod_service(
667729
)
668730
],
669731
)
670-
]
732+
],
733+
tolerations=tolerations,
671734
),
672735
)
673736
call_api_method(

0 commit comments

Comments
 (0)