ENH Add automatic device selection#1141
Conversation
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for creating this PR to support device='auto' in skorch. It already looks quite complete. I have a small comment, please check.
Moreover, I think we need to add a more complete test than the ones you've added. By that, I mean something like this test:
skorch/skorch/tests/test_net.py
Lines 338 to 347 in 5db0ddc
The reason is that only checking if the model weight is on the correct device is not enough. There are other places where there could be an error due to mismatched devices, so having a test that is more end-to-end is important.
What I also wondered about while checking this PR is if we can take the opportunity to be more open to non-CUDA devices like XPU. It won't be easy to test, as the CI doesn't have access to those machines and I also don't have one at home. But it could still be a nice addition. LMK what you think.
Regardless of whether that is added, I think it's a good idea to use a more general wording in the documentation, e.g. replacing "If set to 'auto', CUDA is used" with "If set to 'auto', hardware acceleration like CUDA is being used". That way, we don't need to rewrite all these sections if we eventually add support for other devices.
|
Thanks for the review. I pushed 441d2d1 with the pylint directive moved back to to_tensor, an end-to-end fit/predict test for device=auto, and broader docs wording around hardware acceleration. On XPU and other accelerators, I would prefer to keep this PR focused on the CUDA-to-CPU fallback and handle additional accelerators in a follow-up once we agree on the detection order. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for the updates. I still found a few parts that could be improved, but otherwise this looks good.
Regarding the support for other devices, I'm fine with only adding CUDA for now. The rest can be worked on in the future.
| state_dict = {} | ||
| with safe_open(f_name, framework='pt', device=self.device) as f: | ||
| with safe_open( | ||
| f_name, framework='pt', |
There was a problem hiding this comment.
| f_name, framework='pt', | |
| f_name, | |
| framework='pt', |
| return isinstance(x, Data) | ||
|
|
||
|
|
||
| def _check_device_auto(device): |
There was a problem hiding this comment.
Let's rename the function to just _check_device, as it may do more than checking for auto in the future. Also, let's add a short docstring.
| y_pred = net.predict(X) | ||
| assert accuracy_score(y, y_pred) > ACCURACY_EXPECTED | ||
|
|
||
| def test_device_auto_fit_predict(self, net_cls, module_cls, data): |
There was a problem hiding this comment.
Let's parametrize this function similar to the other tests that you added. Also please check the device of the modules. Finally, let's include a call to y_forward = net.forward(X, device='auto') and check the device of the returned tensor.
|
Thanks again for the follow-up review. I pushed
I also reran the focused and broader affected tests locally:
|
Summary
Add support for
device='auto'so skorch selects CUDA when it is available and falls back to CPU otherwise.This centralizes the resolution in
skorch.utilsand applies it to tensor/device movement, map-location handling, module initialization, and safetensors loading.Closes #576.
Testing
.venv/bin/python -m pytest skorch/tests/test_utils.py::TestToTensor::test_device_setting_auto skorch/tests/test_utils.py::TestToDevice::test_check_device_auto skorch/tests/test_utils.py::TestToDevice::test_get_map_location_auto skorch/tests/test_net.py::TestNeuralNet::test_device_auto.venv/bin/python -m pytest skorch/tests/test_utils.py skorch/tests/test_net.py::TestNeuralNet::test_device_torch_device skorch/tests/test_net.py::TestNeuralNet::test_pickle_save_load_device_is_nonegit diff --check