File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -93,14 +93,24 @@ def format(self, record: logging.LogRecord) -> str:
9393 if MAX_CONCURRENT_PER_REPO <= 0 :
9494 MAX_CONCURRENT_PER_REPO = None
9595
96- GPU_MAP = {
97- "t4" : modal . gpu . T4 () ,
98- "l4" : modal . gpu . L4 () ,
99- "a100" : modal . gpu . A100 () ,
100- "a100-80gb" : modal . gpu . A100_80GB () ,
101- "h100" : modal . gpu . H100 () ,
96+ GPU_LABEL_TO_ATTR = {
97+ "t4" : "T4" ,
98+ "l4" : "L4" ,
99+ "a100" : " A100" ,
100+ "a100-80gb" : " A100_80GB" ,
101+ "h100" : " H100" ,
102102}
103103
104+
105+ def _get_gpu_config (gpu_key : str ):
106+ attr_name = GPU_LABEL_TO_ATTR .get (gpu_key )
107+ if attr_name is None :
108+ return None
109+ gpu_cls = getattr (modal .gpu , attr_name , None )
110+ if gpu_cls is None :
111+ return None
112+ return gpu_cls ()
113+
104114# =============================================================================
105115# TRUST MODEL
106116# =============================================================================
@@ -488,7 +498,7 @@ async def github_webhook(request: Request):
488498 for label in job_labels :
489499 if label .startswith ("gpu:" ):
490500 gpu_key = label .split (":" , 1 )[1 ].lower ()
491- gpu_config = GPU_MAP . get (gpu_key )
501+ gpu_config = _get_gpu_config (gpu_key )
492502 if not gpu_config :
493503 logger .warning (
494504 "Unknown GPU type requested" ,
You can’t perform that action at this time.
0 commit comments