Skip to content

Commit 641c8cf

Browse files
committed
train: fallback to CPU on unsupported CUDA capability
1 parent c8a7695 commit 641c8cf

2 files changed

Lines changed: 22 additions & 0 deletions

File tree

src/training/trainer_runtime.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ def resolve_trainer_hw() -> tuple[str, int, str]:
2424
requested_devices = cfg_int("trainer_devices")
2525
strategy = cfg_str("trainer_strategy")
2626
if torch.cuda.is_available():
27+
capability = torch.cuda.get_device_capability(0)
28+
if int(capability[0]) < 7:
29+
log(
30+
"Detected CUDA device with compute capability "
31+
f"sm_{int(capability[0])}{int(capability[1])}, unsupported by current torch build. "
32+
"Falling back to CPU.",
33+
)
34+
return "cpu", 1, "auto"
2735
available = max(1, torch.cuda.device_count())
2836
devices = min(requested_devices, available)
2937
if requested_devices > available:

tests/test_training_trainer_runtime.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def test_resolve_trainer_hw_downgrades_ddp_spawn_when_one_gpu_available(self) ->
4343
CONFIG["trainer_strategy"] = "ddp_spawn"
4444
with (
4545
patch("training.trainer_runtime.torch.cuda.is_available", return_value=True),
46+
patch("training.trainer_runtime.torch.cuda.get_device_capability", return_value=(8, 0)),
4647
patch("training.trainer_runtime.torch.cuda.device_count", return_value=1),
4748
):
4849
accelerator, devices, strategy = resolve_trainer_hw()
@@ -55,13 +56,26 @@ def test_resolve_trainer_hw_keeps_ddp_spawn_when_two_gpus_available(self) -> Non
5556
CONFIG["trainer_strategy"] = "ddp_spawn"
5657
with (
5758
patch("training.trainer_runtime.torch.cuda.is_available", return_value=True),
59+
patch("training.trainer_runtime.torch.cuda.get_device_capability", return_value=(8, 0)),
5860
patch("training.trainer_runtime.torch.cuda.device_count", return_value=2),
5961
):
6062
accelerator, devices, strategy = resolve_trainer_hw()
6163
self.assertEqual(accelerator, "gpu")
6264
self.assertEqual(devices, 2)
6365
self.assertEqual(strategy, "ddp_spawn")
6466

67+
def test_resolve_trainer_hw_falls_back_to_cpu_for_unsupported_cuda_capability(self) -> None:
68+
CONFIG["trainer_devices"] = 1
69+
CONFIG["trainer_strategy"] = "auto"
70+
with (
71+
patch("training.trainer_runtime.torch.cuda.is_available", return_value=True),
72+
patch("training.trainer_runtime.torch.cuda.get_device_capability", return_value=(6, 0)),
73+
):
74+
accelerator, devices, strategy = resolve_trainer_hw()
75+
self.assertEqual(accelerator, "cpu")
76+
self.assertEqual(devices, 1)
77+
self.assertEqual(strategy, "auto")
78+
6579

6680
if __name__ == "__main__":
6781
unittest.main()

0 commit comments

Comments
 (0)