Skip to content

ENH Add automatic device selection#1141

Open
adit24dhaya wants to merge 3 commits into
skorch-dev:masterfrom
adit24dhaya:add-auto-device-selection
Open

ENH Add automatic device selection#1141
adit24dhaya wants to merge 3 commits into
skorch-dev:masterfrom
adit24dhaya:add-auto-device-selection

Conversation

@adit24dhaya
Copy link
Copy Markdown

Summary

Add support for device='auto' so skorch selects CUDA when it is available and falls back to CPU otherwise.

This centralizes the resolution in skorch.utils and applies it to tensor/device movement, map-location handling, module initialization, and safetensors loading.

Closes #576.

Testing

  • .venv/bin/python -m pytest skorch/tests/test_utils.py::TestToTensor::test_device_setting_auto skorch/tests/test_utils.py::TestToDevice::test_check_device_auto skorch/tests/test_utils.py::TestToDevice::test_get_map_location_auto skorch/tests/test_net.py::TestNeuralNet::test_device_auto
  • .venv/bin/python -m pytest skorch/tests/test_utils.py skorch/tests/test_net.py::TestNeuralNet::test_device_torch_device skorch/tests/test_net.py::TestNeuralNet::test_pickle_save_load_device_is_none
  • git diff --check

@adit24dhaya adit24dhaya marked this pull request as ready for review May 17, 2026 12:32
Copy link
Copy Markdown
Collaborator

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for creating this PR to support device='auto' in skorch. It already looks quite complete. I have a small comment, please check.

Moreover, I think we need to add a more complete test than the ones you've added. By that, I mean something like this test:

def test_net_learns(self, net_cls, module_cls, data):
X, y = data
net = net_cls(
module_cls,
max_epochs=10,
lr=0.1,
)
net.fit(X, y)
y_pred = net.predict(X)
assert accuracy_score(y, y_pred) > ACCURACY_EXPECTED

The reason is that only checking if the model weight is on the correct device is not enough. There are other places where there could be an error due to mismatched devices, so having a test that is more end-to-end is important.

What I also wondered about while checking this PR is if we can take the opportunity to be more open to non-CUDA devices like XPU. It won't be easy to test, as the CI doesn't have access to those machines and I also don't have one at home. But it could still be a nice addition. LMK what you think.

Regardless of whether that is added, I think it's a good idea to use a more general wording in the documentation, e.g. replacing "If set to 'auto', CUDA is used" with "If set to 'auto', hardware acceleration like CUDA is being used". That way, we don't need to rewrite all these sections if we eventually add support for other devices.

Comment thread skorch/utils.py
@adit24dhaya
Copy link
Copy Markdown
Author

adit24dhaya commented May 18, 2026

Thanks for the review. I pushed 441d2d1 with the pylint directive moved back to to_tensor, an end-to-end fit/predict test for device=auto, and broader docs wording around hardware acceleration.

On XPU and other accelerators, I would prefer to keep this PR focused on the CUDA-to-CPU fallback and handle additional accelerators in a follow-up once we agree on the detection order.

Copy link
Copy Markdown
Collaborator

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates. I still found a few parts that could be improved, but otherwise this looks good.

Regarding the support for other devices, I'm fine with only adding CUDA for now. The rest can be worked on in the future.

Comment thread skorch/net.py Outdated
state_dict = {}
with safe_open(f_name, framework='pt', device=self.device) as f:
with safe_open(
f_name, framework='pt',
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f_name, framework='pt',
f_name,
framework='pt',

Comment thread skorch/utils.py Outdated
return isinstance(x, Data)


def _check_device_auto(device):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rename the function to just _check_device, as it may do more than checking for auto in the future. Also, let's add a short docstring.

Comment thread skorch/tests/test_net.py Outdated
y_pred = net.predict(X)
assert accuracy_score(y, y_pred) > ACCURACY_EXPECTED

def test_device_auto_fit_predict(self, net_cls, module_cls, data):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's parametrize this function similar to the other tests that you added. Also please check the device of the modules. Finally, let's include a call to y_forward = net.forward(X, device='auto') and check the device of the returned tensor.

@adit24dhaya
Copy link
Copy Markdown
Author

Thanks again for the follow-up review. I pushed 566a92c with the requested changes:

  • renamed _check_device_auto to _check_device and added a short docstring
  • split the safe_open arguments across lines
  • parametrized and strengthened the device='auto' end-to-end test to check module devices and forward(..., device='auto') output device

I also reran the focused and broader affected tests locally:

  • .venv/bin/python -m pytest skorch/tests/test_net.py::TestNeuralNet::test_device_auto_fit_predict skorch/tests/test_net.py::TestNeuralNet::test_device_auto skorch/tests/test_utils.py::TestToTensor::test_device_setting_auto skorch/tests/test_utils.py::TestToDevice::test_check_device_auto skorch/tests/test_utils.py::TestToDevice::test_get_map_location_auto
  • .venv/bin/python -m pytest skorch/tests/test_utils.py skorch/tests/test_net.py::TestNeuralNet::test_device_auto_fit_predict skorch/tests/test_net.py::TestNeuralNet::test_device_auto skorch/tests/test_net.py::TestNeuralNet::test_device_torch_device skorch/tests/test_net.py::TestNeuralNet::test_pickle_save_load_device_is_none
  • git diff --check

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

device='auto' should select the fastest possible device

2 participants