Skip to content

Commit f13a5f3

Browse files
AleDinveGiovanniCanali
authored andcommitted
add equivariant network block
1 parent cbf4e80 commit f13a5f3

3 files changed

Lines changed: 82 additions & 45 deletions

File tree

pina/model/block/message_passing/egnn_block.py

Lines changed: 79 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,91 +7,131 @@
77

88
class EnEquivariantGraphBlock(MessagePassing):
99
"""
10-
TODO
10+
Implementation of the E(n) Equivariant Graph Neural Network block.
11+
12+
This block is used to perform message-passing between nodes and edges in a
13+
graph neural network, following the scheme proposed by Satorras et al. (2021).
14+
It serves as an inner block in a larger graph neural network architecture.
15+
16+
The message between two nodes connected by an edge is computed by applying a
17+
linear transformation to the sender node features and the edge features,
18+
followed by a non-linear activation function. Messages are then aggregated
19+
using an aggregation scheme (e.g., sum, mean, min, max, or product).
20+
21+
The update step is performed by a simple addition of the incoming messages
22+
to the node features.
23+
24+
.. seealso::
25+
26+
**Original reference** Satorras, V. G., Hoogeboom, E., & Welling, M. (2021, July).
27+
E (n) equivariant graph neural networks.
28+
In International conference on machine learning (pp. 9323-9332). PMLR.
1129
"""
1230

1331
def __init__(
1432
self,
15-
channels_h,
33+
channels_x,
1634
channels_m,
1735
channels_a,
1836
aggr: str = "add",
1937
hidden_channels: int = 64,
2038
**kwargs,
2139
):
2240
"""
23-
TODO
41+
Initialization of the :class:`EnEquivariantGraphBlock` class.
42+
43+
:param int channels_x: The dimension of the node features.
44+
:param int channels_m: The dimension of the Euclidean coordinates (should be =3).
45+
:param int channels_a: The dimension of the edge features.
46+
:param str aggr: The aggregation scheme to use for message passing.
47+
Available options are "add", "mean", "min", "max", "mul".
48+
See :class:`torch_geometric.nn.MessagePassing` for more details.
49+
Default is "add".
50+
:param int hidden_channels_dim: The hidden dimension in each MLPs initialized in the block.
2451
"""
2552
super().__init__(aggr=aggr, **kwargs)
2653

2754
self.phi_e = torch.nn.Sequential(
28-
torch.nn.Linear(2 * channels_h + 1 + channels_a, hidden_channels),
55+
torch.nn.Linear(2 * channels_x + 1 + channels_a, hidden_channels),
2956
torch.nn.LayerNorm(hidden_channels),
3057
torch.nn.SiLU(),
3158
torch.nn.Linear(hidden_channels, channels_m),
3259
torch.nn.LayerNorm(channels_m),
3360
torch.nn.SiLU(),
3461
)
35-
self.phi_x = torch.nn.Sequential(
62+
self.phi_pos = torch.nn.Sequential(
3663
torch.nn.Linear(channels_m, hidden_channels),
3764
torch.nn.LayerNorm(hidden_channels),
3865
torch.nn.SiLU(),
3966
torch.nn.Linear(hidden_channels, 1),
4067
)
41-
self.phi_h = torch.nn.Sequential(
42-
torch.nn.Linear(channels_h + channels_m, hidden_channels),
68+
self.phi_x = torch.nn.Sequential(
69+
torch.nn.Linear(channels_x + channels_m, hidden_channels),
4370
torch.nn.LayerNorm(hidden_channels),
4471
torch.nn.SiLU(),
45-
torch.nn.Linear(hidden_channels, channels_h),
72+
torch.nn.Linear(hidden_channels, channels_x),
4673
)
4774

48-
def forward(self, x, h, edge_attr, edge_index, c=None):
75+
def forward(self, x, pos, edge_attr, edge_index, c=None):
4976
"""
50-
TODO
77+
Forward pass of the block, triggering the message-passing routine.
78+
79+
:param x: The node features.
80+
:type x: torch.Tensor | LabelTensor
81+
:param pos_i: 3D Euclidean coordinates.
82+
:type pos_i: torch.Tensor | LabelTensor
83+
:param torch.Tensor edge_index: The edge indices. In the original formulation,
84+
the messages are aggregated from all nodes, not only from the neighbours.
85+
:return: The updated node features.
86+
:rtype: torch.Tensor
5187
"""
5288
if c is None:
53-
c = degree(edge_index[0], x.shape[0]).unsqueeze(-1)
89+
c = degree(edge_index[0], pos.shape[0]).unsqueeze(-1)
5490
return self.propagate(
55-
edge_index=edge_index, x=x, h=h, edge_attr=edge_attr, c=c
91+
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr, c=c
5692
)
5793

58-
def message(self, x_i, x_j, h_i, h_j, edge_attr):
94+
def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
5995
"""
60-
TODO
96+
Compute the message to be passed between nodes and edges.
97+
98+
:param x_i: Node features of the sender nodes.
99+
:type x_i: torch.Tensor | LabelTensor
100+
:param pos_i: 3D Euclidean coordinates of the sender nodes.
101+
:type pos_i: torch.Tensor | LabelTensor
102+
:param edge_attr: The edge attributes.
103+
:type edge_attr: torch.Tensor | LabelTensor
104+
:return: The message to be passed.
105+
:rtype: torch.Tensor
61106
"""
62-
mh_ij = self.phi_e(
107+
mpos_ij = self.phi_e(
63108
torch.cat(
64109
[
65-
h_i,
66-
h_j,
67-
torch.norm(x_i - x_j, dim=-1, keepdim=True) ** 2,
110+
x_i,
111+
x_j,
112+
torch.norm(pos_i - pos_j, dim=-1, keepdim=True) ** 2,
68113
edge_attr,
69114
],
70115
dim=-1,
71116
)
72117
)
73-
mx_ij = (x_i - x_j) * self.phi_x(mh_ij)
74-
return torch.cat((mx_ij, mh_ij), dim=-1)
118+
mpos_ij = (pos_i - pos_j) * self.phi_pos(mpos_ij)
119+
return mpos_ij
75120

76-
def update(self, aggr_out, x, h, edge_attr, c):
121+
def update(self, message, x, pos, c):
77122
"""
78-
TODO
79-
"""
80-
m_x, m_h = aggr_out[:, : self.m_len], aggr_out[:, self.m_len :]
81-
h_l1 = self.phi_h(torch.cat([h, m_h], dim=-1))
82-
x_l1 = x + (m_x / c)
83-
return x_l1, h_l1
123+
Update the node features with the received messages.
84124
85-
@property
86-
def edge_function(self):
87-
"""
88-
TODO
89-
"""
90-
return self._edge_function
91-
92-
@property
93-
def attribute_function(self):
94-
"""
95-
TODO
125+
:param torch.Tensor message: The message to be passed.
126+
:param x: The node features.
127+
:type x: torch.Tensor | LabelTensor
128+
:param pos: The 3D Euclidean coordinates of the nodes.
129+
:type pos: torch.Tensor | LabelTensor
130+
:param c: the constant that divides the aggregated message (it should be (M-1), where M is the number of nodes)
131+
:type pos: torch.Tensor
132+
:return: The concatenation of the update position features and the updated node features.
133+
:rtype: torch.Tensor
96134
"""
97-
return self._attribute_function
135+
x = self.phi_x(torch.cat([x, message], dim=-1))
136+
pos = pos + (message / c)
137+
return pos, x

pina/model/block/message_passing/radial_field_network_block.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def __init__(
6262
source node. See :class:`torch_geometric.nn.MessagePassing` for more
6363
details. Default is "source_to_target".
6464
:raises ValueError: If `node_feature_dim` is not a positive integer.
65-
:raises ValueError: If `edge_feature_dim` is not a positive integer.
6665
"""
6766
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
6867

@@ -108,8 +107,7 @@ def message(self, x_j, x_i):
108107
"""
109108
Compute the message to be passed between nodes and edges.
110109
111-
:param x_j: Concatenation of the node position and the
112-
node features of the sender nodes.
110+
:param x_j: Node features of the sender nodes.
113111
:type x_j: torch.Tensor | LabelTensor
114112
:param edge_attr: The edge attributes.
115113
:type edge_attr: torch.Tensor | LabelTensor

pina/model/block/message_passing/schnet_block.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(
4646
flow="source_to_target",
4747
):
4848
"""
49-
Initialization of the :class:`RadialFieldNetworkBlock` class.
49+
Initialization of the :class:`SchnetBlock` class.
5050
5151
:param int node_feature_dim: The dimension of the node features.
5252
:param int edge_feature_dim: The dimension of the edge features.
@@ -129,8 +129,7 @@ def message(self, x_i, pos_i ,pos_j):
129129
"""
130130
Compute the message to be passed between nodes and edges.
131131
132-
:param x_j: Concatenation of the node position and the
133-
node features of the sender nodes.
132+
:param x_j: Node features of the sender nodes.
134133
:type x_j: torch.Tensor | LabelTensor
135134
:param edge_attr: The edge attributes.
136135
:type edge_attr: torch.Tensor | LabelTensor

0 commit comments

Comments
 (0)