Skip to content

Commit d1f2094

Browse files
fix egnn + equivariance/invariance tests
Co-authored-by: Dario Coscia <dariocos99@gmail.com>
1 parent 2cd5a5e commit d1f2094

4 files changed

Lines changed: 147 additions & 17 deletions

File tree

pina/model/block/message_passing/en_equivariant_network_block.py

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,21 @@
1010
class EnEquivariantNetworkBlock(MessagePassing):
1111
"""
1212
Implementation of the E(n) Equivariant Graph Neural Network block.
13-
1413
This block is used to perform message-passing between nodes and edges in a
1514
graph neural network, following the scheme proposed by Satorras et al. in
1615
2021. It serves as an inner block in a larger graph neural network
1716
architecture.
18-
1917
The message between two nodes connected by an edge is computed by applying a
2018
linear transformation to the sender node features and the edge features,
2119
together with the squared euclidean distance between the sender and
2220
recipient node positions, followed by a non-linear activation function.
2321
Messages are then aggregated using an aggregation scheme (e.g., sum, mean,
2422
min, max, or product).
25-
2623
The update step is performed by applying another MLP to the concatenation of
2724
the incoming messages and the node features. Here, also the node
2825
positions are updated by adding the incoming messages divided by the
2926
degree of the recipient node.
30-
3127
.. seealso::
32-
3328
**Original reference** Satorras, V. G., Hoogeboom, E., Welling, M.
3429
(2021). *E(n) Equivariant Graph Neural Networks.*
3530
In International Conference on Machine Learning.
@@ -51,7 +46,6 @@ def __init__(
5146
):
5247
"""
5348
Initialization of the :class:`EnEquivariantNetworkBlock` class.
54-
5549
:param int node_feature_dim: The dimension of the node features.
5650
:param int edge_feature_dim: The dimension of the edge features.
5751
:param int pos_dim: The dimension of the position features.
@@ -102,14 +96,24 @@ def __init__(
10296
)
10397

10498
# Layer for updating the node features
105-
self.update_net = FeedForward(
99+
self.update_feat_net = FeedForward(
106100
input_dimensions=node_feature_dim + pos_dim,
107101
output_dimensions=node_feature_dim,
108102
inner_size=hidden_dim,
109103
n_layers=n_update_layers,
110104
func=activation,
111105
)
112106

107+
# Layer for updating the node positions
108+
# The output dimension is set to 1 for equivariant updates
109+
self.update_pos_net = FeedForward(
110+
input_dimensions=pos_dim,
111+
output_dimensions=1,
112+
inner_size=hidden_dim,
113+
n_layers=n_update_layers,
114+
func=activation,
115+
)
116+
113117
def forward(self, x, pos, edge_index, edge_attr=None):
114118
"""
115119
Forward pass of the block, triggering the message-passing routine.
@@ -143,22 +147,62 @@ def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
143147
:param edge_attr: The edge attributes.
144148
:type edge_attr: torch.Tensor | LabelTensor
145149
:return: The message to be passed.
146-
:rtype: torch.Tensor
150+
:rtype: tuple(torch.Tensor, torch.Tensor)
147151
"""
148-
dist = torch.norm(pos_i - pos_j, dim=-1, keepdim=True) ** 2
152+
# Compute the euclidean distance between the sender and recipient nodes
153+
diff = pos_i - pos_j
154+
dist = torch.norm(diff, dim=-1, keepdim=True) ** 2
155+
156+
# Compute the message input
149157
if edge_attr is None:
150158
input_ = torch.cat((x_i, x_j, dist), dim=-1)
151159
else:
152160
input_ = torch.cat((x_i, x_j, dist, edge_attr), dim=-1)
153161

154-
return self.message_net(input_)
162+
# Compute the messages and their equivariant counterpart
163+
m_ij = self.message_net(input_)
164+
message = diff * self.update_pos_net(m_ij)
165+
166+
return message, m_ij
155167

156-
def update(self, message, x, pos, edge_index):
168+
def aggregate(self, inputs, index, ptr=None, dim_size=None):
169+
"""
170+
Aggregate the messages at the nodes during message passing.
171+
172+
This method receives a tuple of tensors corresponding to the messages
173+
to be aggregated. Both messages are aggregated separately according to
174+
the specified aggregation scheme.
175+
176+
:param tuple(torch.Tensor) inputs: Tuple containing two messages to
177+
aggregate.
178+
:param index: The indices of target nodes for each message. This tensor
179+
specifies which node each message is aggregated into.
180+
:type index: torch.Tensor | LabelTensor
181+
:param ptr: Optional tensor to specify the slices of messages for each
182+
node (used in some aggregation strategies). Default is None.
183+
:type ptr: torch.Tensor | LabelTensor
184+
:param int dim_size: Optional size of the output dimension, i.e.,
185+
number of nodes. Default is None.
186+
:return: Tuple of aggregated tensors corresponding to (aggregated
187+
messages for position updates, aggregated messages for feature
188+
updates).
189+
:rtype: tuple(torch.Tensor, torch.Tensor)
190+
"""
191+
# Unpack the messages from the inputs
192+
message, m_ij = inputs
193+
194+
# Aggregate messages as usual using self.aggr method
195+
agg_message = super().aggregate(message, index, ptr, dim_size)
196+
agg_m_ij = super().aggregate(m_ij, index, ptr, dim_size)
197+
198+
return agg_message, agg_m_ij
199+
200+
def update(self, aggregated_inputs, x, pos, edge_index):
157201
"""
158202
Update the node features and the node coordinates with the received
159203
messages.
160204
161-
:param torch.Tensor message: The message to be passed.
205+
:param tuple(torch.Tensor) aggregated_inputs: The messages to be passed.
162206
:param x: The node features.
163207
:type x: torch.Tensor | LabelTensor
164208
:param pos: The euclidean coordinates of the nodes.
@@ -167,10 +211,14 @@ def update(self, message, x, pos, edge_index):
167211
:return: The updated node features and node positions.
168212
:rtype: tuple(torch.Tensor, torch.Tensor)
169213
"""
170-
# Update the node features
171-
x = self.update_net(torch.cat((x, message), dim=-1))
214+
# aggregated_inputs is tuple (agg_message, agg_m_ij)
215+
agg_message, agg_m_ij = aggregated_inputs
216+
217+
# Update node features with aggregated m_ij
218+
x = self.update_feat_net(torch.cat((x, agg_m_ij), dim=-1))
219+
220+
# Degree for normalization of position updates
221+
c = degree(edge_index[1], pos.shape[0]).unsqueeze(-1).clamp(min=1)
222+
pos = pos + agg_message / c
172223

173-
# Update the node positions
174-
c = degree(edge_index[0], pos.shape[0]).unsqueeze(-1)
175-
pos = pos + message / c
176224
return x, pos

tests/test_messagepassing/test_equivariant_network_block.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,38 @@ def test_backward(edge_feature_dim):
128128
loss.backward()
129129
assert x.grad.shape == x.shape
130130
assert pos.grad.shape == pos.shape
131+
132+
133+
def test_equivariance():
134+
135+
# Graph to be fully connected and undirected
136+
edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T
137+
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
138+
139+
# Random rotation (det(rotation) should be 1)
140+
rotation = torch.linalg.qr(torch.rand(pos.shape[-1], pos.shape[-1])).Q
141+
if torch.det(rotation) < 0:
142+
rotation[:, 0] *= -1
143+
144+
# Random translation
145+
translation = torch.rand(1, pos.shape[-1])
146+
147+
model = EnEquivariantNetworkBlock(
148+
node_feature_dim=x.shape[1],
149+
edge_feature_dim=0,
150+
pos_dim=pos.shape[1],
151+
hidden_dim=64,
152+
n_message_layers=2,
153+
n_update_layers=2,
154+
).eval()
155+
156+
h1, pos1 = model(edge_index=edge_index, x=x, pos=pos)
157+
h2, pos2 = model(
158+
edge_index=edge_index, x=x, pos=pos @ rotation.T + translation
159+
)
160+
161+
# Transform model output
162+
pos1_transformed = (pos1 @ rotation.T) + translation
163+
164+
assert torch.allclose(pos2, pos1_transformed, atol=1e-5)
165+
assert torch.allclose(h1, h2, atol=1e-5)

tests/test_messagepassing/test_radial_field_network_block.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,28 @@ def test_backward():
6565
loss = torch.mean(output_)
6666
loss.backward()
6767
assert x.grad.shape == x.shape
68+
69+
70+
def test_equivariance():
71+
72+
# Graph to be fully connected and undirected
73+
edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T
74+
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
75+
76+
# Random rotation (det(rotation) should be 1)
77+
rotation = torch.linalg.qr(torch.rand(x.shape[-1], x.shape[-1])).Q
78+
if torch.det(rotation) < 0:
79+
rotation[:, 0] *= -1
80+
81+
# Random translation
82+
translation = torch.rand(1, x.shape[-1])
83+
84+
model = RadialFieldNetworkBlock(node_feature_dim=x.shape[1]).eval()
85+
86+
pos1 = model(edge_index=edge_index, x=x)
87+
pos2 = model(edge_index=edge_index, x=x @ rotation.T + translation)
88+
89+
# Transform model output
90+
pos1_transformed = (pos1 @ rotation.T) + translation
91+
92+
assert torch.allclose(pos2, pos1_transformed, atol=1e-5)

tests/test_messagepassing/test_schnet_block.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,25 @@ def test_backward():
7171
loss = torch.mean(output_)
7272
loss.backward()
7373
assert x.grad.shape == x.shape
74+
75+
76+
def test_invariance():
77+
78+
# Graph to be fully connected and undirected
79+
edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T
80+
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
81+
82+
# Random rotation (det(rotation) should be 1)
83+
rotation = torch.linalg.qr(torch.rand(pos.shape[-1], pos.shape[-1])).Q
84+
if torch.det(rotation) < 0:
85+
rotation[:, 0] *= -1
86+
87+
# Random translation
88+
translation = torch.rand(1, pos.shape[-1])
89+
90+
model = SchnetBlock(node_feature_dim=x.shape[1]).eval()
91+
92+
out1 = model(edge_index=edge_index, x=x, pos=pos)
93+
out2 = model(edge_index=edge_index, x=x, pos=pos @ rotation.T + translation)
94+
95+
assert torch.allclose(out1, out2, atol=1e-5)

0 commit comments

Comments
 (0)