From 198eea2fefb45d986a3872f125051cbb41768bd9 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sat, 4 Apr 2026 18:12:42 +0000 Subject: [PATCH 1/2] update --- test/utils/test_to_dense_batch.py | 29 ++++++++++++++++++--- torch_geometric/utils/_to_dense_batch.py | 33 ++++++++++++------------ 2 files changed, 43 insertions(+), 19 deletions(-) diff --git a/test/utils/test_to_dense_batch.py b/test/utils/test_to_dense_batch.py index 7c13993a4496..093c4073cfb3 100644 --- a/test/utils/test_to_dense_batch.py +++ b/test/utils/test_to_dense_batch.py @@ -75,11 +75,11 @@ def test_to_dense_batch_disable_dynamic_shapes(): with set_experimental_mode(True, 'disable_dynamic_shapes'): with pytest.raises(ValueError, match="'batch_size' needs to be set"): - out, mask = to_dense_batch(x, batch, max_num_nodes=6) + to_dense_batch(x, batch, max_num_nodes=6) with pytest.raises(ValueError, match="'max_num_nodes' needs to be"): - out, mask = to_dense_batch(x, batch, batch_size=4) + to_dense_batch(x, batch, batch_size=4) with pytest.raises(ValueError, match="'batch_size' needs to be set"): - out, mask = to_dense_batch(x) + to_dense_batch(x) out, mask = to_dense_batch(x, batch_size=1, max_num_nodes=6) assert out.size() == (1, 6, 2) @@ -90,6 +90,29 @@ def test_to_dense_batch_disable_dynamic_shapes(): assert mask.size() == (3, 10) +def test_to_dense_batch_overflow(): + x = torch.tensor([ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + [7.0, 8.0], + [9.0, 10.0], + [11.0, 12.0], + ]) + batch = torch.tensor([0, 0, 1, 2, 2, 2]) + + expected = torch.tensor([ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [0.0, 0.0]], + [[7.0, 8.0], [9.0, 10.0]], + ]) + expected_mask = [[True, True], [True, False], [True, True]] + + out, mask = to_dense_batch(x, batch, max_num_nodes=2, batch_size=3) + assert torch.equal(out, expected) + assert mask.tolist() == expected_mask + + @onlyFullTest def test_to_dense_batch_jit(): @torch.jit.script diff --git a/torch_geometric/utils/_to_dense_batch.py b/torch_geometric/utils/_to_dense_batch.py index 7b6cc0552626..b38a119749ac 100644 --- a/torch_geometric/utils/_to_dense_batch.py +++ b/torch_geometric/utils/_to_dense_batch.py @@ -3,10 +3,7 @@ import torch from torch import Tensor -from torch_geometric.experimental import ( - disable_dynamic_shapes, - is_experimental_mode_enabled, -) +from torch_geometric.experimental import disable_dynamic_shapes from torch_geometric.utils import cumsum, scatter @@ -28,6 +25,11 @@ def to_dense_batch( N_{\max}}` is returned, holding information about the existence of fake-nodes in the dense representation. + .. note:: + When ``batch_size`` or ``max_num_nodes`` is not provided, this + function triggers a host-device synchronization to compute the value + from the input tensor ``batch``. + Args: x (Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`. @@ -107,30 +109,29 @@ def to_dense_batch( dim_size=batch_size, reduce='sum') cum_nodes = cumsum(num_nodes) - filter_nodes = False - dynamic_shapes_disabled = is_experimental_mode_enabled( - 'disable_dynamic_shapes') - if max_num_nodes is None: max_num_nodes = int(num_nodes.max()) - elif not dynamic_shapes_disabled and num_nodes.max() > max_num_nodes: - filter_nodes = True tmp = torch.arange(batch.size(0), device=x.device) - cum_nodes[batch] idx = tmp + (batch * max_num_nodes) - if filter_nodes: - mask = tmp < max_num_nodes - x, idx = x[mask], idx[mask] - size = [batch_size * max_num_nodes] + list(x.size())[1:] + # Redirect overflow rows (tmp >= max_num_nodes) to a "trash" slot at the + # end of the flat buffer. This avoids data-dependent boolean indexing. + valid = tmp < max_num_nodes + trash_idx = batch_size * max_num_nodes # index of the extra slot + idx = torch.where(valid, idx, trash_idx) + + flat_size = batch_size * max_num_nodes + 1 + size = [flat_size] + list(x.size())[1:] out = torch.as_tensor(fill_value, device=x.device, dtype=x.dtype) out = out.repeat(size) out[idx] = x + out = out[:batch_size * max_num_nodes] # drop the trash slot out = out.view([batch_size, max_num_nodes] + list(x.size())[1:]) - mask = torch.zeros(batch_size * max_num_nodes, dtype=torch.bool, - device=x.device) + mask = torch.zeros(flat_size, dtype=torch.bool, device=x.device) mask[idx] = 1 + mask = mask[:batch_size * max_num_nodes] mask = mask.view(batch_size, max_num_nodes) return out, mask From 9a8c5501d60e6101c126a7e20cc32ae2a7b32d15 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sat, 4 Apr 2026 21:34:33 +0000 Subject: [PATCH 2/2] update --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f4125aab5b53..86a9906cdd0d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ### Changed +- Improved runtime of `to_dense_batch` in both eager and `torch.compile` ([#10660](https://github.com/pyg-team/pytorch_geometric/pull/10660)) - Dropped support for TorchScript in `GATConv` and `GATv2Conv` for correctness ([#10596](https://github.com/pyg-team/pytorch_geometric/pull/10596)) ### Deprecated