Skip to content

Commit a30c48f

Browse files
committed
remove torch-scatter dependency
1 parent 3ca2730 commit a30c48f

File tree

4 files changed

+69
-16
lines changed

4 files changed

+69
-16
lines changed

cosmic/cosmo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from torch import nn
3-
from torch_scatter import scatter_add, scatter_mean, scatter_softmax
3+
4+
from .utilities import scatter_add, scatter_mean, scatter_softmax
45

56
"""
67
Cosmo can be implemented with various filter functions. The underlying principle is always to compute the filter under transformation of a local reference frame (hood_coords) which is derived from neighboring input points. The forward signature of the layer is always the same and inputs can be obtained from a Lift2D or Lift3D module.

cosmic/lift.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from types import SimpleNamespace
22

33
import torch
4-
from torch_scatter import scatter_max, scatter_mean, scatter_softmax, scatter_sum
54

65
from .utilities import *
76

@@ -124,15 +123,12 @@ def __init__(self, agg="mean"):
124123

125124
def __call__(self, features, index, size, return_index=False):
126125
if self.agg == "sum":
127-
return scatter_sum(features, index, dim_size=size, dim=0)
126+
return scatter_sum(features, index, size)
128127
elif self.agg == "mean":
129-
return scatter_mean(features, index, dim_size=size, dim=0)
128+
return scatter_mean(features, index, size)
130129
elif self.agg == "max":
131-
val, idx = scatter_max(features, index, dim_size=size, dim=0)
132-
if return_index:
133-
return val, idx
134-
else:
135-
return val
130+
result = scatter_max(features, index, size)
131+
return result[0] if not return_index else result
136132
elif self.agg == "softmax":
137-
a = scatter_softmax(features, index, dim_size=size, dim=0)
138-
return scatter_sum(a * features, index, dim_size=size, dim=0)
133+
a = scatter_softmax(features, index, size)
134+
return scatter_sum(a * features, index, size)

cosmic/utilities.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,64 @@ def filter_bases(bases, minimum_angle, coords, triangles):
131131
valid = angles_deg >= minimum_angle
132132
bases[~valid] = torch.nan
133133
return bases
134+
135+
136+
def scatter_sum(src, index, size):
137+
assert index.dtype == torch.long
138+
assert index.shape[0] == src.shape[0]
139+
out_shape = (size,) + src.shape[1:]
140+
out = torch.zeros(out_shape, dtype=src.dtype, device=src.device)
141+
out.index_add_(0, index, src)
142+
return out
143+
144+
145+
def scatter_mean(src, index, size):
146+
sum_out = scatter_sum(src, index, size)
147+
count = scatter_sum(torch.ones_like(src), index, size)
148+
return sum_out / count.clamp_min(1)
149+
150+
151+
def scatter_max(src, index, size):
152+
assert src.is_floating_point() and index.dtype == torch.long
153+
N = src.size(0)
154+
idx_exp = index.view(N, *([1] * (src.ndim - 1))).expand_as(src)
155+
vals = torch.full(
156+
(size,) + src.shape[1:],
157+
torch.finfo(src.dtype).min,
158+
dtype=src.dtype,
159+
device=src.device,
160+
)
161+
vals.scatter_reduce_(0, idx_exp, src, reduce="amax")
162+
pos = (
163+
torch.arange(N, device=src.device)
164+
.view(N, *([1] * (src.ndim - 1)))
165+
.expand_as(src)
166+
)
167+
pos_mask = torch.where(src == vals[index], pos, torch.full_like(pos, N))
168+
argmax = torch.full_like(vals, N, dtype=torch.long)
169+
argmax.scatter_reduce_(0, idx_exp, pos_mask, reduce="amin")
170+
argmax[argmax == N] = -1
171+
count = torch.zeros_like(vals, dtype=torch.long)
172+
count.scatter_reduce_(
173+
0, idx_exp, torch.ones_like(src, dtype=torch.long), reduce="sum"
174+
)
175+
vals[count == 0] = 0
176+
return vals, argmax
177+
178+
179+
def scatter_softmax(src, index, size, eps=1e-12):
180+
assert src.is_floating_point()
181+
assert index.dtype == torch.long
182+
N = src.shape[0]
183+
idx_exp = index.view(N, *([1] * (src.ndim - 1))).expand_as(src)
184+
max_vals = torch.full(
185+
(size,) + src.shape[1:],
186+
torch.finfo(src.dtype).min,
187+
dtype=src.dtype,
188+
device=src.device,
189+
)
190+
max_vals.scatter_reduce_(0, idx_exp, src, reduce="amax")
191+
ex = (src - max_vals[index]).exp()
192+
sum_per_group = torch.zeros_like(max_vals)
193+
sum_per_group.scatter_reduce_(0, idx_exp, ex, reduce="sum")
194+
return ex / sum_per_group[index].clamp_min(eps)

pyproject.toml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,3 @@ include = ["cosmic", "cosmic.*"]
3030

3131
[tool.setuptools_scm]
3232

33-
[project.optional-dependencies]
34-
scatter = [
35-
"torch_scatter>=2.1",
36-
]
37-

0 commit comments

Comments
 (0)