Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).

### Changed

- Improved runtime of `to_dense_batch` in both eager and `torch.compile` ([#10660](https://github.com/pyg-team/pytorch_geometric/pull/10660))
- Dropped support for TorchScript in `GATConv` and `GATv2Conv` for correctness ([#10596](https://github.com/pyg-team/pytorch_geometric/pull/10596))

### Deprecated
Expand Down
29 changes: 26 additions & 3 deletions test/utils/test_to_dense_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def test_to_dense_batch_disable_dynamic_shapes():

with set_experimental_mode(True, 'disable_dynamic_shapes'):
with pytest.raises(ValueError, match="'batch_size' needs to be set"):
out, mask = to_dense_batch(x, batch, max_num_nodes=6)
to_dense_batch(x, batch, max_num_nodes=6)
with pytest.raises(ValueError, match="'max_num_nodes' needs to be"):
out, mask = to_dense_batch(x, batch, batch_size=4)
to_dense_batch(x, batch, batch_size=4)
with pytest.raises(ValueError, match="'batch_size' needs to be set"):
out, mask = to_dense_batch(x)
to_dense_batch(x)

out, mask = to_dense_batch(x, batch_size=1, max_num_nodes=6)
assert out.size() == (1, 6, 2)
Expand All @@ -90,6 +90,29 @@ def test_to_dense_batch_disable_dynamic_shapes():
assert mask.size() == (3, 10)


def test_to_dense_batch_overflow():
x = torch.tensor([
[1.0, 2.0],
[3.0, 4.0],
[5.0, 6.0],
[7.0, 8.0],
[9.0, 10.0],
[11.0, 12.0],
])
batch = torch.tensor([0, 0, 1, 2, 2, 2])

expected = torch.tensor([
[[1.0, 2.0], [3.0, 4.0]],
[[5.0, 6.0], [0.0, 0.0]],
[[7.0, 8.0], [9.0, 10.0]],
])
expected_mask = [[True, True], [True, False], [True, True]]

out, mask = to_dense_batch(x, batch, max_num_nodes=2, batch_size=3)
assert torch.equal(out, expected)
assert mask.tolist() == expected_mask


@onlyFullTest
def test_to_dense_batch_jit():
@torch.jit.script
Expand Down
33 changes: 17 additions & 16 deletions torch_geometric/utils/_to_dense_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
import torch
from torch import Tensor

from torch_geometric.experimental import (
disable_dynamic_shapes,
is_experimental_mode_enabled,
)
from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.utils import cumsum, scatter


Expand All @@ -28,6 +25,11 @@ def to_dense_batch(
N_{\max}}` is returned, holding information about the existence of
fake-nodes in the dense representation.

.. note::
When ``batch_size`` or ``max_num_nodes`` is not provided, this
function triggers a host-device synchronization to compute the value
from the input tensor ``batch``.

Args:
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.
Expand Down Expand Up @@ -107,30 +109,29 @@ def to_dense_batch(
dim_size=batch_size, reduce='sum')
cum_nodes = cumsum(num_nodes)

filter_nodes = False
dynamic_shapes_disabled = is_experimental_mode_enabled(
'disable_dynamic_shapes')

if max_num_nodes is None:
max_num_nodes = int(num_nodes.max())
elif not dynamic_shapes_disabled and num_nodes.max() > max_num_nodes:
filter_nodes = True

tmp = torch.arange(batch.size(0), device=x.device) - cum_nodes[batch]
idx = tmp + (batch * max_num_nodes)
if filter_nodes:
mask = tmp < max_num_nodes
x, idx = x[mask], idx[mask]

size = [batch_size * max_num_nodes] + list(x.size())[1:]
# Redirect overflow rows (tmp >= max_num_nodes) to a "trash" slot at the
# end of the flat buffer. This avoids data-dependent boolean indexing.
valid = tmp < max_num_nodes
trash_idx = batch_size * max_num_nodes # index of the extra slot
idx = torch.where(valid, idx, trash_idx)

flat_size = batch_size * max_num_nodes + 1
size = [flat_size] + list(x.size())[1:]
out = torch.as_tensor(fill_value, device=x.device, dtype=x.dtype)
out = out.repeat(size)
out[idx] = x
out = out[:batch_size * max_num_nodes] # drop the trash slot
out = out.view([batch_size, max_num_nodes] + list(x.size())[1:])

mask = torch.zeros(batch_size * max_num_nodes, dtype=torch.bool,
device=x.device)
mask = torch.zeros(flat_size, dtype=torch.bool, device=x.device)
mask[idx] = 1
mask = mask[:batch_size * max_num_nodes]
mask = mask.view(batch_size, max_num_nodes)

return out, mask
Loading