Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add `device='auto'` to select hardware acceleration like CUDA when
available, and CPU otherwise.

### Changed

### Fixed
Expand Down
5 changes: 3 additions & 2 deletions docs/user/neuralnet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,9 @@ As the name suggests, this determines which computation device should
be used. If set to ``'cuda'``, the incoming data will be transferred to
CUDA before being passed to the PyTorch :class:`~torch.nn.Module`. The
device parameter adheres to the general syntax of the PyTorch device
parameter. If you want to prevent skorch from handling the device, set
``device=None``.
parameter. If set to ``'auto'``, hardware acceleration like CUDA is being
used if available, and CPU otherwise. If you want to prevent skorch from
handling the device, set ``device=None``.

initialize()
^^^^^^^^^^^^
Expand Down
15 changes: 12 additions & 3 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from skorch.utils import FirstStepAccumulator
from skorch.utils import TeeGenerator
from skorch.utils import _check_f_arguments
from skorch.utils import _check_device
from skorch.utils import check_is_fitted
from skorch.utils import duplicate_items
from skorch.utils import get_map_location
Expand Down Expand Up @@ -218,8 +219,10 @@ class NeuralNet(BaseEstimator):
device : str, torch.device, or None (default='cpu')
The compute device to be used. If set to 'cuda' in order to use
GPU acceleration, data in torch tensors will be pushed to cuda
tensors before being sent to the module. If set to None, then
all compute devices will be left unmodified.
tensors before being sent to the module. If set to 'auto',
hardware acceleration like CUDA is being used if available, and
CPU otherwise. If set to None, then all compute devices will be
left unmodified.

compile : bool (default=False)
If set to ``True``, compile all modules using ``torch.compile``. For this
Expand Down Expand Up @@ -2719,6 +2722,9 @@ def _check_device(self, requested_device, map_device):
warnings.warn(msg, DeviceWarning)
return map_device

if requested_device == 'auto':
return map_device

type_1 = torch.device(requested_device)
type_2 = torch.device(map_device)
if type_1 != type_2:
Expand Down Expand Up @@ -2796,7 +2802,10 @@ def _get_state_dict(f_name):

if isinstance(f_name, (str, os.PathLike)):
state_dict = {}
with safe_open(f_name, framework='pt', device=self.device) as f:
with safe_open(
f_name,
framework='pt',
device=_check_device(self.device)) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
else:
Expand Down
43 changes: 43 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,34 @@ def test_net_learns(self, net_cls, module_cls, data):
y_pred = net.predict(X)
assert accuracy_score(y, y_pred) > ACCURACY_EXPECTED

@pytest.mark.parametrize('cuda_available, expected', [
(False, 'cpu'),
(True, 'cuda'),
])
def test_device_auto_fit_predict(
self, net_cls, module_cls, data, cuda_available, expected):
if cuda_available and not torch.cuda.is_available():
pytest.skip()

X, y = data
with patch('torch.cuda.is_available', lambda *_: cuda_available):
net = net_cls(
module_cls,
max_epochs=10,
lr=0.1,
device='auto',
)
net.fit(X, y)
y_pred = net.predict(X)
y_forward = net.forward(X, device='auto')

assert accuracy_score(y, y_pred) > ACCURACY_EXPECTED
assert all(
param.device.type == expected
for _, param in net.get_all_learnable_params()
)
assert y_forward.device.type == expected

def test_forward(self, net_fit, data):
X = data[0]
n = len(X)
Expand Down Expand Up @@ -468,6 +496,21 @@ def test_device_torch_device(self, net_cls, module_cls, device):
net = net.initialize()
assert net.module_.sequential[0].weight.device.type.startswith(device)

@pytest.mark.parametrize('cuda_available, expected', [
(False, 'cpu'),
(True, 'cuda'),
])
def test_device_auto(
self, net_cls, module_cls, cuda_available, expected):
if cuda_available and not torch.cuda.is_available():
pytest.skip()

with patch('torch.cuda.is_available', lambda *_: cuda_available):
net = net_cls(module=module_cls, device='auto')
net = net.initialize()

assert net.module_.sequential[0].weight.device.type == expected

@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device")
@pytest.mark.parametrize(
'save_dev, cuda_available, load_dev, expect_warning',
Expand Down
49 changes: 49 additions & 0 deletions skorch/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test for utils.py"""

from copy import deepcopy
from unittest.mock import patch

import numpy as np
import pytest
Expand Down Expand Up @@ -33,6 +34,21 @@ def test_device_setting_cuda(self, to_tensor):
t = to_tensor(t, device='cpu')
assert t.device.type == 'cpu'

@pytest.mark.parametrize('cuda_available, expected', [
(False, 'cpu'),
(True, 'cuda'),
])
def test_device_setting_auto(
self, to_tensor, cuda_available, expected):
if cuda_available and not torch.cuda.is_available():
pytest.skip()

x = np.ones((2, 3, 4))
with patch('torch.cuda.is_available', lambda *_: cuda_available):
t = to_tensor(x, device='auto')

assert t.device.type == expected

def tensors_equal(self, x, y):
""""Test that tensors in diverse containers are equal."""
if isinstance(x, PackedSequence):
Expand Down Expand Up @@ -247,6 +263,10 @@ def check_device_type(self, tensor, device_input, prev_device):
if None is device_input:
assert tensor.device.type == prev_device

elif device_input == 'auto':
expected = 'cuda' if torch.cuda.is_available() else 'cpu'
assert tensor.device.type == expected

else:
assert tensor.device.type == device_input

Expand All @@ -271,6 +291,35 @@ def test_check_device_torch_tensor(self, to_device, x, device_from, device_to):
x = to_device(x, device=device_to)
self.check_device_type(x, device_to, prev_device)

@pytest.mark.parametrize('cuda_available, expected', [
(False, 'cpu'),
(True, 'cuda'),
])
def test_check_device_auto(
self, to_device, x, cuda_available, expected):
if cuda_available and not torch.cuda.is_available():
pytest.skip()

with patch('torch.cuda.is_available', lambda *_: cuda_available):
x = to_device(x, device='auto')

assert x.device.type == expected

@pytest.mark.parametrize('cuda_available, expected', [
(False, 'cpu'),
(True, 'cuda'),
])
def test_get_map_location_auto(self, cuda_available, expected):
if cuda_available and not torch.cuda.is_available():
pytest.skip()

from skorch.utils import get_map_location

with patch('torch.cuda.is_available', lambda *_: cuda_available):
map_location = get_map_location('auto')

assert map_location.type == expected

@pytest.mark.parametrize('device_from, device_to', [
('cpu', 'cpu'),
('cpu', 'cuda'),
Expand Down
17 changes: 15 additions & 2 deletions skorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ def is_geometric_data_type(x):
return isinstance(x, Data)


def _check_device(device):
"""Resolve special device shortcuts."""
if device == 'auto':
return 'cuda' if torch.cuda.is_available() else 'cpu'
Comment thread
BenjaminBossan marked this conversation as resolved.
return device


# pylint: disable=not-callable
def to_tensor(X, device, accept_sparse=False):
"""Turn input data to torch tensor.
Expand All @@ -77,7 +84,8 @@ def to_tensor(X, device, accept_sparse=False):
device : str, torch.device
The compute device to be used. If set to 'cuda', data in torch
tensors will be pushed to cuda tensors before being sent to the
module.
module. If set to 'auto', hardware acceleration like CUDA is
being used if available, and CPU otherwise.

accept_sparse : bool (default=False)
Whether to accept scipy sparse matrices as input. If False,
Expand All @@ -89,6 +97,7 @@ def to_tensor(X, device, accept_sparse=False):
output : torch Tensor

"""
device = _check_device(device)
to_tensor_ = partial(to_tensor, device=device)

if is_torch_data_type(X):
Expand Down Expand Up @@ -185,9 +194,12 @@ def to_device(X, device):

device : str, torch.device
The compute device to be used. If device=None, return the input
unmodified
unmodified. If device='auto', hardware acceleration like CUDA
is being used if available, and CPU otherwise.

"""
device = _check_device(device)

if device is None:
return X

Expand Down Expand Up @@ -562,6 +574,7 @@ def get_map_location(target_device, fallback_device='cpu'):
"""
if target_device is None:
target_device = fallback_device
target_device = _check_device(target_device)

map_location = torch.device(target_device)

Expand Down
Loading