@@ -102,14 +102,24 @@ def __init__(
102102 )
103103
104104 # Layer for updating the node features
105- self .update_net = FeedForward (
105+ self .update_feat_net = FeedForward (
106106 input_dimensions = node_feature_dim + pos_dim ,
107107 output_dimensions = node_feature_dim ,
108108 inner_size = hidden_dim ,
109109 n_layers = n_update_layers ,
110110 func = activation ,
111111 )
112112
113+ # Layer for updating the node positions
114+ # The output dimension is set to 1 for equivariant updates
115+ self .update_pos_net = FeedForward (
116+ input_dimensions = pos_dim ,
117+ output_dimensions = 1 ,
118+ inner_size = hidden_dim ,
119+ n_layers = n_update_layers ,
120+ func = activation ,
121+ )
122+
113123 def forward (self , x , pos , edge_index , edge_attr = None ):
114124 """
115125 Forward pass of the block, triggering the message-passing routine.
@@ -145,13 +155,21 @@ def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
145155 :return: The message to be passed.
146156 :rtype: torch.Tensor
147157 """
148- dist = torch .norm (pos_i - pos_j , dim = - 1 , keepdim = True ) ** 2
158+ # Compute the euclidean distance between the sender and recipient nodes
159+ diff = pos_i - pos_j
160+ dist = torch .norm (diff , dim = - 1 , keepdim = True ) ** 2
161+
162+ # Compute the message input
149163 if edge_attr is None :
150164 input_ = torch .cat ((x_i , x_j , dist ), dim = - 1 )
151165 else :
152166 input_ = torch .cat ((x_i , x_j , dist , edge_attr ), dim = - 1 )
153167
154- return self .message_net (input_ )
168+ # Compute the messages and save them for feature update
169+ self ._m_ij = self .message_net (input_ )
170+
171+ # Rescale the message by the euclidean distance
172+ return diff * self .update_pos_net (self ._m_ij )
155173
156174 def update (self , message , x , pos , edge_index ):
157175 """
@@ -167,10 +185,14 @@ def update(self, message, x, pos, edge_index):
167185 :return: The updated node features and node positions.
168186 :rtype: tuple(torch.Tensor, torch.Tensor)
169187 """
188+ # Sum the incoming messages for each node (m_i = sum_j m_ij)
189+ m_sum = torch .zeros (x .size (0 ), self ._m_ij .shape [- 1 ], device = x .device )
190+ m_sum .index_add_ (0 , edge_index [1 ], self ._m_ij )
191+
170192 # Update the node features
171- x = self .update_net (torch .cat ((x , message ), dim = - 1 ))
193+ x = self .update_feat_net (torch .cat ((x , m_sum ), dim = - 1 ))
172194
173195 # Update the node positions
174- c = degree (edge_index [0 ], pos .shape [0 ]).unsqueeze (- 1 )
196+ c = degree (edge_index [1 ], pos .shape [0 ]).unsqueeze (- 1 )
175197 pos = pos + message / c
176198 return x , pos
0 commit comments