Skip to content

Commit ea9f11e

Browse files
By default use all gpus available in auto mode (#808)
1 parent ca64d76 commit ea9f11e

7 files changed

Lines changed: 42 additions & 27 deletions

File tree

changelog/808.changed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"auto" device selection now uses all available CUDA GPUs instead of only the first one

src/tabpfn/classifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,8 +1506,8 @@ def load_from_fit_state(
15061506
def to(self, device: DevicesSpecification) -> None:
15071507
"""Move the estimator to the given device(s).
15081508
1509-
If "auto": a single device is selected based on availability in the
1510-
following order of priority: "cuda:0", "mps", "cpu".
1509+
If "auto": devices are selected based on availability in the
1510+
following order of priority: all available CUDA GPUs, "mps", "cpu".
15111511
15121512
To manually select a single device: specify a PyTorch device string e.g.
15131513
"cuda:1". See PyTorch's documentation for information about supported

src/tabpfn/inference_tuning.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -90,23 +90,31 @@ class ClassifierEvalMetrics(str, Enum):
9090

9191

9292
METRIC_NAME_TO_OBJECTIVE = {
93-
"f1": lambda y_true, y_pred: -f1_score(
94-
y_true,
95-
y_pred,
96-
average="binary",
97-
zero_division=0,
93+
"f1": lambda y_true, y_pred: (
94+
-f1_score(
95+
y_true,
96+
y_pred,
97+
average="binary",
98+
zero_division=0,
99+
)
98100
),
99-
"accuracy": lambda y_true, y_pred: -accuracy_score(
100-
y_true,
101-
y_pred,
101+
"accuracy": lambda y_true, y_pred: (
102+
-accuracy_score(
103+
y_true,
104+
y_pred,
105+
)
102106
),
103-
"balanced_accuracy": lambda y_true, y_pred: -balanced_accuracy_score(
104-
y_true,
105-
y_pred,
107+
"balanced_accuracy": lambda y_true, y_pred: (
108+
-balanced_accuracy_score(
109+
y_true,
110+
y_pred,
111+
)
106112
),
107-
"roc_auc": lambda y_true, y_pred: -roc_auc_score(
108-
y_true,
109-
y_pred,
113+
"roc_auc": lambda y_true, y_pred: (
114+
-roc_auc_score(
115+
y_true,
116+
y_pred,
117+
)
110118
),
111119
"log_loss": log_loss,
112120
}

src/tabpfn/regressor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,8 +1203,8 @@ def load_from_fit_state(
12031203
def to(self, device: DevicesSpecification) -> None:
12041204
"""Move the estimator to the given device(s).
12051205
1206-
If "auto": a single device is selected based on availability in the
1207-
following order of priority: "cuda:0", "mps", "cpu".
1206+
If "auto": devices are selected based on availability in the
1207+
following order of priority: all available CUDA GPUs, "mps", "cpu".
12081208
12091209
To manually select a single device: specify a PyTorch device string e.g.
12101210
"cuda:1". See PyTorch's documentation for information about supported

src/tabpfn/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def infer_devices(devices: DevicesSpecification) -> tuple[torch.device, ...]:
117117
"""Selects the appropriate PyTorch devices for inference.
118118
119119
If `device` is "auto" then the devices are selected as follows:
120-
1. If CUDA is available and not excluded, returns the first "cuda" device
120+
1. If CUDA is available and not excluded, returns all available "cuda" devices
121121
2. Otherwise, if MPS is available and not excluded, returns the "mps" device
122122
3. Otherwise, returns the "cpu" device
123123
@@ -145,7 +145,9 @@ def infer_devices(devices: DevicesSpecification) -> tuple[torch.device, ...]:
145145

146146
if devices == "auto":
147147
if "cuda" not in exclude_devices and torch.cuda.is_available():
148-
return (torch.device("cuda:0"),)
148+
return tuple(
149+
torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())
150+
)
149151

150152
if _is_mps_supported() and "mps" not in exclude_devices:
151153
return (torch.device("mps"),)

tests/test_config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,14 @@ def test__parse_config__unused_keys__returns_unused_config(
4747

4848

4949
@dataclass
50-
class FakeConfig(ArchitectureConfig):
51-
a: int = 1
52-
b: FakeSubConfig = field(default_factory=lambda: FakeSubConfig())
50+
class FakeSubConfig:
51+
c: int = 2
5352

5453

5554
@dataclass
56-
class FakeSubConfig:
57-
c: int = 2
55+
class FakeConfig(ArchitectureConfig):
56+
a: int = 1
57+
b: FakeSubConfig = field(default_factory=FakeSubConfig)
5858

5959

6060
class FakeArchitectureModule(ArchitectureModule):

tests/test_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test__infer_devices__auto__single_cuda_gpu_available__selects_it(
3434
assert infer_devices(devices="auto") == (torch.device("cuda:0"),)
3535

3636

37-
def test__infer_devices__auto__multiple_cuda_gpus_available__selects_first(
37+
def test__infer_devices__auto__multiple_cuda_gpus_available__selects_all(
3838
mocker: MagicMock, monkeypatch: pytest.MonkeyPatch
3939
) -> None:
4040
monkeypatch.setenv("TABPFN_EXCLUDE_DEVICES", "")
@@ -43,7 +43,11 @@ def test__infer_devices__auto__multiple_cuda_gpus_available__selects_first(
4343
mock_cuda.device_count.return_value = 3
4444
mocker.patch("torch.backends.mps").is_available.return_value = True
4545

46-
assert infer_devices(devices="auto") == (torch.device("cuda:0"),)
46+
assert infer_devices(devices="auto") == (
47+
torch.device("cuda:0"),
48+
torch.device("cuda:1"),
49+
torch.device("cuda:2"),
50+
)
4751

4852

4953
def test__infer_devices__auto__cuda_and_mps_available_but_excluded__selects_cpu(

0 commit comments

Comments
 (0)