Skip to content

Commit b698f8f

Browse files
committed
fix scatter
1 parent dbc1569 commit b698f8f

File tree

4 files changed

+12
-73
lines changed

4 files changed

+12
-73
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
This repository contains the Cosmo neural network lift and convolution layers. For a usage example and reproduction of the results of the RECOMB 2026 submission "Gaining mechanistic insight from geometric deep learning on molecule structures through equivariant convolution", see https://github.com/BorgwardtLab/RECOMB2026Cosmo.
66

7-
Installation: `pip install cosmic-torch` or `pip install git+https://github.com/BorgwardtLab/Cosmo`
7+
Installation: `pip install cosmic-torch` or `pip install git+https://github.com/BorgwardtLab/Cosmo`. Make sure to before install [torch](https://pytorch.org/get-started/locally/) and [torch-scatter](https://pypi.org/project/torch-scatter/) according to their instructions.
88

99
### Cosmo
1010

cosmic/cosmo.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
from torch import nn
3-
4-
from .utilities import scatter_mean, scatter_softmax, scatter_sum
3+
from torch_scatter import scatter_mean, scatter_softmax, scatter_sum
54

65
"""
76
Cosmo can be implemented with various filter functions. The underlying principle is always to compute the filter under transformation of a local reference frame (hood_coords) which is derived from neighboring input points. The forward signature of the layer is always the same and inputs can be obtained from a Lift2D or Lift3D module.
@@ -41,7 +40,7 @@ def forward(
4140
w = self.w[:, nn_idx] # use closest kernel point
4241
f = features[source]
4342
out_channels = torch.einsum("ni,oni->no", f, w) # m x out
44-
features = scatter_sum(out_channels, target, m)
43+
features = scatter_sum(out_channels, target, dim_size=m, dim=0)
4544
return features # Updated features of shape (m, out_channels)
4645

4746

@@ -94,7 +93,7 @@ def forward(
9493
)
9594
f = features[source]
9695
out_channels = torch.einsum("ni,noi->no", f, w) # m x out
97-
features = scatter_mean(out_channels, target, m)
96+
features = scatter_mean(out_channels, target, dim_size=m, dim=0)
9897
return features # Updated features of shape (m, out_channels)
9998

10099

@@ -136,6 +135,6 @@ def forward(
136135
w1 = self.w1(features)
137136
w2 = self.w2(features)
138137
w3 = self.w3(features)
139-
a = scatter_softmax(w1[target] - w2[source] + d, target, m)
140-
features = scatter_sum(a * (w3[source] + d), target, m)
138+
a = scatter_softmax(w1[target] - w2[source] + d, target, dim_size=m, dim=0)
139+
features = scatter_sum(a * (w3[source] + d), target, dim_size=m, dim=0)
141140
return features # Updated features of shape (m, out_channels)

cosmic/lift.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from types import SimpleNamespace
22

33
import torch
4+
from torch_scatter import scatter_max, scatter_mean, scatter_softmax, scatter_sum
45

56
from .utilities import *
67

@@ -123,12 +124,12 @@ def __init__(self, agg="mean"):
123124

124125
def __call__(self, features, index, size, return_index=False):
125126
if self.agg == "sum":
126-
return scatter_sum(features, index, size)
127+
return scatter_sum(features, index, dim_size=size, dim=0)
127128
elif self.agg == "mean":
128-
return scatter_mean(features, index, size)
129+
return scatter_mean(features, index, dim_size=size, dim=0)
129130
elif self.agg == "max":
130-
result = scatter_max(features, index, size)
131+
result = scatter_max(features, index, dim_size=size, dim=0)
131132
return result[0] if not return_index else result
132133
elif self.agg == "softmax":
133-
a = scatter_softmax(features, index, size)
134-
return scatter_sum(a * features, index, size)
134+
a = scatter_softmax(features, index, dim_size=size, dim=0)
135+
return scatter_sum(a * features, index, dim_size=size, dim=0)

cosmic/utilities.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -131,64 +131,3 @@ def filter_bases(bases, minimum_angle, coords, triangles):
131131
valid = angles_deg >= minimum_angle
132132
bases[~valid] = torch.nan
133133
return bases
134-
135-
136-
def scatter_sum(src, index, size):
137-
assert index.dtype == torch.long
138-
assert index.shape[0] == src.shape[0]
139-
out_shape = (size,) + src.shape[1:]
140-
out = torch.zeros(out_shape, dtype=src.dtype, device=src.device)
141-
out.index_add_(0, index, src)
142-
return out
143-
144-
145-
def scatter_mean(src, index, size):
146-
sum_out = scatter_sum(src, index, size)
147-
count = scatter_sum(torch.ones_like(src), index, size)
148-
return sum_out / count.clamp_min(1)
149-
150-
151-
def scatter_max(src, index, size):
152-
assert src.is_floating_point() and index.dtype == torch.long
153-
N = src.size(0)
154-
idx_exp = index.view(N, *([1] * (src.ndim - 1))).expand_as(src)
155-
vals = torch.full(
156-
(size,) + src.shape[1:],
157-
torch.finfo(src.dtype).min,
158-
dtype=src.dtype,
159-
device=src.device,
160-
)
161-
vals.scatter_reduce_(0, idx_exp, src, reduce="amax")
162-
pos = (
163-
torch.arange(N, device=src.device)
164-
.view(N, *([1] * (src.ndim - 1)))
165-
.expand_as(src)
166-
)
167-
pos_mask = torch.where(src == vals[index], pos, torch.full_like(pos, N))
168-
argmax = torch.full_like(vals, N, dtype=torch.long)
169-
argmax.scatter_reduce_(0, idx_exp, pos_mask, reduce="amin")
170-
argmax[argmax == N] = -1
171-
count = torch.zeros_like(vals, dtype=torch.long)
172-
count.scatter_reduce_(
173-
0, idx_exp, torch.ones_like(src, dtype=torch.long), reduce="sum"
174-
)
175-
vals[count == 0] = 0
176-
return vals, argmax
177-
178-
179-
def scatter_softmax(src, index, size, eps=1e-12):
180-
assert src.is_floating_point()
181-
assert index.dtype == torch.long
182-
N = src.shape[0]
183-
idx_exp = index.view(N, *([1] * (src.ndim - 1))).expand_as(src)
184-
max_vals = torch.full(
185-
(size,) + src.shape[1:],
186-
torch.finfo(src.dtype).min,
187-
dtype=src.dtype,
188-
device=src.device,
189-
)
190-
max_vals.scatter_reduce_(0, idx_exp, src, reduce="amax")
191-
ex = (src - max_vals[index]).exp()
192-
sum_per_group = torch.zeros_like(max_vals)
193-
sum_per_group.scatter_reduce_(0, idx_exp, ex, reduce="sum")
194-
return ex / sum_per_group[index].clamp_min(eps)

0 commit comments

Comments
 (0)