diff --git a/CHANGELOG.md b/CHANGELOG.md index f4125aab5b53..a0c12305cdcd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,12 +7,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ### Added +- Added `weights_only` parameter to `torch_load` for explicit control over safe deserialization ([#10666](https://github.com/pyg-team/pytorch_geometric/pull/10666)) + ### Changed - Dropped support for TorchScript in `GATConv` and `GATv2Conv` for correctness ([#10596](https://github.com/pyg-team/pytorch_geometric/pull/10596)) ### Deprecated +- Deprecated the implicit `weights_only=False` fallback in `torch_load`; pass `weights_only` explicitly ([#10666](https://github.com/pyg-team/pytorch_geometric/pull/10666)) - Deprecated support for `torch-spline-conv` in favor of `pyg-lib>=0.6.0` ([#10622](https://github.com/pyg-team/pytorch_geometric/pull/10622)) ### Removed diff --git a/test/io/test_fs.py b/test/io/test_fs.py index d9227c6a230b..e39ad880dc3f 100644 --- a/test/io/test_fs.py +++ b/test/io/test_fs.py @@ -128,5 +128,29 @@ def test_torch_save_load(tmp_fs_path): path = osp.join(tmp_fs_path, 'x.pt') fs.torch_save(x, path) - out = fs.torch_load(path) + out = fs.torch_load(path, weights_only=True) assert torch.equal(x, out) + + +@pytest.mark.skipif( + not torch_geometric.typing.WITH_PT24, + reason='weights_only requires PyTorch >= 2.4', +) +def test_torch_load_fallback_warning(tmp_fs_path): + # Save an object that cannot be loaded with weights_only=True: + path = osp.join(tmp_fs_path, 'data.pt') + fs.torch_save(object(), path) + + with pytest.warns(FutureWarning, match='weights_only'): + fs.torch_load(path) + + +@pytest.mark.skipif( + not torch_geometric.typing.WITH_PT24, + reason='weights_only requires PyTorch >= 2.4', +) +def test_torch_load_weights_only_false(tmp_fs_path): + path = osp.join(tmp_fs_path, 'data.pt') + fs.torch_save(object(), path) + out = fs.torch_load(path, weights_only=False) + assert isinstance(out, object) diff --git a/torch_geometric/data/in_memory_dataset.py b/torch_geometric/data/in_memory_dataset.py index bcfcba02ee2e..c186250f45c6 100644 --- a/torch_geometric/data/in_memory_dataset.py +++ b/torch_geometric/data/in_memory_dataset.py @@ -126,9 +126,14 @@ def save(cls, data_list: Sequence[BaseData], path: str) -> None: data, slices = cls.collate(data_list) fs.torch_save((data.to_dict(), slices, data.__class__), path) - def load(self, path: str, data_cls: Type[BaseData] = Data) -> None: + def load( + self, + path: str, + data_cls: Type[BaseData] = Data, + weights_only: Optional[bool] = None, + ) -> None: r"""Loads the dataset from the file path :obj:`path`.""" - out = fs.torch_load(path) + out = fs.torch_load(path, weights_only=weights_only) assert isinstance(out, tuple) assert len(out) == 2 or len(out) == 3 if len(out) == 2: # Backward compatibility. diff --git a/torch_geometric/io/fs.py b/torch_geometric/io/fs.py index c88a04ad5a9a..4f68bde27883 100644 --- a/torch_geometric/io/fs.py +++ b/torch_geometric/io/fs.py @@ -214,8 +214,38 @@ def torch_save(data: Any, path: str) -> None: f.write(buffer.getvalue()) -def torch_load(path: str, map_location: Any = None) -> Any: +def torch_load( + path: str, + map_location: Any = None, + weights_only: Optional[bool] = None, +) -> Any: + r"""Load a PyTorch file from a given path using :func:`fsspec`. + + Args: + path (str): The path to the file to load. + map_location: A simplified version of :attr:`torch.load`'s + :attr:`map_location`. + weights_only (bool, optional): If :obj:`True`, only weights will be + loaded and an error will be raised on failure (*i.e.*, no + fallback to :obj:`weights_only=False`). + If :obj:`False`, :obj:`weights_only=False` will be used + directly. + If :obj:`None` (default), the current fallback behavior is + preserved: first tries :obj:`weights_only=True`, then falls + back to :obj:`weights_only=False` on + :class:`~pickle.UnpicklingError`. (default: :obj:`None`) + """ if torch_geometric.typing.WITH_PT24: + if weights_only is True: + with fsspec.open(path, 'rb') as f: + return torch.load(f, map_location, weights_only=True) + + if weights_only is False: + with fsspec.open(path, 'rb') as f: + return torch.load(f, map_location, weights_only=False) + + # Default behavior (weights_only=None): try weights_only=True + # first, then fall back to weights_only=False on failure. try: with fsspec.open(path, 'rb') as f: return torch.load(f, map_location, weights_only=True) @@ -230,10 +260,21 @@ def torch_load(path: str, map_location: Any = None) -> Any: warnings.warn( f"{warn_msg} Please use " f"`torch.serialization.{match.group()}` to " - f"allowlist this global.", stacklevel=2) + f"allowlist this global if you trust it.", + stacklevel=2) else: warnings.warn(warn_msg, stacklevel=2) + warnings.warn( + "Falling back to `weights_only=False` because the file at" + f"'{path}' could not be loaded with `weights_only=True`." + "In a future release, this fallback will be removed. Pass " + "`weights_only=False` explicitly if you need to load " + "custom Python objects, allowlist doesn't work, and you" + "trust the source.", + FutureWarning, + stacklevel=2, + ) with fsspec.open(path, 'rb') as f: return torch.load(f, map_location, weights_only=False) else: