22import tempfile
33import threading
44import time
5+ from enum import Enum
56from typing import List , Optional , Tuple
67
78from gpuhunt import KNOWN_NVIDIA_GPUS , AcceleratorVendor
6263NVIDIA_GPU_NAME_TO_GPU_INFO = {gpu .name : gpu for gpu in KNOWN_NVIDIA_GPUS }
6364NVIDIA_GPU_NAMES = NVIDIA_GPU_NAME_TO_GPU_INFO .keys ()
6465
66+ NVIDIA_GPU_RESOURCE = "nvidia.com/gpu"
67+ NVIDIA_GPU_COUNT_LABEL = f"{ NVIDIA_GPU_RESOURCE } .count"
68+ NVIDIA_GPU_PRODUCT_LABEL = f"{ NVIDIA_GPU_RESOURCE } .product"
69+ NVIDIA_GPU_NODE_TAINT = NVIDIA_GPU_RESOURCE
70+
71+ # Taints we know and tolerate when creating our objects, e.g., the jump pod.
72+ TOLERATED_NODE_TAINTS = (NVIDIA_GPU_NODE_TAINT ,)
73+
6574DUMMY_REGION = "-"
6675
6776
77+ class Operator (str , Enum ):
78+ EXISTS = "Exists"
79+ IN = "In"
80+
81+
82+ class TaintEffect (str , Enum ):
83+ NO_EXECUTE = "NoExecute"
84+ NO_SCHEDULE = "NoSchedule"
85+ PREFER_NO_SCHEDULE = "PreferNoSchedule"
86+
87+
6888class KubernetesCompute (
6989 ComputeWithFilteredOffersCached ,
7090 ComputeWithPrivilegedSupport ,
@@ -181,6 +201,7 @@ def run_job(
181201 resources_requests : dict [str , str ] = {}
182202 resources_limits : dict [str , str ] = {}
183203 node_affinity : Optional [client .V1NodeAffinity ] = None
204+ tolerations : list [client .V1Toleration ] = []
184205 volumes_ : list [client .V1Volume ] = []
185206 volume_mounts : list [client .V1VolumeMount ] = []
186207
@@ -226,21 +247,28 @@ def run_job(
226247 "Requesting %d GPU(s), node labels: %s" , gpu_min , matching_gpu_label_values
227248 )
228249 # TODO: support other GPU vendors
229- resources_requests ["nvidia.com/gpu" ] = str (gpu_min )
230- resources_limits ["nvidia.com/gpu" ] = str (gpu_min )
250+ resources_requests [NVIDIA_GPU_RESOURCE ] = str (gpu_min )
251+ resources_limits [NVIDIA_GPU_RESOURCE ] = str (gpu_min )
231252 node_affinity = client .V1NodeAffinity (
232253 required_during_scheduling_ignored_during_execution = [
233254 client .V1NodeSelectorTerm (
234255 match_expressions = [
235256 client .V1NodeSelectorRequirement (
236- key = "nvidia.com/gpu.product" ,
237- operator = "In" ,
257+ key = NVIDIA_GPU_PRODUCT_LABEL ,
258+ operator = Operator . IN ,
238259 values = list (matching_gpu_label_values ),
239260 ),
240261 ],
241262 ),
242263 ],
243264 )
265+ # It should be NoSchedule, but we also add NoExecute toleration just in case.
266+ for effect in [TaintEffect .NO_SCHEDULE , TaintEffect .NO_EXECUTE ]:
267+ tolerations .append (
268+ client .V1Toleration (
269+ key = NVIDIA_GPU_NODE_TAINT , operator = Operator .EXISTS , effect = effect
270+ )
271+ )
244272
245273 if (memory_min := resources_spec .memory .min ) is not None :
246274 resources_requests ["memory" ] = _render_memory (memory_min )
@@ -304,6 +332,7 @@ def run_job(
304332 )
305333 ],
306334 affinity = node_affinity ,
335+ tolerations = tolerations ,
307336 volumes = volumes_ ,
308337 ),
309338 )
@@ -527,8 +556,8 @@ def _get_gpus_from_node_labels(labels: dict[str, str]) -> tuple[list[Gpu], Optio
527556 # "A100" but a product name like "Tesla-T4" or "A100-SXM4-40GB".
528557 # Thus, we convert the product name to a known gpu name.
529558 # TODO: support other GPU vendors
530- gpu_count = labels .get ("nvidia.com/gpu.count" )
531- gpu_product = labels .get ("nvidia.com/gpu.product" )
559+ gpu_count = labels .get (NVIDIA_GPU_COUNT_LABEL )
560+ gpu_product = labels .get (NVIDIA_GPU_PRODUCT_LABEL )
532561 if gpu_count is None or gpu_product is None :
533562 return [], None
534563 gpu_count = int (gpu_count )
@@ -647,6 +676,39 @@ def _create_jump_pod_service(
647676 namespace = namespace ,
648677 name = pod_name ,
649678 )
679+
680+ node_list = call_api_method (api .list_node , client .V1NodeList )
681+ nodes = get_value (node_list , ".items" , list [client .V1Node ], required = True )
682+ # False if we found at least one node without any "hard" taint, that is, if we don't need to
683+ # specify the toleration.
684+ toleration_required = True
685+ # (key, effect) pairs.
686+ tolerated_taints : set [tuple [str , str ]] = set ()
687+ for node in nodes :
688+ # True if the node has at least one NoExecute or NoSchedule taint.
689+ has_hard_taint = False
690+ taints = get_value (node , ".spec.taints" , list [client .V1Taint ]) or []
691+ for taint in taints :
692+ effect = get_value (taint , ".effect" , str , required = True )
693+ # A "soft" taint, ignore.
694+ if effect == TaintEffect .PREFER_NO_SCHEDULE :
695+ continue
696+ has_hard_taint = True
697+ key = get_value (taint , ".key" , str , required = True )
698+ if key in TOLERATED_NODE_TAINTS :
699+ tolerated_taints .add ((key , effect ))
700+ if not has_hard_taint :
701+ toleration_required = False
702+ break
703+ tolerations : list [client .V1Toleration ] = []
704+ if toleration_required :
705+ for key , effect in tolerated_taints :
706+ tolerations .append (
707+ client .V1Toleration (key = key , operator = Operator .EXISTS , effect = effect )
708+ )
709+ if not tolerations :
710+ logger .warning ("No appropriate node found, the jump pod may never be scheduled" )
711+
650712 commands = _get_jump_pod_commands (authorized_keys = ssh_public_keys )
651713 pod = client .V1Pod (
652714 metadata = client .V1ObjectMeta (
@@ -667,7 +729,8 @@ def _create_jump_pod_service(
667729 )
668730 ],
669731 )
670- ]
732+ ],
733+ tolerations = tolerations ,
671734 ),
672735 )
673736 call_api_method (
0 commit comments