diff --git a/CHANGES.md b/CHANGES.md index 8abf3cef..189107e0 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add `device='auto'` to select hardware acceleration like CUDA when + available, and CPU otherwise. + ### Changed ### Fixed diff --git a/docs/user/neuralnet.rst b/docs/user/neuralnet.rst index 0118b1ab..c1482bf0 100644 --- a/docs/user/neuralnet.rst +++ b/docs/user/neuralnet.rst @@ -298,8 +298,9 @@ As the name suggests, this determines which computation device should be used. If set to ``'cuda'``, the incoming data will be transferred to CUDA before being passed to the PyTorch :class:`~torch.nn.Module`. The device parameter adheres to the general syntax of the PyTorch device -parameter. If you want to prevent skorch from handling the device, set -``device=None``. +parameter. If set to ``'auto'``, hardware acceleration like CUDA is being +used if available, and CPU otherwise. If you want to prevent skorch from +handling the device, set ``device=None``. initialize() ^^^^^^^^^^^^ diff --git a/skorch/net.py b/skorch/net.py index 4d217460..a485e405 100644 --- a/skorch/net.py +++ b/skorch/net.py @@ -48,6 +48,7 @@ from skorch.utils import FirstStepAccumulator from skorch.utils import TeeGenerator from skorch.utils import _check_f_arguments +from skorch.utils import _check_device from skorch.utils import check_is_fitted from skorch.utils import duplicate_items from skorch.utils import get_map_location @@ -218,8 +219,10 @@ class NeuralNet(BaseEstimator): device : str, torch.device, or None (default='cpu') The compute device to be used. If set to 'cuda' in order to use GPU acceleration, data in torch tensors will be pushed to cuda - tensors before being sent to the module. If set to None, then - all compute devices will be left unmodified. + tensors before being sent to the module. If set to 'auto', + hardware acceleration like CUDA is being used if available, and + CPU otherwise. If set to None, then all compute devices will be + left unmodified. compile : bool (default=False) If set to ``True``, compile all modules using ``torch.compile``. For this @@ -2719,6 +2722,9 @@ def _check_device(self, requested_device, map_device): warnings.warn(msg, DeviceWarning) return map_device + if requested_device == 'auto': + return map_device + type_1 = torch.device(requested_device) type_2 = torch.device(map_device) if type_1 != type_2: @@ -2796,7 +2802,10 @@ def _get_state_dict(f_name): if isinstance(f_name, (str, os.PathLike)): state_dict = {} - with safe_open(f_name, framework='pt', device=self.device) as f: + with safe_open( + f_name, + framework='pt', + device=_check_device(self.device)) as f: for key in f.keys(): state_dict[key] = f.get_tensor(key) else: diff --git a/skorch/tests/test_net.py b/skorch/tests/test_net.py index 2ff94490..0b281cb6 100644 --- a/skorch/tests/test_net.py +++ b/skorch/tests/test_net.py @@ -346,6 +346,34 @@ def test_net_learns(self, net_cls, module_cls, data): y_pred = net.predict(X) assert accuracy_score(y, y_pred) > ACCURACY_EXPECTED + @pytest.mark.parametrize('cuda_available, expected', [ + (False, 'cpu'), + (True, 'cuda'), + ]) + def test_device_auto_fit_predict( + self, net_cls, module_cls, data, cuda_available, expected): + if cuda_available and not torch.cuda.is_available(): + pytest.skip() + + X, y = data + with patch('torch.cuda.is_available', lambda *_: cuda_available): + net = net_cls( + module_cls, + max_epochs=10, + lr=0.1, + device='auto', + ) + net.fit(X, y) + y_pred = net.predict(X) + y_forward = net.forward(X, device='auto') + + assert accuracy_score(y, y_pred) > ACCURACY_EXPECTED + assert all( + param.device.type == expected + for _, param in net.get_all_learnable_params() + ) + assert y_forward.device.type == expected + def test_forward(self, net_fit, data): X = data[0] n = len(X) @@ -468,6 +496,21 @@ def test_device_torch_device(self, net_cls, module_cls, device): net = net.initialize() assert net.module_.sequential[0].weight.device.type.startswith(device) + @pytest.mark.parametrize('cuda_available, expected', [ + (False, 'cpu'), + (True, 'cuda'), + ]) + def test_device_auto( + self, net_cls, module_cls, cuda_available, expected): + if cuda_available and not torch.cuda.is_available(): + pytest.skip() + + with patch('torch.cuda.is_available', lambda *_: cuda_available): + net = net_cls(module=module_cls, device='auto') + net = net.initialize() + + assert net.module_.sequential[0].weight.device.type == expected + @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device") @pytest.mark.parametrize( 'save_dev, cuda_available, load_dev, expect_warning', diff --git a/skorch/tests/test_utils.py b/skorch/tests/test_utils.py index 21ff4aac..c81751c7 100644 --- a/skorch/tests/test_utils.py +++ b/skorch/tests/test_utils.py @@ -1,6 +1,7 @@ """Test for utils.py""" from copy import deepcopy +from unittest.mock import patch import numpy as np import pytest @@ -33,6 +34,21 @@ def test_device_setting_cuda(self, to_tensor): t = to_tensor(t, device='cpu') assert t.device.type == 'cpu' + @pytest.mark.parametrize('cuda_available, expected', [ + (False, 'cpu'), + (True, 'cuda'), + ]) + def test_device_setting_auto( + self, to_tensor, cuda_available, expected): + if cuda_available and not torch.cuda.is_available(): + pytest.skip() + + x = np.ones((2, 3, 4)) + with patch('torch.cuda.is_available', lambda *_: cuda_available): + t = to_tensor(x, device='auto') + + assert t.device.type == expected + def tensors_equal(self, x, y): """"Test that tensors in diverse containers are equal.""" if isinstance(x, PackedSequence): @@ -247,6 +263,10 @@ def check_device_type(self, tensor, device_input, prev_device): if None is device_input: assert tensor.device.type == prev_device + elif device_input == 'auto': + expected = 'cuda' if torch.cuda.is_available() else 'cpu' + assert tensor.device.type == expected + else: assert tensor.device.type == device_input @@ -271,6 +291,35 @@ def test_check_device_torch_tensor(self, to_device, x, device_from, device_to): x = to_device(x, device=device_to) self.check_device_type(x, device_to, prev_device) + @pytest.mark.parametrize('cuda_available, expected', [ + (False, 'cpu'), + (True, 'cuda'), + ]) + def test_check_device_auto( + self, to_device, x, cuda_available, expected): + if cuda_available and not torch.cuda.is_available(): + pytest.skip() + + with patch('torch.cuda.is_available', lambda *_: cuda_available): + x = to_device(x, device='auto') + + assert x.device.type == expected + + @pytest.mark.parametrize('cuda_available, expected', [ + (False, 'cpu'), + (True, 'cuda'), + ]) + def test_get_map_location_auto(self, cuda_available, expected): + if cuda_available and not torch.cuda.is_available(): + pytest.skip() + + from skorch.utils import get_map_location + + with patch('torch.cuda.is_available', lambda *_: cuda_available): + map_location = get_map_location('auto') + + assert map_location.type == expected + @pytest.mark.parametrize('device_from, device_to', [ ('cpu', 'cpu'), ('cpu', 'cuda'), diff --git a/skorch/utils.py b/skorch/utils.py index 506537d0..e79ab3c3 100644 --- a/skorch/utils.py +++ b/skorch/utils.py @@ -59,6 +59,13 @@ def is_geometric_data_type(x): return isinstance(x, Data) +def _check_device(device): + """Resolve special device shortcuts.""" + if device == 'auto': + return 'cuda' if torch.cuda.is_available() else 'cpu' + return device + + # pylint: disable=not-callable def to_tensor(X, device, accept_sparse=False): """Turn input data to torch tensor. @@ -77,7 +84,8 @@ def to_tensor(X, device, accept_sparse=False): device : str, torch.device The compute device to be used. If set to 'cuda', data in torch tensors will be pushed to cuda tensors before being sent to the - module. + module. If set to 'auto', hardware acceleration like CUDA is + being used if available, and CPU otherwise. accept_sparse : bool (default=False) Whether to accept scipy sparse matrices as input. If False, @@ -89,6 +97,7 @@ def to_tensor(X, device, accept_sparse=False): output : torch Tensor """ + device = _check_device(device) to_tensor_ = partial(to_tensor, device=device) if is_torch_data_type(X): @@ -185,9 +194,12 @@ def to_device(X, device): device : str, torch.device The compute device to be used. If device=None, return the input - unmodified + unmodified. If device='auto', hardware acceleration like CUDA + is being used if available, and CPU otherwise. """ + device = _check_device(device) + if device is None: return X @@ -562,6 +574,7 @@ def get_map_location(target_device, fallback_device='cpu'): """ if target_device is None: target_device = fallback_device + target_device = _check_device(target_device) map_location = torch.device(target_device)