Skip to content

Commit 79a29e5

Browse files
fix egnn + equivariance test
1 parent 6e5e3dd commit 79a29e5

2 files changed

Lines changed: 61 additions & 5 deletions

File tree

pina/model/block/message_passing/en_equivariant_network_block.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,24 @@ def __init__(
102102
)
103103

104104
# Layer for updating the node features
105-
self.update_net = FeedForward(
105+
self.update_feat_net = FeedForward(
106106
input_dimensions=node_feature_dim + pos_dim,
107107
output_dimensions=node_feature_dim,
108108
inner_size=hidden_dim,
109109
n_layers=n_update_layers,
110110
func=activation,
111111
)
112112

113+
# Layer for updating the node positions
114+
# The output dimension is set to 1 for equivariant updates
115+
self.update_pos_net = FeedForward(
116+
input_dimensions=pos_dim,
117+
output_dimensions=1,
118+
inner_size=hidden_dim,
119+
n_layers=n_update_layers,
120+
func=activation,
121+
)
122+
113123
def forward(self, x, pos, edge_index, edge_attr=None):
114124
"""
115125
Forward pass of the block, triggering the message-passing routine.
@@ -145,13 +155,21 @@ def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
145155
:return: The message to be passed.
146156
:rtype: torch.Tensor
147157
"""
148-
dist = torch.norm(pos_i - pos_j, dim=-1, keepdim=True) ** 2
158+
# Compute the euclidean distance between the sender and recipient nodes
159+
diff = pos_i - pos_j
160+
dist = torch.norm(diff, dim=-1, keepdim=True) ** 2
161+
162+
# Compute the message input
149163
if edge_attr is None:
150164
input_ = torch.cat((x_i, x_j, dist), dim=-1)
151165
else:
152166
input_ = torch.cat((x_i, x_j, dist, edge_attr), dim=-1)
153167

154-
return self.message_net(input_)
168+
# Compute the messages and save them for feature update
169+
self._m_ij = self.message_net(input_)
170+
171+
# Rescale the message by the euclidean distance
172+
return diff * self.update_pos_net(self._m_ij)
155173

156174
def update(self, message, x, pos, edge_index):
157175
"""
@@ -167,10 +185,14 @@ def update(self, message, x, pos, edge_index):
167185
:return: The updated node features and node positions.
168186
:rtype: tuple(torch.Tensor, torch.Tensor)
169187
"""
188+
# Sum the incoming messages for each node (m_i = sum_j m_ij)
189+
m_sum = torch.zeros(x.size(0), self._m_ij.shape[-1], device=x.device)
190+
m_sum.index_add_(0, edge_index[1], self._m_ij)
191+
170192
# Update the node features
171-
x = self.update_net(torch.cat((x, message), dim=-1))
193+
x = self.update_feat_net(torch.cat((x, m_sum), dim=-1))
172194

173195
# Update the node positions
174-
c = degree(edge_index[0], pos.shape[0]).unsqueeze(-1)
196+
c = degree(edge_index[1], pos.shape[0]).unsqueeze(-1)
175197
pos = pos + message / c
176198
return x, pos

tests/test_messagepassing/test_equivariant_network_block.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,37 @@ def test_backward(edge_feature_dim):
128128
loss.backward()
129129
assert x.grad.shape == x.shape
130130
assert pos.grad.shape == pos.shape
131+
132+
133+
def test_equivariance():
134+
135+
# Graph to be fully connected and undirected
136+
edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T
137+
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
138+
139+
# Random rotation (det(rotation) should be 1)
140+
rotation = torch.linalg.qr(torch.rand(pos.shape[-1], pos.shape[-1])).Q
141+
if torch.det(rotation) < 0:
142+
rotation[:, 0] *= -1
143+
144+
# Random translation
145+
translation = torch.rand(1, pos.shape[-1])
146+
147+
model = EnEquivariantNetworkBlock(
148+
node_feature_dim=x.shape[1],
149+
edge_feature_dim=0,
150+
pos_dim=pos.shape[1],
151+
hidden_dim=64,
152+
n_message_layers=2,
153+
n_update_layers=2,
154+
).eval()
155+
156+
_, pos1 = model(edge_index=edge_index, x=x, pos=pos)
157+
_, pos2 = model(
158+
edge_index=edge_index, x=x, pos=pos @ rotation.T + translation
159+
)
160+
161+
# Transform model output
162+
pos1_transformed = (pos1 @ rotation.T) + translation
163+
164+
assert torch.allclose(pos2, pos1_transformed, atol=1e-5)

0 commit comments

Comments
 (0)