1010class EnEquivariantNetworkBlock (MessagePassing ):
1111 """
1212 Implementation of the E(n) Equivariant Graph Neural Network block.
13-
1413 This block is used to perform message-passing between nodes and edges in a
1514 graph neural network, following the scheme proposed by Satorras et al. in
1615 2021. It serves as an inner block in a larger graph neural network
1716 architecture.
18-
1917 The message between two nodes connected by an edge is computed by applying a
2018 linear transformation to the sender node features and the edge features,
2119 together with the squared euclidean distance between the sender and
2220 recipient node positions, followed by a non-linear activation function.
2321 Messages are then aggregated using an aggregation scheme (e.g., sum, mean,
2422 min, max, or product).
25-
2623 The update step is performed by applying another MLP to the concatenation of
2724 the incoming messages and the node features. Here, also the node
2825 positions are updated by adding the incoming messages divided by the
2926 degree of the recipient node.
30-
3127 .. seealso::
32-
3328 **Original reference** Satorras, V. G., Hoogeboom, E., Welling, M.
3429 (2021). *E(n) Equivariant Graph Neural Networks.*
3530 In International Conference on Machine Learning.
@@ -51,7 +46,6 @@ def __init__(
5146 ):
5247 """
5348 Initialization of the :class:`EnEquivariantNetworkBlock` class.
54-
5549 :param int node_feature_dim: The dimension of the node features.
5650 :param int edge_feature_dim: The dimension of the edge features.
5751 :param int pos_dim: The dimension of the position features.
@@ -102,14 +96,24 @@ def __init__(
10296 )
10397
10498 # Layer for updating the node features
105- self .update_net = FeedForward (
99+ self .update_feat_net = FeedForward (
106100 input_dimensions = node_feature_dim + pos_dim ,
107101 output_dimensions = node_feature_dim ,
108102 inner_size = hidden_dim ,
109103 n_layers = n_update_layers ,
110104 func = activation ,
111105 )
112106
107+ # Layer for updating the node positions
108+ # The output dimension is set to 1 for equivariant updates
109+ self .update_pos_net = FeedForward (
110+ input_dimensions = pos_dim ,
111+ output_dimensions = 1 ,
112+ inner_size = hidden_dim ,
113+ n_layers = n_update_layers ,
114+ func = activation ,
115+ )
116+
113117 def forward (self , x , pos , edge_index , edge_attr = None ):
114118 """
115119 Forward pass of the block, triggering the message-passing routine.
@@ -143,22 +147,62 @@ def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
143147 :param edge_attr: The edge attributes.
144148 :type edge_attr: torch.Tensor | LabelTensor
145149 :return: The message to be passed.
146- :rtype: torch.Tensor
150+ :rtype: tuple( torch.Tensor, torch.Tensor)
147151 """
148- dist = torch .norm (pos_i - pos_j , dim = - 1 , keepdim = True ) ** 2
152+ # Compute the euclidean distance between the sender and recipient nodes
153+ diff = pos_i - pos_j
154+ dist = torch .norm (diff , dim = - 1 , keepdim = True ) ** 2
155+
156+ # Compute the message input
149157 if edge_attr is None :
150158 input_ = torch .cat ((x_i , x_j , dist ), dim = - 1 )
151159 else :
152160 input_ = torch .cat ((x_i , x_j , dist , edge_attr ), dim = - 1 )
153161
154- return self .message_net (input_ )
162+ # Compute the messages and their equivariant counterpart
163+ m_ij = self .message_net (input_ )
164+ message = diff * self .update_pos_net (m_ij )
165+
166+ return message , m_ij
155167
156- def update (self , message , x , pos , edge_index ):
168+ def aggregate (self , inputs , index , ptr = None , dim_size = None ):
169+ """
170+ Aggregate the messages at the nodes during message passing.
171+
172+ This method receives a tuple of tensors corresponding to the messages
173+ to be aggregated. Both messages are aggregated separately according to
174+ the specified aggregation scheme.
175+
176+ :param tuple(torch.Tensor) inputs: Tuple containing two messages to
177+ aggregate.
178+ :param index: The indices of target nodes for each message. This tensor
179+ specifies which node each message is aggregated into.
180+ :type index: torch.Tensor | LabelTensor
181+ :param ptr: Optional tensor to specify the slices of messages for each
182+ node (used in some aggregation strategies). Default is None.
183+ :type ptr: torch.Tensor | LabelTensor
184+ :param int dim_size: Optional size of the output dimension, i.e.,
185+ number of nodes. Default is None.
186+ :return: Tuple of aggregated tensors corresponding to (aggregated
187+ messages for position updates, aggregated messages for feature
188+ updates).
189+ :rtype: tuple(torch.Tensor, torch.Tensor)
190+ """
191+ # Unpack the messages from the inputs
192+ message , m_ij = inputs
193+
194+ # Aggregate messages as usual using self.aggr method
195+ agg_message = super ().aggregate (message , index , ptr , dim_size )
196+ agg_m_ij = super ().aggregate (m_ij , index , ptr , dim_size )
197+
198+ return agg_message , agg_m_ij
199+
200+ def update (self , aggregated_inputs , x , pos , edge_index ):
157201 """
158202 Update the node features and the node coordinates with the received
159203 messages.
160204
161- :param torch.Tensor message : The message to be passed.
205+ :param tuple( torch.Tensor) aggregated_inputs : The messages to be passed.
162206 :param x: The node features.
163207 :type x: torch.Tensor | LabelTensor
164208 :param pos: The euclidean coordinates of the nodes.
@@ -167,10 +211,14 @@ def update(self, message, x, pos, edge_index):
167211 :return: The updated node features and node positions.
168212 :rtype: tuple(torch.Tensor, torch.Tensor)
169213 """
170- # Update the node features
171- x = self .update_net (torch .cat ((x , message ), dim = - 1 ))
214+ # aggregated_inputs is tuple (agg_message, agg_m_ij)
215+ agg_message , agg_m_ij = aggregated_inputs
216+
217+ # Update node features with aggregated m_ij
218+ x = self .update_feat_net (torch .cat ((x , agg_m_ij ), dim = - 1 ))
219+
220+ # Degree for normalization of position updates
221+ c = degree (edge_index [1 ], pos .shape [0 ]).unsqueeze (- 1 ).clamp (min = 1 )
222+ pos = pos + agg_message / c
172223
173- # Update the node positions
174- c = degree (edge_index [0 ], pos .shape [0 ]).unsqueeze (- 1 )
175- pos = pos + message / c
176224 return x , pos
0 commit comments