Skip to content
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).

### Fixed

- Fixed batched `OnDiskDataset` access for subsets and split-specific indices via `DataLoader` ([#10674](https://github.com/pyg-team/pytorch_geometric/pull/10674))
- Fix MovieLens dataset incompatibility with `sentence-transformers>=5.0.0` ([#10668](https://github.com/pyg-team/pytorch_geometric/pull/10668)
- Removed an unnecessary device synchronization in `torch_geometric.utils.softmax` ([#10499](https://github.com/pyg-team/pytorch_geometric/pull/10499))
- Fixed loading of legacy HuggingFace BERT checkpoints ([#10631](https://github.com/pyg-team/pytorch_geometric/pull/10631))
Expand Down
60 changes: 60 additions & 0 deletions test/data/test_on_disk_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os.path as osp
from typing import Any, Dict

import numpy as np
import torch

from torch_geometric.data import Data, OnDiskDataset
from torch_geometric.loader import DataLoader
from torch_geometric.testing import withPackage


Expand Down Expand Up @@ -109,3 +111,61 @@ def deserialize(self, mapping: Dict[str, Any]) -> Any:
assert out.num_nodes == data.num_nodes

dataset.close()


@withPackage('sqlite3')
def test_index_select_multi_get(tmp_path):
dataset = OnDiskDataset(tmp_path)
data_list = [Data(x=torch.tensor([i])) for i in range(10)]
dataset.extend(data_list)

subset = dataset.index_select([5, 6, 7, 8, 9])
nested_subset = subset.index_select([1, 3])

assert torch.equal(subset[0].x, data_list[5].x)
assert torch.equal(subset.get(0).x, data_list[0].x)

out_list = subset.multi_get([0, 2, 4])
assert [int(data.x.item()) for data in out_list] == [5, 7, 9]

out_list = subset.multi_get(np.array([True, False, True, False, True]))
assert [int(data.x.item()) for data in out_list] == [5, 7, 9]

assert subset.multi_get([]) == []
assert subset.multi_get(torch.tensor([], dtype=torch.long)) == []
assert subset.multi_get(torch.tensor([], dtype=torch.bool)) == []
assert subset.multi_get(np.zeros(5, dtype=bool)) == []

out_list = nested_subset.__getitems__([0, 1])
assert torch.equal(out_list[0].x, data_list[6].x)
assert torch.equal(out_list[1].x, data_list[8].x)

loader = DataLoader(subset, batch_size=3, shuffle=False)
batch = next(iter(loader))
assert batch.x.view(-1).tolist() == [5, 6, 7]

dataset.close()


@withPackage('sqlite3')
def test_direct_indices_multi_get(tmp_path):
dataset = OnDiskDataset(tmp_path)
data_list = [Data(x=torch.tensor([i])) for i in range(10)]
dataset.extend(data_list)

dataset._indices = [5, 6, 7, 8, 9]

assert torch.equal(dataset[0].x, data_list[5].x)
assert torch.equal(dataset.get(0).x, data_list[0].x)

out_list = dataset.multi_get([0, 1, 2])
assert [int(data.x.item()) for data in out_list] == [5, 6, 7]

out_list = dataset.__getitems__([0, 1, 2])
assert [int(data.x.item()) for data in out_list] == [5, 6, 7]

loader = DataLoader(dataset, batch_size=3, shuffle=False)
batch = next(iter(loader))
assert batch.x.view(-1).tolist() == [5, 6, 7]

dataset.close()
59 changes: 55 additions & 4 deletions torch_geometric/data/on_disk_dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import os
from typing import Any, Callable, Iterable, List, Optional, Sequence, Union

import numpy as np
import torch
from torch import Tensor

from torch_geometric.data import Database, RocksDatabase, SQLiteDatabase
from torch_geometric.data.data import BaseData
from torch_geometric.data.database import Schema
from torch_geometric.data.dataset import Dataset

IndexType = Union[Iterable[int], Tensor, np.ndarray, slice, range]


class OnDiskDataset(Dataset):
r"""Dataset base class for creating large graph datasets which do not
Expand Down Expand Up @@ -137,17 +141,63 @@ def extend(
self.db.multi_insert(range(start, end), data_list, batch_size)
self._numel += (end - start)

def _resolve_indices(
self,
indices: IndexType,
) -> List[int]:
base_indices = self.indices()

if isinstance(indices, slice):
indices = base_indices[indices]

if isinstance(indices, Tensor):
indices = indices.flatten().tolist()
elif isinstance(indices, np.ndarray):
indices = indices.flatten().tolist()

return [int(idx) for idx in indices]

if isinstance(indices, Tensor):
if indices.dtype == torch.bool:
indices = indices.flatten().nonzero(as_tuple=False).flatten()
indices = indices.flatten().tolist()
elif isinstance(indices, np.ndarray):
if indices.dtype == bool:
indices = indices.flatten().nonzero()[0]
indices = indices.flatten().tolist()
elif isinstance(indices, range):
indices = list(indices)
elif not isinstance(indices, Sequence):
indices = list(indices)

return [int(base_indices[idx]) for idx in indices]

def get(self, idx: int) -> BaseData:
r"""Gets the data object at index :obj:`idx`."""
r"""Gets the data object at the raw database index :obj:`idx`.

Note that subset-aware integer access is handled by
:meth:`Dataset.__getitem__`, which resolves :obj:`self.indices()[idx]`
before calling :meth:`get`.
"""
return self.deserialize(self.db.get(idx))

def multi_get(
self,
indices: Union[Iterable[int], Tensor, slice, range],
indices: IndexType,
batch_size: Optional[int] = None,
) -> List[BaseData]:
r"""Gets a list of data objects from the specified indices."""
if len(indices) == 1:
r"""Gets a list of data objects from the specified subset-local
indices.

In contrast to :meth:`get`, batched access is expected to follow the
same subset semantics as ``[self[idx] for idx in indices]``. As such,
indices are first resolved through :obj:`self.indices()`.
"""
indices = self._resolve_indices(indices)

if len(indices) == 0:
data_list = []
elif len(indices) == 1:
data_list = [self.db.get(indices[0])]
else:
data_list = self.db.multi_get(indices, batch_size)
Expand All @@ -158,6 +208,7 @@ def multi_get(
return data_list

def __getitems__(self, indices: List[int]) -> List[BaseData]:
r"""Gets a list of data objects for batched subset-aware access."""
return self.multi_get(indices)

def len(self) -> int:
Expand Down