From 5a4fc440e65811c0f42b9b3b8793e25600025f5d Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Fri, 21 Mar 2025 10:34:52 +0100 Subject: [PATCH 1/8] add buggy egnn block --- .../model/block/message_passing/egnn_block.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 pina/model/block/message_passing/egnn_block.py diff --git a/pina/model/block/message_passing/egnn_block.py b/pina/model/block/message_passing/egnn_block.py new file mode 100644 index 000000000..8154aeb8e --- /dev/null +++ b/pina/model/block/message_passing/egnn_block.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn +from torch_geometric.nn import MessagePassing +from torch_geometric.utils import degree +from ....utils import check_consistency + + +class EnEquivariantGraphBlock(MessagePassing): + def __init__(self, + channels_h, + channels_m, + channels_a, + aggr: str = 'add', + hidden_channels: int = 64, + **kwargs): + super().__init__(aggr=aggr, **kwargs) + + self.phi_e = nn.Sequential( + nn.Linear(2 * channels_h + 1 + channels_a, hidden_channels), + nn.LayerNorm(hidden_channels), + nn.SiLU(), + nn.Linear(hidden_channels, channels_m), + nn.LayerNorm(channels_m), + nn.SiLU() + ) + self.phi_x = nn.Sequential( + nn.Linear(channels_m, hidden_channels), + nn.LayerNorm(hidden_channels), + nn.SiLU(), + nn.Linear(hidden_channels, 1), + ) + self.phi_h = nn.Sequential( + nn.Linear(channels_h + channels_m, hidden_channels), + nn.LayerNorm(hidden_channels), + nn.SiLU(), + nn.Linear(hidden_channels, channels_h), + ) + + def forward(self, x, h, edge_attr, edge_index, c=None): + if c is None: + c = degree(edge_index[0], x.shape[0]).unsqueeze(-1) + return self.propagate(edge_index=edge_index, x=x, h=h, edge_attr=edge_attr, c=c) + + def message(self, x_i, x_j, h_i, h_j, edge_attr): + mh_ij = self.phi_e(torch.cat([h_i, h_j, torch.norm(x_i - x_j, dim=-1, keepdim=True)**2, edge_attr], dim=-1)) + mx_ij = (x_i - x_j) * self.phi_x(mh_ij) + return torch.cat((mx_ij, mh_ij), dim=-1) + + def update(self, aggr_out, x, h, edge_attr, c): + m_x, m_h = aggr_out[:, :self.m_len], aggr_out[:, self.m_len:] + h_l1 = self.phi_h(torch.cat([h, m_h], dim=-1)) + x_l1 = x + (m_x / c) + return x_l1, h_l1 + + @property + def edge_function(self): + return self._edge_function + + @property + def attribute_function(self): + return self._attribute_function From a7c8c35b7254059ad40fbed6e5c5307c1c2888ef Mon Sep 17 00:00:00 2001 From: giovanni Date: Wed, 9 Apr 2025 15:10:40 +0200 Subject: [PATCH 2/8] add deep tensor network block --- pina/model/block/message_passing/__init__.py | 9 ++ .../deep_tensor_network_block.py | 128 ++++++++++++++++++ .../model/block/message_passing/egnn_block.py | 96 +++++++++---- .../interaction_network_block.py | 10 ++ 4 files changed, 213 insertions(+), 30 deletions(-) create mode 100644 pina/model/block/message_passing/__init__.py create mode 100644 pina/model/block/message_passing/deep_tensor_network_block.py create mode 100644 pina/model/block/message_passing/interaction_network_block.py diff --git a/pina/model/block/message_passing/__init__.py b/pina/model/block/message_passing/__init__.py new file mode 100644 index 000000000..a4b122016 --- /dev/null +++ b/pina/model/block/message_passing/__init__.py @@ -0,0 +1,9 @@ +"""Module for the message passing blocks of the graph neural models.""" + +__all__ = [ + "InteractionNetworkBlock", + "DeepTensorNetworkBlock", +] + +from .interaction_network_block import InteractionNetworkBlock +from .deep_tensor_network_block import DeepTensorNetworkBlock diff --git a/pina/model/block/message_passing/deep_tensor_network_block.py b/pina/model/block/message_passing/deep_tensor_network_block.py new file mode 100644 index 000000000..fe48d8e13 --- /dev/null +++ b/pina/model/block/message_passing/deep_tensor_network_block.py @@ -0,0 +1,128 @@ +"""Module for the Deep Tensor Network block.""" + +import torch +from torch_geometric.nn import MessagePassing + + +class DeepTensorNetworkBlock(MessagePassing): + """ + Implementation of the Deep Tensor Network block. + + This block is used to perform message-passing between nodes and edges in a + graph neural network, following the scheme proposed by Schutt et al. (2017). + It serves as an inner block in a larger graph neural network architecture. + + The message between two nodes connected by an edge is computed by applying a + linear transformation to the sender node features and the edge features, + followed by a non-linear activation function. Messages are then aggregated + using an aggregation scheme (e.g., sum, mean, min, max, or product). + + The update step is performed by a simple addition of the incoming messages + to the node features. + + .. seealso:: + + **Original reference**: Schutt, K., Arbabzadah, F., Chmiela, S. et al. + *Quantum-Chemical Insights from Deep Tensor Neural Networks*. + Nature Communications 8, 13890 (2017). + DOI: `_` + """ + + def __init__( + self, + node_feature_dim, + edge_feature_dim, + activation=torch.nn.Tanh, + aggr="add", + node_dim=-2, + flow="source_to_target", + ): + """ + Initialization of the :class:`AVNOBDeepTensorNetworkBlocklock` class. + + :param int node_feature_dim: The dimension of the node features. + :param int edge_feature_dim: The dimension of the edge features. + :param torch.nn.Module activation: The activation function. + Default is :class:`torch.nn.Tanh`. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int node_dim: The axis along which to propagate. Default is -2. + :param str flow: The direction of message passing. + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "source_to_target". + """ + super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) + + self.node_feature_dim = node_feature_dim + self.edge_feature_dim = edge_feature_dim + self.activation = activation + + # Layer for processing node features + self.node_layer = torch.nn.Linear( + in_features=self.node_feature_dim, + out_features=self.node_feature_dim, + bias=True, + ) + + # Layer for processing edge features + self.edge_layer = torch.nn.Linear( + in_features=self.edge_feature_dim, + out_features=self.node_feature_dim, + bias=True, + ) + + # Layer for computing the message + self.message_layer = torch.nn.Linear( + in_features=self.node_feature_dim, + out_features=self.node_feature_dim, + bias=False, + ) + + def forward(self, x, edge_index, edge_attr): + """ + Forward pass of the block. It performs a message-passing operation + between nodes and edges. + + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indeces. + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr) + + def message(self, x_j, edge_attr): + """ + Compute the message to be passed between nodes and edges. + + :param x_j: The node features of the sender nodes. + :type x_j: torch.Tensor | LabelTensor + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The message to be passed. + :rtype: torch.Tensor + """ + # Process node and edge features + filter_node = self.node_layer(x_j) + filter_edge = self.edge_layer(edge_attr) + + # Compute the message to be passed + message = self.message_layer(filter_node * filter_edge) + + return self.activation(message) + + def update(self, message, x): + """ + Update the node features with the received messages. + + :param torch.Tensor message: The message to be passed. + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + return x + message diff --git a/pina/model/block/message_passing/egnn_block.py b/pina/model/block/message_passing/egnn_block.py index 8154aeb8e..7c137ac0e 100644 --- a/pina/model/block/message_passing/egnn_block.py +++ b/pina/model/block/message_passing/egnn_block.py @@ -1,61 +1,97 @@ +"""Module for the E(n) Equivariant Graph Neural Network block.""" + import torch -import torch.nn as nn from torch_geometric.nn import MessagePassing from torch_geometric.utils import degree -from ....utils import check_consistency class EnEquivariantGraphBlock(MessagePassing): - def __init__(self, - channels_h, - channels_m, - channels_a, - aggr: str = 'add', - hidden_channels: int = 64, - **kwargs): + """ + TODO + """ + + def __init__( + self, + channels_h, + channels_m, + channels_a, + aggr: str = "add", + hidden_channels: int = 64, + **kwargs, + ): + """ + TODO + """ super().__init__(aggr=aggr, **kwargs) - self.phi_e = nn.Sequential( - nn.Linear(2 * channels_h + 1 + channels_a, hidden_channels), - nn.LayerNorm(hidden_channels), - nn.SiLU(), - nn.Linear(hidden_channels, channels_m), - nn.LayerNorm(channels_m), - nn.SiLU() + self.phi_e = torch.nn.Sequential( + torch.nn.Linear(2 * channels_h + 1 + channels_a, hidden_channels), + torch.nn.LayerNorm(hidden_channels), + torch.nn.SiLU(), + torch.nn.Linear(hidden_channels, channels_m), + torch.nn.LayerNorm(channels_m), + torch.nn.SiLU(), ) - self.phi_x = nn.Sequential( - nn.Linear(channels_m, hidden_channels), - nn.LayerNorm(hidden_channels), - nn.SiLU(), - nn.Linear(hidden_channels, 1), + self.phi_x = torch.nn.Sequential( + torch.nn.Linear(channels_m, hidden_channels), + torch.nn.LayerNorm(hidden_channels), + torch.nn.SiLU(), + torch.nn.Linear(hidden_channels, 1), + ) + self.phi_h = torch.nn.Sequential( + torch.nn.Linear(channels_h + channels_m, hidden_channels), + torch.nn.LayerNorm(hidden_channels), + torch.nn.SiLU(), + torch.nn.Linear(hidden_channels, channels_h), ) - self.phi_h = nn.Sequential( - nn.Linear(channels_h + channels_m, hidden_channels), - nn.LayerNorm(hidden_channels), - nn.SiLU(), - nn.Linear(hidden_channels, channels_h), - ) def forward(self, x, h, edge_attr, edge_index, c=None): + """ + TODO + """ if c is None: c = degree(edge_index[0], x.shape[0]).unsqueeze(-1) - return self.propagate(edge_index=edge_index, x=x, h=h, edge_attr=edge_attr, c=c) + return self.propagate( + edge_index=edge_index, x=x, h=h, edge_attr=edge_attr, c=c + ) def message(self, x_i, x_j, h_i, h_j, edge_attr): - mh_ij = self.phi_e(torch.cat([h_i, h_j, torch.norm(x_i - x_j, dim=-1, keepdim=True)**2, edge_attr], dim=-1)) + """ + TODO + """ + mh_ij = self.phi_e( + torch.cat( + [ + h_i, + h_j, + torch.norm(x_i - x_j, dim=-1, keepdim=True) ** 2, + edge_attr, + ], + dim=-1, + ) + ) mx_ij = (x_i - x_j) * self.phi_x(mh_ij) return torch.cat((mx_ij, mh_ij), dim=-1) def update(self, aggr_out, x, h, edge_attr, c): - m_x, m_h = aggr_out[:, :self.m_len], aggr_out[:, self.m_len:] + """ + TODO + """ + m_x, m_h = aggr_out[:, : self.m_len], aggr_out[:, self.m_len :] h_l1 = self.phi_h(torch.cat([h, m_h], dim=-1)) x_l1 = x + (m_x / c) return x_l1, h_l1 @property def edge_function(self): + """ + TODO + """ return self._edge_function @property def attribute_function(self): + """ + TODO + """ return self._attribute_function diff --git a/pina/model/block/message_passing/interaction_network_block.py b/pina/model/block/message_passing/interaction_network_block.py new file mode 100644 index 000000000..44ecccb27 --- /dev/null +++ b/pina/model/block/message_passing/interaction_network_block.py @@ -0,0 +1,10 @@ +"""Module for the Interaction Network block.""" + +import torch +from torch_geometric.nn import MessagePassing + + +class InteractionNetworkBlock(MessagePassing): + """ + TODO + """ From 9a098218e2ff503d870bd85581d3302596ce5843 Mon Sep 17 00:00:00 2001 From: giovanni Date: Wed, 9 Apr 2025 17:22:29 +0200 Subject: [PATCH 3/8] add interaction network block --- .../deep_tensor_network_block.py | 38 ++++- .../interaction_network_block.py | 160 +++++++++++++++++- 2 files changed, 190 insertions(+), 8 deletions(-) diff --git a/pina/model/block/message_passing/deep_tensor_network_block.py b/pina/model/block/message_passing/deep_tensor_network_block.py index fe48d8e13..950e32f05 100644 --- a/pina/model/block/message_passing/deep_tensor_network_block.py +++ b/pina/model/block/message_passing/deep_tensor_network_block.py @@ -2,6 +2,7 @@ import torch from torch_geometric.nn import MessagePassing +from ....utils import check_consistency class DeepTensorNetworkBlock(MessagePassing): @@ -25,7 +26,7 @@ class DeepTensorNetworkBlock(MessagePassing): **Original reference**: Schutt, K., Arbabzadah, F., Chmiela, S. et al. *Quantum-Chemical Insights from Deep Tensor Neural Networks*. Nature Communications 8, 13890 (2017). - DOI: `_` + DOI: `_`. """ def __init__( @@ -38,7 +39,7 @@ def __init__( flow="source_to_target", ): """ - Initialization of the :class:`AVNOBDeepTensorNetworkBlocklock` class. + Initialization of the :class:`DeepTensorNetworkBlocklock` class. :param int node_feature_dim: The dimension of the node features. :param int edge_feature_dim: The dimension of the edge features. @@ -49,12 +50,36 @@ def __init__( See :class:`torch_geometric.nn.MessagePassing` for more details. Default is "add". :param int node_dim: The axis along which to propagate. Default is -2. - :param str flow: The direction of message passing. - See :class:`torch_geometric.nn.MessagePassing` for more details. - Default is "source_to_target". + :param str flow: The direction of message passing. Available options + are "source_to_target" and "target_to_source". + The "source_to_target" flow means that messages are sent from + the source node to the target node, while the "target_to_source" + flow means that messages are sent from the target node to the + source node. See :class:`torch_geometric.nn.MessagePassing` for more + details. Default is "source_to_target". + :raises ValueError: If `node_feature_dim` is not a positive integer. + :raises ValueError: If `edge_feature_dim` is not a positive integer. """ super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) + # Check consistency + check_consistency(node_feature_dim, int) + check_consistency(edge_feature_dim, int) + + # Check values + if node_feature_dim <= 0: + raise ValueError( + "`node_feature_dim` must be a positive integer," + f" got {node_feature_dim}." + ) + + if edge_feature_dim <= 0: + raise ValueError( + "`edge_feature_dim` must be a positive integer," + f" got {edge_feature_dim}." + ) + + # Initialize parameters self.node_feature_dim = node_feature_dim self.edge_feature_dim = edge_feature_dim self.activation = activation @@ -82,8 +107,7 @@ def __init__( def forward(self, x, edge_index, edge_attr): """ - Forward pass of the block. It performs a message-passing operation - between nodes and edges. + Forward pass of the block, triggering the message-passing routine. :param x: The node features. :type x: torch.Tensor | LabelTensor diff --git a/pina/model/block/message_passing/interaction_network_block.py b/pina/model/block/message_passing/interaction_network_block.py index 44ecccb27..f27169448 100644 --- a/pina/model/block/message_passing/interaction_network_block.py +++ b/pina/model/block/message_passing/interaction_network_block.py @@ -2,9 +2,167 @@ import torch from torch_geometric.nn import MessagePassing +from ....model import FeedForward +from ....utils import check_consistency class InteractionNetworkBlock(MessagePassing): """ - TODO + Implementation of the Interaction Network block. + + This block is used to perform message-passing between nodes and edges in a + graph neural network, following the scheme proposed by Battaglia et al. + (2016). + It serves as an inner block in a larger graph neural network architecture. + + The message between two nodes connected by an edge is computed by applying a + multi-layer perceptron (MLP) to the concatenation of the sender and + recipient node features. Messages are then aggregated using an aggregation + scheme (e.g., sum, mean, min, max, or product). + + The update step is performed by applying another MLP to the concatenation of + the incoming messages and the node features. + + .. seealso:: + + **Original reference**: Battaglia, P. W., et al. (2016). + *Interaction Networks for Learning about Objects, Relations and + Physics*. + In Advances in Neural Information Processing Systems (NeurIPS 2016). + DOI: `_`. """ + + def __init__( + self, + node_feature_dim, + hidden_dim, + n_message_layers=2, + n_update_layers=2, + activation=torch.nn.SiLU, + aggr="add", + node_dim=-2, + flow="source_to_target", + ): + """ + Initialization of the :class:`InteractionNetworkBlock` class. + + :param int node_feature_dim: The dimension of the node features. + :param int hidden_dim: The dimension of the hidden features. + :param int n_message_layers: The number of layers in the message + network. Default is 2. + :param int n_update_layers: The number of layers in the update network. + Default is 2. + :param torch.nn.Module activation: The activation function. + Default is :class:`torch.nn.SiLU`. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int node_dim: The axis along which to propagate. Default is -2. + :param str flow: The direction of message passing. Available options + are "source_to_target" and "target_to_source". + The "source_to_target" flow means that messages are sent from + the source node to the target node, while the "target_to_source" + flow means that messages are sent from the target node to the + source node. See :class:`torch_geometric.nn.MessagePassing` for more + details. Default is "source_to_target". + :raises ValueError: If `node_feature_dim` is not a positive integer. + :raises ValueError: If `hidden_dim` is not a positive integer. + :raises ValueError: If `n_message_layers` is not a positive integer. + :raises ValueError: If `n_update_layers` is not a positive integer. + """ + super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) + + # Check consistency + check_consistency(node_feature_dim, int) + check_consistency(hidden_dim, int) + check_consistency(n_message_layers, int) + check_consistency(n_update_layers, int) + + # Check values + if node_feature_dim <= 0: + raise ValueError( + "`node_feature_dim` must be a positive integer," + f" got {node_feature_dim}." + ) + + if hidden_dim <= 0: + raise ValueError( + "`hidden_dim` must be a positive integer," f" got {hidden_dim}." + ) + + if n_message_layers <= 0: + raise ValueError( + "`n_message_layers` must be a positive integer," + f" got {n_message_layers}." + ) + + if n_update_layers <= 0: + raise ValueError( + "`n_update_layers` must be a positive integer," + f" got {n_update_layers}." + ) + + # Initialize parameters + self.node_feature_dim = node_feature_dim + self.hidden_dim = hidden_dim + self.activation = activation + + # Message network + self.message_net = FeedForward( + input_dimensions=2 * self.node_feature_dim, + output_dimensions=self.hidden_dim, + inner_size=self.hidden_dim, + n_layers=n_message_layers, + func=self.activation, + ) + + # Update network + self.update_net = FeedForward( + input_dimensions=self.node_feature_dim + self.hidden_dim, + output_dimensions=self.hidden_dim, + inner_size=self.node_feature_dim, + n_layers=n_update_layers, + func=self.activation, + ) + + def forward(self, x, edge_index, edge_attr): + """ + Forward pass of the block, triggering the message-passing routine. + + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indeces. + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + + # TODO: edge_attr is not used in the message function + return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr) + + def message(self, x_i, x_j): + """ + Compute the message to be passed between nodes and edges. + + :param x_i: The node features of the recipient nodes. + :type x_i: torch.Tensor | LabelTensor + :param x_j: The node features of the sender nodes. + :type x_j: torch.Tensor | LabelTensor + :return: The message to be passed. + :rtype: torch.Tensor + """ + return self.message_net(torch.cat((x_i, x_j), dim=-1)) + + def update(self, message, x): + """ + Update the node features with the received messages. + + :param torch.Tensor message: The message to be passed. + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + return self.update_net(torch.cat((x, message), dim=-1)) From 0ba352c584fba4a08d71105d437a62d063eae4fb Mon Sep 17 00:00:00 2001 From: AleDinve Date: Thu, 24 Apr 2025 12:59:56 -0400 Subject: [PATCH 4/8] add radial field network block --- .../radial_field_network_block.py | 143 ++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 pina/model/block/message_passing/radial_field_network_block.py diff --git a/pina/model/block/message_passing/radial_field_network_block.py b/pina/model/block/message_passing/radial_field_network_block.py new file mode 100644 index 000000000..4f5982fe3 --- /dev/null +++ b/pina/model/block/message_passing/radial_field_network_block.py @@ -0,0 +1,143 @@ +"""Module for the Radial Field Network block.""" + +import torch +from torch_geometric.nn import MessagePassing +from ....utils import check_consistency + + +class RadialFieldNetworkBlock(MessagePassing): + """ + Implementation of the Radial Field Network block. + + This block is used to perform message-passing between nodes and edges in a + graph neural network, following the scheme proposed by Köhler et al. (2020). + It serves as an inner block in a larger graph neural network architecture. + + The message between two nodes connected by an edge is computed by applying a + linear transformation to the sender node features and the edge features, + followed by a non-linear activation function. Messages are then aggregated + using an aggregation scheme (e.g., sum, mean, min, max, or product). + + The update step is performed by a simple addition of the incoming messages + to the node features. + + .. seealso:: + + **Original reference** Köhler, J., Klein, L., & Noé, F. (2020, November). + Equivariant flows: exact likelihood generative learning for symmetric densities. + In International conference on machine learning (pp. 5361-5370). PMLR. + """ + + + + def __init__( + self, + node_feature_dim, + hidden_dim, + edge_feature_dim, + activation=torch.nn.ReLU, + aggr="add", + node_dim=-2, + flow="source_to_target", + ): + """ + Initialization of the :class:`RadialFieldNetworkBlock` class. + + :param int node_feature_dim: The dimension of the node features. + :param int edge_feature_dim: The dimension of the edge features. + :param torch.nn.Module activation: The activation function. + Default is :class:`torch.nn.Tanh`. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int node_dim: The axis along which to propagate. Default is -2. + :param str flow: The direction of message passing. Available options + are "source_to_target" and "target_to_source". + The "source_to_target" flow means that messages are sent from + the source node to the target node, while the "target_to_source" + flow means that messages are sent from the target node to the + source node. See :class:`torch_geometric.nn.MessagePassing` for more + details. Default is "source_to_target". + :raises ValueError: If `node_feature_dim` is not a positive integer. + :raises ValueError: If `edge_feature_dim` is not a positive integer. + """ + super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) + + # Check consistency + check_consistency(node_feature_dim, int) + check_consistency(edge_feature_dim, int) + + # Check values + if node_feature_dim <= 0: + raise ValueError( + "`node_feature_dim` must be a positive integer," + f" got {node_feature_dim}." + ) + + if edge_feature_dim <= 0: + raise ValueError( + "`edge_feature_dim` must be a positive integer," + f" got {edge_feature_dim}." + ) + + + # Initialize parameters + self.node_feature_dim = node_feature_dim + self.edge_feature_dim = edge_feature_dim + self.hidden_dim = hidden_dim + self.activation = activation + self.layer = lambda i,o: torch.nn.Linear( + in_features=i, + out_features=o, + bias=True, + ) + # Layer for processing node features + self.radial_field = torch.nn.Sequential([self.layer(1,self.hidden_dim), + torch.nn.ReLU, + self.layer(self.hidden_dim,1)] + ) + + + def forward(self, x, edge_index): + """ + Forward pass of the block, triggering the message-passing routine. + + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indices. In the original formulation, + the messages are aggregated from all nodes, not only from the neighbours. + :return: The updated node features. + :rtype: torch.Tensor + """ + return self.propagate(edge_index=edge_index, x=x) + + def message(self, x_j, x_i): + """ + Compute the message to be passed between nodes and edges. + + :param x_j: Concatenation of the node position and the + node features of the sender nodes. + :type x_j: torch.Tensor | LabelTensor + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The message to be passed. + :rtype: torch.Tensor + """ + r = torch.norm(x_i-x_j)*(x_i-x_j) + + + return self.activation(self.radial_field(r)) + + + def update(self, message, x): + """ + Update the node features with the received messages. + + :param torch.Tensor message: The message to be passed. + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :return: The updated node features. + :rtype: torch.Tensor + """ + return x + message From cbf4e80fdd19a1f3901e47911b5103efb5fdc0fc Mon Sep 17 00:00:00 2001 From: AleDinve Date: Thu, 24 Apr 2025 15:36:58 -0400 Subject: [PATCH 5/8] add schnet block --- .../radial_field_network_block.py | 32 ++-- .../block/message_passing/schnet_block.py | 154 ++++++++++++++++++ 2 files changed, 166 insertions(+), 20 deletions(-) create mode 100644 pina/model/block/message_passing/schnet_block.py diff --git a/pina/model/block/message_passing/radial_field_network_block.py b/pina/model/block/message_passing/radial_field_network_block.py index 4f5982fe3..f7d55a948 100644 --- a/pina/model/block/message_passing/radial_field_network_block.py +++ b/pina/model/block/message_passing/radial_field_network_block.py @@ -1,6 +1,7 @@ """Module for the Radial Field Network block.""" import torch +from ....model import FeedForward from torch_geometric.nn import MessagePassing from ....utils import check_consistency @@ -34,7 +35,8 @@ def __init__( self, node_feature_dim, hidden_dim, - edge_feature_dim, + radial_hidden_dim=16, + n_radial_layers=2, activation=torch.nn.ReLU, aggr="add", node_dim=-2, @@ -66,7 +68,6 @@ def __init__( # Check consistency check_consistency(node_feature_dim, int) - check_consistency(edge_feature_dim, int) # Check values if node_feature_dim <= 0: @@ -75,27 +76,18 @@ def __init__( f" got {node_feature_dim}." ) - if edge_feature_dim <= 0: - raise ValueError( - "`edge_feature_dim` must be a positive integer," - f" got {edge_feature_dim}." - ) - - # Initialize parameters self.node_feature_dim = node_feature_dim - self.edge_feature_dim = edge_feature_dim self.hidden_dim = hidden_dim self.activation = activation - self.layer = lambda i,o: torch.nn.Linear( - in_features=i, - out_features=o, - bias=True, - ) + # Layer for processing node features - self.radial_field = torch.nn.Sequential([self.layer(1,self.hidden_dim), - torch.nn.ReLU, - self.layer(self.hidden_dim,1)] + self.radial_field = FeedForward( + input_dimensions=1, + output_dimensions=1, + inner_size=radial_hidden_dim, + n_layers=n_radial_layers, + func=self.activation, ) @@ -124,10 +116,10 @@ def message(self, x_j, x_i): :return: The message to be passed. :rtype: torch.Tensor """ - r = torch.norm(x_i-x_j)*(x_i-x_j) + r = torch.norm(x_i-x_j) - return self.activation(self.radial_field(r)) + return self.radial_field(r)*(x_i-x_j) def update(self, message, x): diff --git a/pina/model/block/message_passing/schnet_block.py b/pina/model/block/message_passing/schnet_block.py new file mode 100644 index 000000000..955fbbe8e --- /dev/null +++ b/pina/model/block/message_passing/schnet_block.py @@ -0,0 +1,154 @@ +"""Module for the Schnet block.""" + +import torch +from ....model import FeedForward +from torch_geometric.nn import MessagePassing +from ....utils import check_consistency + + +class SchnetBlock(MessagePassing): + """ + Implementation of the Schnet block. + + This block is used to perform message-passing between nodes and edges in a + graph neural network, following the scheme proposed by Schütt et al. (2017). + It serves as an inner block in a larger graph neural network architecture. + + The message between two nodes connected by an edge is computed by applying a + linear transformation to the sender node features and the edge features, + followed by a non-linear activation function. Messages are then aggregated + using an aggregation scheme (e.g., sum, mean, min, max, or product). + + The update step is performed by a simple addition of the incoming messages + to the node features. + + .. seealso:: + + **Original reference** Schütt, K., Kindermans, P. J., Sauceda Felix, H. E., Chmiela, S., Tkatchenko, A., & Müller, K. R. (2017). + Schnet: A continuous-filter convolutional neural network for modeling quantum interactions. + Advances in neural information processing systems, 30. + """ + + + + def __init__( + self, + node_feature_dim, + node_pos_dim, + hidden_dim, + radial_hidden_dim=16, + n_message_layers=2, + n_update_layers=2, + n_radial_layers=2, + activation=torch.nn.ReLU, + aggr="add", + node_dim=-2, + flow="source_to_target", + ): + """ + Initialization of the :class:`RadialFieldNetworkBlock` class. + + :param int node_feature_dim: The dimension of the node features. + :param int edge_feature_dim: The dimension of the edge features. + :param torch.nn.Module activation: The activation function. + Default is :class:`torch.nn.Tanh`. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int node_dim: The axis along which to propagate. Default is -2. + :param str flow: The direction of message passing. Available options + are "source_to_target" and "target_to_source". + The "source_to_target" flow means that messages are sent from + the source node to the target node, while the "target_to_source" + flow means that messages are sent from the target node to the + source node. See :class:`torch_geometric.nn.MessagePassing` for more + details. Default is "source_to_target". + :raises ValueError: If `node_feature_dim` is not a positive integer. + :raises ValueError: If `edge_feature_dim` is not a positive integer. + """ + super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) + + # Check consistency + check_consistency(node_feature_dim, int) + + # Check values + if node_feature_dim <= 0: + raise ValueError( + "`node_feature_dim` must be a positive integer," + f" got {node_feature_dim}." + ) + + + # Initialize parameters + self.node_feature_dim = node_feature_dim + self.node_pos_dim = node_pos_dim + self.hidden_dim = hidden_dim + self.activation = activation + + # Layer for processing node features + self.radial_field = FeedForward( + input_dimensions=1, + output_dimensions=1, + inner_size=radial_hidden_dim, + n_layers=n_radial_layers, + func=self.activation, + ) + + self.update_net = FeedForward( + input_dimensions=self.node_pos_dim + self.hidden_dim, + output_dimensions=self.hidden_dim, + inner_size=self.hidden_dim, + n_layers=n_update_layers, + func=self.activation, + ) + + self.message_net = FeedForward( + input_dimensions=self.node_feature_dim, + output_dimensions=self.node_pos_dim + self.hidden_dim, + inner_size=self.hidden_dim, + n_layers=n_message_layers, + func=self.activation, + ) + + + def forward(self, x, pos, edge_index): + """ + Forward pass of the block, triggering the message-passing routine. + + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indices. In the original formulation, + the messages are aggregated from all nodes, not only from the neighbours. + :return: The updated node features. + :rtype: torch.Tensor + """ + return self.propagate(edge_index=edge_index, x=x, pos=pos) + + def message(self, x_i, pos_i ,pos_j): + """ + Compute the message to be passed between nodes and edges. + + :param x_j: Concatenation of the node position and the + node features of the sender nodes. + :type x_j: torch.Tensor | LabelTensor + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The message to be passed. + :rtype: torch.Tensor + """ + + return self.radial_field(torch.norm(pos_i-pos_j))*self.message_net(x_i) + + + def update(self, message, pos): + """ + Update the node features with the received messages. + + :param torch.Tensor message: The message to be passed. + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :return: The concatenation of the update position features and the updated node features. + :rtype: torch.Tensor + """ + return self.update_net(torch.cat((pos, message), dim=-1)) From f13a5f31d374d870cb3c87ee96ed9863884baad3 Mon Sep 17 00:00:00 2001 From: AleDinve Date: Sun, 11 May 2025 23:57:19 -0400 Subject: [PATCH 6/8] add equivariant network block --- .../model/block/message_passing/egnn_block.py | 118 ++++++++++++------ .../radial_field_network_block.py | 4 +- .../block/message_passing/schnet_block.py | 5 +- 3 files changed, 82 insertions(+), 45 deletions(-) diff --git a/pina/model/block/message_passing/egnn_block.py b/pina/model/block/message_passing/egnn_block.py index 7c137ac0e..b6a605070 100644 --- a/pina/model/block/message_passing/egnn_block.py +++ b/pina/model/block/message_passing/egnn_block.py @@ -7,12 +7,30 @@ class EnEquivariantGraphBlock(MessagePassing): """ - TODO + Implementation of the E(n) Equivariant Graph Neural Network block. + + This block is used to perform message-passing between nodes and edges in a + graph neural network, following the scheme proposed by Satorras et al. (2021). + It serves as an inner block in a larger graph neural network architecture. + + The message between two nodes connected by an edge is computed by applying a + linear transformation to the sender node features and the edge features, + followed by a non-linear activation function. Messages are then aggregated + using an aggregation scheme (e.g., sum, mean, min, max, or product). + + The update step is performed by a simple addition of the incoming messages + to the node features. + + .. seealso:: + + **Original reference** Satorras, V. G., Hoogeboom, E., & Welling, M. (2021, July). + E (n) equivariant graph neural networks. + In International conference on machine learning (pp. 9323-9332). PMLR. """ def __init__( self, - channels_h, + channels_x, channels_m, channels_a, aggr: str = "add", @@ -20,78 +38,100 @@ def __init__( **kwargs, ): """ - TODO + Initialization of the :class:`EnEquivariantGraphBlock` class. + + :param int channels_x: The dimension of the node features. + :param int channels_m: The dimension of the Euclidean coordinates (should be =3). + :param int channels_a: The dimension of the edge features. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int hidden_channels_dim: The hidden dimension in each MLPs initialized in the block. """ super().__init__(aggr=aggr, **kwargs) self.phi_e = torch.nn.Sequential( - torch.nn.Linear(2 * channels_h + 1 + channels_a, hidden_channels), + torch.nn.Linear(2 * channels_x + 1 + channels_a, hidden_channels), torch.nn.LayerNorm(hidden_channels), torch.nn.SiLU(), torch.nn.Linear(hidden_channels, channels_m), torch.nn.LayerNorm(channels_m), torch.nn.SiLU(), ) - self.phi_x = torch.nn.Sequential( + self.phi_pos = torch.nn.Sequential( torch.nn.Linear(channels_m, hidden_channels), torch.nn.LayerNorm(hidden_channels), torch.nn.SiLU(), torch.nn.Linear(hidden_channels, 1), ) - self.phi_h = torch.nn.Sequential( - torch.nn.Linear(channels_h + channels_m, hidden_channels), + self.phi_x = torch.nn.Sequential( + torch.nn.Linear(channels_x + channels_m, hidden_channels), torch.nn.LayerNorm(hidden_channels), torch.nn.SiLU(), - torch.nn.Linear(hidden_channels, channels_h), + torch.nn.Linear(hidden_channels, channels_x), ) - def forward(self, x, h, edge_attr, edge_index, c=None): + def forward(self, x, pos, edge_attr, edge_index, c=None): """ - TODO + Forward pass of the block, triggering the message-passing routine. + + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param pos_i: 3D Euclidean coordinates. + :type pos_i: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indices. In the original formulation, + the messages are aggregated from all nodes, not only from the neighbours. + :return: The updated node features. + :rtype: torch.Tensor """ if c is None: - c = degree(edge_index[0], x.shape[0]).unsqueeze(-1) + c = degree(edge_index[0], pos.shape[0]).unsqueeze(-1) return self.propagate( - edge_index=edge_index, x=x, h=h, edge_attr=edge_attr, c=c + edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr, c=c ) - def message(self, x_i, x_j, h_i, h_j, edge_attr): + def message(self, x_i, x_j, pos_i, pos_j, edge_attr): """ - TODO + Compute the message to be passed between nodes and edges. + + :param x_i: Node features of the sender nodes. + :type x_i: torch.Tensor | LabelTensor + :param pos_i: 3D Euclidean coordinates of the sender nodes. + :type pos_i: torch.Tensor | LabelTensor + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The message to be passed. + :rtype: torch.Tensor """ - mh_ij = self.phi_e( + mpos_ij = self.phi_e( torch.cat( [ - h_i, - h_j, - torch.norm(x_i - x_j, dim=-1, keepdim=True) ** 2, + x_i, + x_j, + torch.norm(pos_i - pos_j, dim=-1, keepdim=True) ** 2, edge_attr, ], dim=-1, ) ) - mx_ij = (x_i - x_j) * self.phi_x(mh_ij) - return torch.cat((mx_ij, mh_ij), dim=-1) + mpos_ij = (pos_i - pos_j) * self.phi_pos(mpos_ij) + return mpos_ij - def update(self, aggr_out, x, h, edge_attr, c): + def update(self, message, x, pos, c): """ - TODO - """ - m_x, m_h = aggr_out[:, : self.m_len], aggr_out[:, self.m_len :] - h_l1 = self.phi_h(torch.cat([h, m_h], dim=-1)) - x_l1 = x + (m_x / c) - return x_l1, h_l1 + Update the node features with the received messages. - @property - def edge_function(self): - """ - TODO - """ - return self._edge_function - - @property - def attribute_function(self): - """ - TODO + :param torch.Tensor message: The message to be passed. + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param pos: The 3D Euclidean coordinates of the nodes. + :type pos: torch.Tensor | LabelTensor + :param c: the constant that divides the aggregated message (it should be (M-1), where M is the number of nodes) + :type pos: torch.Tensor + :return: The concatenation of the update position features and the updated node features. + :rtype: torch.Tensor """ - return self._attribute_function + x = self.phi_x(torch.cat([x, message], dim=-1)) + pos = pos + (message / c) + return pos, x diff --git a/pina/model/block/message_passing/radial_field_network_block.py b/pina/model/block/message_passing/radial_field_network_block.py index f7d55a948..0d3257d48 100644 --- a/pina/model/block/message_passing/radial_field_network_block.py +++ b/pina/model/block/message_passing/radial_field_network_block.py @@ -62,7 +62,6 @@ def __init__( source node. See :class:`torch_geometric.nn.MessagePassing` for more details. Default is "source_to_target". :raises ValueError: If `node_feature_dim` is not a positive integer. - :raises ValueError: If `edge_feature_dim` is not a positive integer. """ super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) @@ -108,8 +107,7 @@ def message(self, x_j, x_i): """ Compute the message to be passed between nodes and edges. - :param x_j: Concatenation of the node position and the - node features of the sender nodes. + :param x_j: Node features of the sender nodes. :type x_j: torch.Tensor | LabelTensor :param edge_attr: The edge attributes. :type edge_attr: torch.Tensor | LabelTensor diff --git a/pina/model/block/message_passing/schnet_block.py b/pina/model/block/message_passing/schnet_block.py index 955fbbe8e..7ee2b129c 100644 --- a/pina/model/block/message_passing/schnet_block.py +++ b/pina/model/block/message_passing/schnet_block.py @@ -46,7 +46,7 @@ def __init__( flow="source_to_target", ): """ - Initialization of the :class:`RadialFieldNetworkBlock` class. + Initialization of the :class:`SchnetBlock` class. :param int node_feature_dim: The dimension of the node features. :param int edge_feature_dim: The dimension of the edge features. @@ -129,8 +129,7 @@ def message(self, x_i, pos_i ,pos_j): """ Compute the message to be passed between nodes and edges. - :param x_j: Concatenation of the node position and the - node features of the sender nodes. + :param x_j: Node features of the sender nodes. :type x_j: torch.Tensor | LabelTensor :param edge_attr: The edge attributes. :type edge_attr: torch.Tensor | LabelTensor From 2cd5a5ef422ce1f2bf42331743ab9ea86b463ae3 Mon Sep 17 00:00:00 2001 From: giovanni Date: Thu, 29 May 2025 23:12:30 +0200 Subject: [PATCH 7/8] fix + tests + doc files --- docs/source/_rst/_code.rst | 12 ++ .../deep_tensor_network_block.rst | 8 + .../en_equivariant_network_block.rst | 8 + .../interaction_network_block.rst | 8 + .../radial_field_network_block.rst | 8 + .../block/message_passing/schnet_block.rst | 8 + pina/model/block/message_passing/__init__.py | 6 + .../deep_tensor_network_block.py | 54 ++---- .../model/block/message_passing/egnn_block.py | 137 -------------- .../en_equivariant_network_block.py | 176 ++++++++++++++++++ .../interaction_network_block.py | 95 ++++------ .../radial_field_network_block.py | 87 ++++----- .../block/message_passing/schnet_block.py | 143 +++++++------- pina/utils.py | 19 ++ .../test_deep_tensor_network_block.py | 59 ++++++ .../test_equivariant_network_block.py | 130 +++++++++++++ .../test_interaction_network_block.py | 84 +++++++++ .../test_radial_field_network_block.py | 67 +++++++ .../test_messagepassing/test_schnet_block.py | 73 ++++++++ tests/test_utils.py | 30 ++- 20 files changed, 862 insertions(+), 350 deletions(-) create mode 100644 docs/source/_rst/model/block/message_passing/deep_tensor_network_block.rst create mode 100644 docs/source/_rst/model/block/message_passing/en_equivariant_network_block.rst create mode 100644 docs/source/_rst/model/block/message_passing/interaction_network_block.rst create mode 100644 docs/source/_rst/model/block/message_passing/radial_field_network_block.rst create mode 100644 docs/source/_rst/model/block/message_passing/schnet_block.rst delete mode 100644 pina/model/block/message_passing/egnn_block.py create mode 100644 pina/model/block/message_passing/en_equivariant_network_block.py create mode 100644 tests/test_messagepassing/test_deep_tensor_network_block.py create mode 100644 tests/test_messagepassing/test_equivariant_network_block.py create mode 100644 tests/test_messagepassing/test_interaction_network_block.py create mode 100644 tests/test_messagepassing/test_radial_field_network_block.py create mode 100644 tests/test_messagepassing/test_schnet_block.py diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index ba059ddbc..957eb6e17 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -122,6 +122,18 @@ Blocks Continuous Convolution Block Orthogonal Block +Message Passing +------------------- + +.. toctree:: + :titlesonly: + + Deep Tensor Network Block + E(n) Equivariant Network Block + Interaction Network Block + Radial Field Network Block + Schnet Block + Reduction and Embeddings -------------------------- diff --git a/docs/source/_rst/model/block/message_passing/deep_tensor_network_block.rst b/docs/source/_rst/model/block/message_passing/deep_tensor_network_block.rst new file mode 100644 index 000000000..30121e5a6 --- /dev/null +++ b/docs/source/_rst/model/block/message_passing/deep_tensor_network_block.rst @@ -0,0 +1,8 @@ +Deep Tensor Network Block +================================== +.. currentmodule:: pina.model.block.message_passing.deep_tensor_network_block + +.. autoclass:: DeepTensorNetworkBlock + :members: + :show-inheritance: + :noindex: diff --git a/docs/source/_rst/model/block/message_passing/en_equivariant_network_block.rst b/docs/source/_rst/model/block/message_passing/en_equivariant_network_block.rst new file mode 100644 index 000000000..e2755c665 --- /dev/null +++ b/docs/source/_rst/model/block/message_passing/en_equivariant_network_block.rst @@ -0,0 +1,8 @@ +E(n) Equivariant Network Block +================================== +.. currentmodule:: pina.model.block.message_passing.en_equivariant_network_block + +.. autoclass:: EnEquivariantNetworkBlock + :members: + :show-inheritance: + :noindex: \ No newline at end of file diff --git a/docs/source/_rst/model/block/message_passing/interaction_network_block.rst b/docs/source/_rst/model/block/message_passing/interaction_network_block.rst new file mode 100644 index 000000000..ffac307e2 --- /dev/null +++ b/docs/source/_rst/model/block/message_passing/interaction_network_block.rst @@ -0,0 +1,8 @@ +Interaction Network Block +================================== +.. currentmodule:: pina.model.block.message_passing.interaction_network_block + +.. autoclass:: InteractionNetworkBlock + :members: + :show-inheritance: + :noindex: \ No newline at end of file diff --git a/docs/source/_rst/model/block/message_passing/radial_field_network_block.rst b/docs/source/_rst/model/block/message_passing/radial_field_network_block.rst new file mode 100644 index 000000000..e05203f33 --- /dev/null +++ b/docs/source/_rst/model/block/message_passing/radial_field_network_block.rst @@ -0,0 +1,8 @@ +Radial Field Network Block +================================== +.. currentmodule:: pina.model.block.message_passing.radial_field_network_block + +.. autoclass:: RadialFieldNetworkBlock + :members: + :show-inheritance: + :noindex: \ No newline at end of file diff --git a/docs/source/_rst/model/block/message_passing/schnet_block.rst b/docs/source/_rst/model/block/message_passing/schnet_block.rst new file mode 100644 index 000000000..c5baa2730 --- /dev/null +++ b/docs/source/_rst/model/block/message_passing/schnet_block.rst @@ -0,0 +1,8 @@ +Schnet Block +================================== +.. currentmodule:: pina.model.block.message_passing.schnet_block + +.. autoclass:: SchnetBlock + :members: + :show-inheritance: + :noindex: \ No newline at end of file diff --git a/pina/model/block/message_passing/__init__.py b/pina/model/block/message_passing/__init__.py index a4b122016..4eed0a611 100644 --- a/pina/model/block/message_passing/__init__.py +++ b/pina/model/block/message_passing/__init__.py @@ -3,7 +3,13 @@ __all__ = [ "InteractionNetworkBlock", "DeepTensorNetworkBlock", + "EnEquivariantNetworkBlock", + "RadialFieldNetworkBlock", + "SchnetBlock", ] from .interaction_network_block import InteractionNetworkBlock from .deep_tensor_network_block import DeepTensorNetworkBlock +from .en_equivariant_network_block import EnEquivariantNetworkBlock +from .radial_field_network_block import RadialFieldNetworkBlock +from .schnet_block import SchnetBlock diff --git a/pina/model/block/message_passing/deep_tensor_network_block.py b/pina/model/block/message_passing/deep_tensor_network_block.py index 950e32f05..a2de3097a 100644 --- a/pina/model/block/message_passing/deep_tensor_network_block.py +++ b/pina/model/block/message_passing/deep_tensor_network_block.py @@ -2,7 +2,7 @@ import torch from torch_geometric.nn import MessagePassing -from ....utils import check_consistency +from ....utils import check_positive_integer class DeepTensorNetworkBlock(MessagePassing): @@ -10,8 +10,9 @@ class DeepTensorNetworkBlock(MessagePassing): Implementation of the Deep Tensor Network block. This block is used to perform message-passing between nodes and edges in a - graph neural network, following the scheme proposed by Schutt et al. (2017). - It serves as an inner block in a larger graph neural network architecture. + graph neural network, following the scheme proposed by Schutt et al. in + 2017. It serves as an inner block in a larger graph neural network + architecture. The message between two nodes connected by an edge is computed by applying a linear transformation to the sender node features and the edge features, @@ -24,9 +25,9 @@ class DeepTensorNetworkBlock(MessagePassing): .. seealso:: **Original reference**: Schutt, K., Arbabzadah, F., Chmiela, S. et al. - *Quantum-Chemical Insights from Deep Tensor Neural Networks*. + (2017). *Quantum-Chemical Insights from Deep Tensor Neural Networks*. Nature Communications 8, 13890 (2017). - DOI: `_`. + DOI: ``_. """ def __init__( @@ -39,7 +40,7 @@ def __init__( flow="source_to_target", ): """ - Initialization of the :class:`DeepTensorNetworkBlocklock` class. + Initialization of the :class:`DeepTensorNetworkBlock` class. :param int node_feature_dim: The dimension of the node features. :param int edge_feature_dim: The dimension of the edge features. @@ -57,51 +58,36 @@ def __init__( flow means that messages are sent from the target node to the source node. See :class:`torch_geometric.nn.MessagePassing` for more details. Default is "source_to_target". - :raises ValueError: If `node_feature_dim` is not a positive integer. - :raises ValueError: If `edge_feature_dim` is not a positive integer. + :raises AssertionError: If `node_feature_dim` is not a positive integer. + :raises AssertionError: If `edge_feature_dim` is not a positive integer. """ super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) - # Check consistency - check_consistency(node_feature_dim, int) - check_consistency(edge_feature_dim, int) - # Check values - if node_feature_dim <= 0: - raise ValueError( - "`node_feature_dim` must be a positive integer," - f" got {node_feature_dim}." - ) - - if edge_feature_dim <= 0: - raise ValueError( - "`edge_feature_dim` must be a positive integer," - f" got {edge_feature_dim}." - ) - - # Initialize parameters - self.node_feature_dim = node_feature_dim - self.edge_feature_dim = edge_feature_dim - self.activation = activation + check_positive_integer(node_feature_dim, strict=True) + check_positive_integer(edge_feature_dim, strict=True) + + # Activation function + self.activation = activation() # Layer for processing node features self.node_layer = torch.nn.Linear( - in_features=self.node_feature_dim, - out_features=self.node_feature_dim, + in_features=node_feature_dim, + out_features=node_feature_dim, bias=True, ) # Layer for processing edge features self.edge_layer = torch.nn.Linear( - in_features=self.edge_feature_dim, - out_features=self.node_feature_dim, + in_features=edge_feature_dim, + out_features=node_feature_dim, bias=True, ) # Layer for computing the message self.message_layer = torch.nn.Linear( - in_features=self.node_feature_dim, - out_features=self.node_feature_dim, + in_features=node_feature_dim, + out_features=node_feature_dim, bias=False, ) diff --git a/pina/model/block/message_passing/egnn_block.py b/pina/model/block/message_passing/egnn_block.py deleted file mode 100644 index b6a605070..000000000 --- a/pina/model/block/message_passing/egnn_block.py +++ /dev/null @@ -1,137 +0,0 @@ -"""Module for the E(n) Equivariant Graph Neural Network block.""" - -import torch -from torch_geometric.nn import MessagePassing -from torch_geometric.utils import degree - - -class EnEquivariantGraphBlock(MessagePassing): - """ - Implementation of the E(n) Equivariant Graph Neural Network block. - - This block is used to perform message-passing between nodes and edges in a - graph neural network, following the scheme proposed by Satorras et al. (2021). - It serves as an inner block in a larger graph neural network architecture. - - The message between two nodes connected by an edge is computed by applying a - linear transformation to the sender node features and the edge features, - followed by a non-linear activation function. Messages are then aggregated - using an aggregation scheme (e.g., sum, mean, min, max, or product). - - The update step is performed by a simple addition of the incoming messages - to the node features. - - .. seealso:: - - **Original reference** Satorras, V. G., Hoogeboom, E., & Welling, M. (2021, July). - E (n) equivariant graph neural networks. - In International conference on machine learning (pp. 9323-9332). PMLR. - """ - - def __init__( - self, - channels_x, - channels_m, - channels_a, - aggr: str = "add", - hidden_channels: int = 64, - **kwargs, - ): - """ - Initialization of the :class:`EnEquivariantGraphBlock` class. - - :param int channels_x: The dimension of the node features. - :param int channels_m: The dimension of the Euclidean coordinates (should be =3). - :param int channels_a: The dimension of the edge features. - :param str aggr: The aggregation scheme to use for message passing. - Available options are "add", "mean", "min", "max", "mul". - See :class:`torch_geometric.nn.MessagePassing` for more details. - Default is "add". - :param int hidden_channels_dim: The hidden dimension in each MLPs initialized in the block. - """ - super().__init__(aggr=aggr, **kwargs) - - self.phi_e = torch.nn.Sequential( - torch.nn.Linear(2 * channels_x + 1 + channels_a, hidden_channels), - torch.nn.LayerNorm(hidden_channels), - torch.nn.SiLU(), - torch.nn.Linear(hidden_channels, channels_m), - torch.nn.LayerNorm(channels_m), - torch.nn.SiLU(), - ) - self.phi_pos = torch.nn.Sequential( - torch.nn.Linear(channels_m, hidden_channels), - torch.nn.LayerNorm(hidden_channels), - torch.nn.SiLU(), - torch.nn.Linear(hidden_channels, 1), - ) - self.phi_x = torch.nn.Sequential( - torch.nn.Linear(channels_x + channels_m, hidden_channels), - torch.nn.LayerNorm(hidden_channels), - torch.nn.SiLU(), - torch.nn.Linear(hidden_channels, channels_x), - ) - - def forward(self, x, pos, edge_attr, edge_index, c=None): - """ - Forward pass of the block, triggering the message-passing routine. - - :param x: The node features. - :type x: torch.Tensor | LabelTensor - :param pos_i: 3D Euclidean coordinates. - :type pos_i: torch.Tensor | LabelTensor - :param torch.Tensor edge_index: The edge indices. In the original formulation, - the messages are aggregated from all nodes, not only from the neighbours. - :return: The updated node features. - :rtype: torch.Tensor - """ - if c is None: - c = degree(edge_index[0], pos.shape[0]).unsqueeze(-1) - return self.propagate( - edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr, c=c - ) - - def message(self, x_i, x_j, pos_i, pos_j, edge_attr): - """ - Compute the message to be passed between nodes and edges. - - :param x_i: Node features of the sender nodes. - :type x_i: torch.Tensor | LabelTensor - :param pos_i: 3D Euclidean coordinates of the sender nodes. - :type pos_i: torch.Tensor | LabelTensor - :param edge_attr: The edge attributes. - :type edge_attr: torch.Tensor | LabelTensor - :return: The message to be passed. - :rtype: torch.Tensor - """ - mpos_ij = self.phi_e( - torch.cat( - [ - x_i, - x_j, - torch.norm(pos_i - pos_j, dim=-1, keepdim=True) ** 2, - edge_attr, - ], - dim=-1, - ) - ) - mpos_ij = (pos_i - pos_j) * self.phi_pos(mpos_ij) - return mpos_ij - - def update(self, message, x, pos, c): - """ - Update the node features with the received messages. - - :param torch.Tensor message: The message to be passed. - :param x: The node features. - :type x: torch.Tensor | LabelTensor - :param pos: The 3D Euclidean coordinates of the nodes. - :type pos: torch.Tensor | LabelTensor - :param c: the constant that divides the aggregated message (it should be (M-1), where M is the number of nodes) - :type pos: torch.Tensor - :return: The concatenation of the update position features and the updated node features. - :rtype: torch.Tensor - """ - x = self.phi_x(torch.cat([x, message], dim=-1)) - pos = pos + (message / c) - return pos, x diff --git a/pina/model/block/message_passing/en_equivariant_network_block.py b/pina/model/block/message_passing/en_equivariant_network_block.py new file mode 100644 index 000000000..fa256e9d5 --- /dev/null +++ b/pina/model/block/message_passing/en_equivariant_network_block.py @@ -0,0 +1,176 @@ +"""Module for the E(n) Equivariant Graph Neural Network block.""" + +import torch +from torch_geometric.nn import MessagePassing +from torch_geometric.utils import degree +from ....utils import check_positive_integer +from ....model import FeedForward + + +class EnEquivariantNetworkBlock(MessagePassing): + """ + Implementation of the E(n) Equivariant Graph Neural Network block. + + This block is used to perform message-passing between nodes and edges in a + graph neural network, following the scheme proposed by Satorras et al. in + 2021. It serves as an inner block in a larger graph neural network + architecture. + + The message between two nodes connected by an edge is computed by applying a + linear transformation to the sender node features and the edge features, + together with the squared euclidean distance between the sender and + recipient node positions, followed by a non-linear activation function. + Messages are then aggregated using an aggregation scheme (e.g., sum, mean, + min, max, or product). + + The update step is performed by applying another MLP to the concatenation of + the incoming messages and the node features. Here, also the node + positions are updated by adding the incoming messages divided by the + degree of the recipient node. + + .. seealso:: + + **Original reference** Satorras, V. G., Hoogeboom, E., Welling, M. + (2021). *E(n) Equivariant Graph Neural Networks.* + In International Conference on Machine Learning. + DOI: ``_. + """ + + def __init__( + self, + node_feature_dim, + edge_feature_dim, + pos_dim, + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + activation=torch.nn.SiLU, + aggr="add", + node_dim=-2, + flow="source_to_target", + ): + """ + Initialization of the :class:`EnEquivariantNetworkBlock` class. + + :param int node_feature_dim: The dimension of the node features. + :param int edge_feature_dim: The dimension of the edge features. + :param int pos_dim: The dimension of the position features. + :param int hidden_dim: The dimension of the hidden features. + Default is 64. + :param int n_message_layers: The number of layers in the message + network. Default is 2. + :param int n_update_layers: The number of layers in the update network. + Default is 2. + :param torch.nn.Module activation: The activation function. + Default is :class:`torch.nn.SiLU`. + :param str aggr: The aggregation scheme to use for message passing. + Available options are "add", "mean", "min", "max", "mul". + See :class:`torch_geometric.nn.MessagePassing` for more details. + Default is "add". + :param int node_dim: The axis along which to propagate. Default is -2. + :param str flow: The direction of message passing. Available options + are "source_to_target" and "target_to_source". + The "source_to_target" flow means that messages are sent from + the source node to the target node, while the "target_to_source" + flow means that messages are sent from the target node to the + source node. See :class:`torch_geometric.nn.MessagePassing` for more + details. Default is "source_to_target". + :raises AssertionError: If `node_feature_dim` is not a positive integer. + :raises AssertionError: If `edge_feature_dim` is a negative integer. + :raises AssertionError: If `pos_dim` is not a positive integer. + :raises AssertionError: If `hidden_dim` is not a positive integer. + :raises AssertionError: If `n_message_layers` is not a positive integer. + :raises AssertionError: If `n_update_layers` is not a positive integer. + """ + super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) + + # Check values + check_positive_integer(node_feature_dim, strict=True) + check_positive_integer(edge_feature_dim, strict=False) + check_positive_integer(pos_dim, strict=True) + check_positive_integer(hidden_dim, strict=True) + check_positive_integer(n_message_layers, strict=True) + check_positive_integer(n_update_layers, strict=True) + + # Layer for computing the message + self.message_net = FeedForward( + input_dimensions=2 * node_feature_dim + edge_feature_dim + 1, + output_dimensions=pos_dim, + inner_size=hidden_dim, + n_layers=n_message_layers, + func=activation, + ) + + # Layer for updating the node features + self.update_net = FeedForward( + input_dimensions=node_feature_dim + pos_dim, + output_dimensions=node_feature_dim, + inner_size=hidden_dim, + n_layers=n_update_layers, + func=activation, + ) + + def forward(self, x, pos, edge_index, edge_attr=None): + """ + Forward pass of the block, triggering the message-passing routine. + + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param pos: The euclidean coordinates of the nodes. + :type pos: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indices. + :param edge_attr: The edge attributes. Default is None. + :type edge_attr: torch.Tensor | LabelTensor + :return: The updated node features and node positions. + :rtype: tuple(torch.Tensor, torch.Tensor) + """ + return self.propagate( + edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr + ) + + def message(self, x_i, x_j, pos_i, pos_j, edge_attr): + """ + Compute the message to be passed between nodes and edges. + + :param x_i: The node features of the recipient nodes. + :type x_i: torch.Tensor | LabelTensor + :param x_j: The node features of the sender nodes. + :type x_j: torch.Tensor | LabelTensor + :param pos_i: The node coordinates of the recipient nodes. + :type pos_i: torch.Tensor | LabelTensor + :param pos_j: The node coordinates of the sender nodes. + :type pos_j: torch.Tensor | LabelTensor + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor | LabelTensor + :return: The message to be passed. + :rtype: torch.Tensor + """ + dist = torch.norm(pos_i - pos_j, dim=-1, keepdim=True) ** 2 + if edge_attr is None: + input_ = torch.cat((x_i, x_j, dist), dim=-1) + else: + input_ = torch.cat((x_i, x_j, dist, edge_attr), dim=-1) + + return self.message_net(input_) + + def update(self, message, x, pos, edge_index): + """ + Update the node features and the node coordinates with the received + messages. + + :param torch.Tensor message: The message to be passed. + :param x: The node features. + :type x: torch.Tensor | LabelTensor + :param pos: The euclidean coordinates of the nodes. + :type pos: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge indices. + :return: The updated node features and node positions. + :rtype: tuple(torch.Tensor, torch.Tensor) + """ + # Update the node features + x = self.update_net(torch.cat((x, message), dim=-1)) + + # Update the node positions + c = degree(edge_index[0], pos.shape[0]).unsqueeze(-1) + pos = pos + message / c + return x, pos diff --git a/pina/model/block/message_passing/interaction_network_block.py b/pina/model/block/message_passing/interaction_network_block.py index f27169448..7c6eb03f6 100644 --- a/pina/model/block/message_passing/interaction_network_block.py +++ b/pina/model/block/message_passing/interaction_network_block.py @@ -2,8 +2,8 @@ import torch from torch_geometric.nn import MessagePassing +from ....utils import check_positive_integer from ....model import FeedForward -from ....utils import check_consistency class InteractionNetworkBlock(MessagePassing): @@ -11,9 +11,9 @@ class InteractionNetworkBlock(MessagePassing): Implementation of the Interaction Network block. This block is used to perform message-passing between nodes and edges in a - graph neural network, following the scheme proposed by Battaglia et al. - (2016). - It serves as an inner block in a larger graph neural network architecture. + graph neural network, following the scheme proposed by Battaglia et al. in + 2016. It serves as an inner block in a larger graph neural network + architecture. The message between two nodes connected by an edge is computed by applying a multi-layer perceptron (MLP) to the concatenation of the sender and @@ -29,13 +29,14 @@ class InteractionNetworkBlock(MessagePassing): *Interaction Networks for Learning about Objects, Relations and Physics*. In Advances in Neural Information Processing Systems (NeurIPS 2016). - DOI: `_`. + DOI: ``_. """ def __init__( self, node_feature_dim, - hidden_dim, + edge_feature_dim=0, + hidden_dim=64, n_message_layers=2, n_update_layers=2, activation=torch.nn.SiLU, @@ -47,7 +48,11 @@ def __init__( Initialization of the :class:`InteractionNetworkBlock` class. :param int node_feature_dim: The dimension of the node features. + :param int edge_feature_dim: The dimension of the edge features. + If edge_attr is not provided, it is assumed to be 0. + Default is 0. :param int hidden_dim: The dimension of the hidden features. + Default is 64. :param int n_message_layers: The number of layers in the message network. Default is 2. :param int n_update_layers: The number of layers in the update network. @@ -66,83 +71,55 @@ def __init__( flow means that messages are sent from the target node to the source node. See :class:`torch_geometric.nn.MessagePassing` for more details. Default is "source_to_target". - :raises ValueError: If `node_feature_dim` is not a positive integer. - :raises ValueError: If `hidden_dim` is not a positive integer. - :raises ValueError: If `n_message_layers` is not a positive integer. - :raises ValueError: If `n_update_layers` is not a positive integer. + :raises AssertionError: If `node_feature_dim` is not a positive integer. + :raises AssertionError: If `hidden_dim` is not a positive integer. + :raises AssertionError: If `n_message_layers` is not a positive integer. + :raises AssertionError: If `n_update_layers` is not a positive integer. + :raises AssertionError: If `edge_feature_dim` is not a non-negative + integer. """ super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) - # Check consistency - check_consistency(node_feature_dim, int) - check_consistency(hidden_dim, int) - check_consistency(n_message_layers, int) - check_consistency(n_update_layers, int) - # Check values - if node_feature_dim <= 0: - raise ValueError( - "`node_feature_dim` must be a positive integer," - f" got {node_feature_dim}." - ) - - if hidden_dim <= 0: - raise ValueError( - "`hidden_dim` must be a positive integer," f" got {hidden_dim}." - ) - - if n_message_layers <= 0: - raise ValueError( - "`n_message_layers` must be a positive integer," - f" got {n_message_layers}." - ) - - if n_update_layers <= 0: - raise ValueError( - "`n_update_layers` must be a positive integer," - f" got {n_update_layers}." - ) - - # Initialize parameters - self.node_feature_dim = node_feature_dim - self.hidden_dim = hidden_dim - self.activation = activation + check_positive_integer(node_feature_dim, strict=True) + check_positive_integer(hidden_dim, strict=True) + check_positive_integer(n_message_layers, strict=True) + check_positive_integer(n_update_layers, strict=True) + check_positive_integer(edge_feature_dim, strict=False) # Message network self.message_net = FeedForward( - input_dimensions=2 * self.node_feature_dim, - output_dimensions=self.hidden_dim, - inner_size=self.hidden_dim, + input_dimensions=2 * node_feature_dim + edge_feature_dim, + output_dimensions=hidden_dim, + inner_size=hidden_dim, n_layers=n_message_layers, - func=self.activation, + func=activation, ) # Update network self.update_net = FeedForward( - input_dimensions=self.node_feature_dim + self.hidden_dim, - output_dimensions=self.hidden_dim, - inner_size=self.node_feature_dim, + input_dimensions=node_feature_dim + hidden_dim, + output_dimensions=node_feature_dim, + inner_size=hidden_dim, n_layers=n_update_layers, - func=self.activation, + func=activation, ) - def forward(self, x, edge_index, edge_attr): + def forward(self, x, edge_index, edge_attr=None): """ Forward pass of the block, triggering the message-passing routine. :param x: The node features. :type x: torch.Tensor | LabelTensor :param torch.Tensor edge_index: The edge indeces. - :param edge_attr: The edge attributes. + :param edge_attr: The edge attributes. Default is None. :type edge_attr: torch.Tensor | LabelTensor :return: The updated node features. :rtype: torch.Tensor """ - - # TODO: edge_attr is not used in the message function return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr) - def message(self, x_i, x_j): + def message(self, x_i, x_j, edge_attr): """ Compute the message to be passed between nodes and edges. @@ -153,7 +130,11 @@ def message(self, x_i, x_j): :return: The message to be passed. :rtype: torch.Tensor """ - return self.message_net(torch.cat((x_i, x_j), dim=-1)) + if edge_attr is None: + input_ = torch.cat((x_i, x_j), dim=-1) + else: + input_ = torch.cat((x_i, x_j, edge_attr), dim=-1) + return self.message_net(input_) def update(self, message, x): """ diff --git a/pina/model/block/message_passing/radial_field_network_block.py b/pina/model/block/message_passing/radial_field_network_block.py index 0d3257d48..ef621b10e 100644 --- a/pina/model/block/message_passing/radial_field_network_block.py +++ b/pina/model/block/message_passing/radial_field_network_block.py @@ -1,9 +1,10 @@ """Module for the Radial Field Network block.""" import torch -from ....model import FeedForward from torch_geometric.nn import MessagePassing -from ....utils import check_consistency +from torch_geometric.utils import remove_self_loops +from ....utils import check_positive_integer +from ....model import FeedForward class RadialFieldNetworkBlock(MessagePassing): @@ -11,33 +12,35 @@ class RadialFieldNetworkBlock(MessagePassing): Implementation of the Radial Field Network block. This block is used to perform message-passing between nodes and edges in a - graph neural network, following the scheme proposed by Köhler et al. (2020). - It serves as an inner block in a larger graph neural network architecture. + graph neural network, following the scheme proposed by Köhler et al. in + 2020. It serves as an inner block in a larger graph neural network + architecture. The message between two nodes connected by an edge is computed by applying a - linear transformation to the sender node features and the edge features, - followed by a non-linear activation function. Messages are then aggregated - using an aggregation scheme (e.g., sum, mean, min, max, or product). + linear transformation to the norm of the difference between the sender and + recipient node features, together with the radial distance between the + sender and recipient node features, followed by a non-linear activation + function. Messages are then aggregated using an aggregation scheme + (e.g., sum, mean, min, max, or product). The update step is performed by a simple addition of the incoming messages to the node features. .. seealso:: - **Original reference** Köhler, J., Klein, L., & Noé, F. (2020, November). - Equivariant flows: exact likelihood generative learning for symmetric densities. - In International conference on machine learning (pp. 5361-5370). PMLR. + **Original reference** Köhler, J., Klein, L., Noé, F. (2020). + *Equivariant Flows: Exact Likelihood Generative Learning for Symmetric + Densities*. + In International Conference on Machine Learning. + DOI: ``_. """ - - def __init__( self, node_feature_dim, - hidden_dim, - radial_hidden_dim=16, - n_radial_layers=2, - activation=torch.nn.ReLU, + hidden_dim=64, + n_layers=2, + activation=torch.nn.Tanh, aggr="add", node_dim=-2, flow="source_to_target", @@ -46,7 +49,9 @@ def __init__( Initialization of the :class:`RadialFieldNetworkBlock` class. :param int node_feature_dim: The dimension of the node features. - :param int edge_feature_dim: The dimension of the edge features. + :param int hidden_dim: The dimension of the hidden features. + Default is 64. + :param int n_layers: The number of layers in the network. Default is 2. :param torch.nn.Module activation: The activation function. Default is :class:`torch.nn.Tanh`. :param str aggr: The aggregation scheme to use for message passing. @@ -61,64 +66,52 @@ def __init__( flow means that messages are sent from the target node to the source node. See :class:`torch_geometric.nn.MessagePassing` for more details. Default is "source_to_target". - :raises ValueError: If `node_feature_dim` is not a positive integer. + :raises AssertionError: If `node_feature_dim` is not a positive integer. + :raises AssertionError: If `hidden_dim` is not a positive integer. + :raises AssertionError: If `n_layers` is not a positive integer. """ super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) - # Check consistency - check_consistency(node_feature_dim, int) - # Check values - if node_feature_dim <= 0: - raise ValueError( - "`node_feature_dim` must be a positive integer," - f" got {node_feature_dim}." - ) - - # Initialize parameters - self.node_feature_dim = node_feature_dim - self.hidden_dim = hidden_dim - self.activation = activation + check_positive_integer(node_feature_dim, strict=True) + check_positive_integer(hidden_dim, strict=True) + check_positive_integer(n_layers, strict=True) # Layer for processing node features - self.radial_field = FeedForward( + self.radial_net = FeedForward( input_dimensions=1, output_dimensions=1, - inner_size=radial_hidden_dim, - n_layers=n_radial_layers, - func=self.activation, + inner_size=hidden_dim, + n_layers=n_layers, + func=activation, ) - def forward(self, x, edge_index): """ Forward pass of the block, triggering the message-passing routine. :param x: The node features. :type x: torch.Tensor | LabelTensor - :param torch.Tensor edge_index: The edge indices. In the original formulation, - the messages are aggregated from all nodes, not only from the neighbours. + :param torch.Tensor edge_index: The edge indices. :return: The updated node features. :rtype: torch.Tensor """ + edge_index, _ = remove_self_loops(edge_index) return self.propagate(edge_index=edge_index, x=x) - def message(self, x_j, x_i): + def message(self, x_i, x_j): """ Compute the message to be passed between nodes and edges. - :param x_j: Node features of the sender nodes. + :param x_i: The node features of the recipient nodes. + :type x_i: torch.Tensor | LabelTensor + :param x_j: The node features of the sender nodes. :type x_j: torch.Tensor | LabelTensor - :param edge_attr: The edge attributes. - :type edge_attr: torch.Tensor | LabelTensor :return: The message to be passed. :rtype: torch.Tensor """ - r = torch.norm(x_i-x_j) - - - return self.radial_field(r)*(x_i-x_j) - + r = x_i - x_j + return self.radial_net(torch.norm(r, dim=1, keepdim=True)) * r def update(self, message, x): """ diff --git a/pina/model/block/message_passing/schnet_block.py b/pina/model/block/message_passing/schnet_block.py index 7ee2b129c..94fe06364 100644 --- a/pina/model/block/message_passing/schnet_block.py +++ b/pina/model/block/message_passing/schnet_block.py @@ -1,9 +1,10 @@ """Module for the Schnet block.""" import torch -from ....model import FeedForward from torch_geometric.nn import MessagePassing -from ....utils import check_consistency +from torch_geometric.utils import remove_self_loops +from ....utils import check_positive_integer +from ....model import FeedForward class SchnetBlock(MessagePassing): @@ -11,36 +12,37 @@ class SchnetBlock(MessagePassing): Implementation of the Schnet block. This block is used to perform message-passing between nodes and edges in a - graph neural network, following the scheme proposed by Schütt et al. (2017). - It serves as an inner block in a larger graph neural network architecture. + graph neural network, following the scheme proposed by Schütt et al. in + 2017. It serves as an inner block in a larger graph neural network + architecture. - The message between two nodes connected by an edge is computed by applying a - linear transformation to the sender node features and the edge features, - followed by a non-linear activation function. Messages are then aggregated - using an aggregation scheme (e.g., sum, mean, min, max, or product). + The message between two nodes connected by an edge is computed as the + product of the output of a MLP applied to the norm of the distance of the + node positions, and of another MLP applied to the node features. Messages + are then aggregated using an aggregation scheme (e.g., sum, mean, min, max, + or product). - The update step is performed by a simple addition of the incoming messages - to the node features. + The update step is performed by applying another MLP to the concatenation of + the incoming messages and the node features. .. seealso:: - **Original reference** Schütt, K., Kindermans, P. J., Sauceda Felix, H. E., Chmiela, S., Tkatchenko, A., & Müller, K. R. (2017). - Schnet: A continuous-filter convolutional neural network for modeling quantum interactions. - Advances in neural information processing systems, 30. + **Original reference** Schütt, K., Kindermans, P. J., Sauceda Felix, + H. E., Chmiela, S., Tkatchenko, A., Müller, K. R. (2017). + *Schnet: A continuous-filter convolutional neural network for modeling + quantum interactions.* + Advances in Neural Information Processing Systems, 30. + DOI: ``_. """ - - def __init__( self, node_feature_dim, - node_pos_dim, - hidden_dim, - radial_hidden_dim=16, + hidden_dim=64, n_message_layers=2, n_update_layers=2, n_radial_layers=2, - activation=torch.nn.ReLU, + activation=torch.nn.SiLU, aggr="add", node_dim=-2, flow="source_to_target", @@ -49,9 +51,16 @@ def __init__( Initialization of the :class:`SchnetBlock` class. :param int node_feature_dim: The dimension of the node features. - :param int edge_feature_dim: The dimension of the edge features. + :param int hidden_dim: The dimension of the hidden features. + Default is 64. + :param int n_message_layers: The number of layers in the message + network. Default is 2. + :param int n_update_layers: The number of layers in the update network. + Default is 2. + :param int n_radial_layers: The number of layers in the radial field + network. Default is 2. :param torch.nn.Module activation: The activation function. - Default is :class:`torch.nn.Tanh`. + Default is :class:`torch.nn.SiLU`. :param str aggr: The aggregation scheme to use for message passing. Available options are "add", "mean", "min", "max", "mul". See :class:`torch_geometric.nn.MessagePassing` for more details. @@ -64,53 +73,47 @@ def __init__( flow means that messages are sent from the target node to the source node. See :class:`torch_geometric.nn.MessagePassing` for more details. Default is "source_to_target". - :raises ValueError: If `node_feature_dim` is not a positive integer. - :raises ValueError: If `edge_feature_dim` is not a positive integer. + :raises AssertionError: If `node_feature_dim` is not a positive integer. + :raises AssertionError: If `hidden_dim` is not a positive integer. + :raises AssertionError: If `n_message_layers` is not a positive integer. + :raises AssertionError: If `n_update_layers` is not a positive integer. + :raises AssertionError: If `n_radial_layers` is not a positive integer. """ super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) - # Check consistency - check_consistency(node_feature_dim, int) - # Check values - if node_feature_dim <= 0: - raise ValueError( - "`node_feature_dim` must be a positive integer," - f" got {node_feature_dim}." - ) - - - # Initialize parameters - self.node_feature_dim = node_feature_dim - self.node_pos_dim = node_pos_dim - self.hidden_dim = hidden_dim - self.activation = activation - - # Layer for processing node features - self.radial_field = FeedForward( + check_positive_integer(node_feature_dim, strict=True) + check_positive_integer(hidden_dim, strict=True) + check_positive_integer(n_message_layers, strict=True) + check_positive_integer(n_update_layers, strict=True) + check_positive_integer(n_radial_layers, strict=True) + + # Layer for processing node distances + self.radial_net = FeedForward( input_dimensions=1, output_dimensions=1, - inner_size=radial_hidden_dim, + inner_size=hidden_dim, n_layers=n_radial_layers, - func=self.activation, - ) - - self.update_net = FeedForward( - input_dimensions=self.node_pos_dim + self.hidden_dim, - output_dimensions=self.hidden_dim, - inner_size=self.hidden_dim, - n_layers=n_update_layers, - func=self.activation, + func=activation, ) + # Layer for computing the message self.message_net = FeedForward( - input_dimensions=self.node_feature_dim, - output_dimensions=self.node_pos_dim + self.hidden_dim, - inner_size=self.hidden_dim, + input_dimensions=node_feature_dim, + output_dimensions=node_feature_dim, + inner_size=hidden_dim, n_layers=n_message_layers, - func=self.activation, + func=activation, ) + # Layer for updating the node features + self.update_net = FeedForward( + input_dimensions=2 * node_feature_dim, + output_dimensions=node_feature_dim, + inner_size=hidden_dim, + n_layers=n_update_layers, + func=activation, + ) def forward(self, x, pos, edge_index): """ @@ -118,36 +121,38 @@ def forward(self, x, pos, edge_index): :param x: The node features. :type x: torch.Tensor | LabelTensor - :param torch.Tensor edge_index: The edge indices. In the original formulation, - the messages are aggregated from all nodes, not only from the neighbours. + :param torch.Tensor edge_index: The edge indices. :return: The updated node features. :rtype: torch.Tensor """ + edge_index, _ = remove_self_loops(edge_index) return self.propagate(edge_index=edge_index, x=x, pos=pos) - def message(self, x_i, pos_i ,pos_j): + def message(self, x_i, pos_i, pos_j): """ Compute the message to be passed between nodes and edges. - :param x_j: Node features of the sender nodes. - :type x_j: torch.Tensor | LabelTensor - :param edge_attr: The edge attributes. - :type edge_attr: torch.Tensor | LabelTensor + :param x_i: Node features of the sender nodes. + :type x_i: torch.Tensor | LabelTensor + :param pos_i: The node coordinates of the recipient nodes. + :type pos_i: torch.Tensor | LabelTensor + :param pos_j: The node coordinates of the sender nodes. + :type pos_j: torch.Tensor | LabelTensor :return: The message to be passed. :rtype: torch.Tensor - """ - - return self.radial_field(torch.norm(pos_i-pos_j))*self.message_net(x_i) - + """ + rad = self.radial_net(torch.norm(pos_i - pos_j, dim=-1, keepdim=True)) + msg = self.message_net(x_i) + return rad * msg - def update(self, message, pos): + def update(self, message, x): """ Update the node features with the received messages. :param torch.Tensor message: The message to be passed. :param x: The node features. :type x: torch.Tensor | LabelTensor - :return: The concatenation of the update position features and the updated node features. + :return: The updated node features. :rtype: torch.Tensor """ - return self.update_net(torch.cat((pos, message), dim=-1)) + return self.update_net(torch.cat((x, message), dim=-1)) diff --git a/pina/utils.py b/pina/utils.py index e3126de45..569ba632c 100644 --- a/pina/utils.py +++ b/pina/utils.py @@ -193,3 +193,22 @@ def chebyshev_roots(n): k = torch.arange(n) nodes = torch.sort(torch.cos(pi * (k + 0.5) / n))[0] return nodes + + +def check_positive_integer(value, strict=True): + """ + Check if the value is a positive integer. + + :param int value: The value to check. + :param bool strict: If True, the value must be strictly positive. + Default is True. + :raises AssertionError: If the value is not a positive integer. + """ + if strict: + assert ( + isinstance(value, int) and value > 0 + ), f"Expected a strictly positive integer, got {value}." + else: + assert ( + isinstance(value, int) and value >= 0 + ), f"Expected a non-negative integer, got {value}." diff --git a/tests/test_messagepassing/test_deep_tensor_network_block.py b/tests/test_messagepassing/test_deep_tensor_network_block.py new file mode 100644 index 000000000..aa295d2db --- /dev/null +++ b/tests/test_messagepassing/test_deep_tensor_network_block.py @@ -0,0 +1,59 @@ +import pytest +import torch +from pina.model.block.message_passing import DeepTensorNetworkBlock + +# Data for testing +x = torch.rand(10, 3) +edge_index = torch.randint(0, 10, (2, 20)) +edge_attr = torch.randn(20, 2) + + +@pytest.mark.parametrize("node_feature_dim", [1, 3]) +@pytest.mark.parametrize("edge_feature_dim", [3, 5]) +def test_constructor(node_feature_dim, edge_feature_dim): + + DeepTensorNetworkBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + ) + + # Should fail if node_feature_dim is negative + with pytest.raises(AssertionError): + DeepTensorNetworkBlock( + node_feature_dim=-1, edge_feature_dim=edge_feature_dim + ) + + # Should fail if edge_feature_dim is negative + with pytest.raises(AssertionError): + DeepTensorNetworkBlock( + node_feature_dim=node_feature_dim, edge_feature_dim=-1 + ) + + +def test_forward(): + + model = DeepTensorNetworkBlock( + node_feature_dim=x.shape[1], + edge_feature_dim=edge_attr.shape[1], + ) + + output_ = model(edge_index=edge_index, x=x, edge_attr=edge_attr) + assert output_.shape == x.shape + + +def test_backward(): + + model = DeepTensorNetworkBlock( + node_feature_dim=x.shape[1], + edge_feature_dim=edge_attr.shape[1], + ) + + output_ = model( + edge_index=edge_index, + x=x.requires_grad_(), + edge_attr=edge_attr.requires_grad_(), + ) + + loss = torch.mean(output_) + loss.backward() + assert x.grad.shape == x.shape diff --git a/tests/test_messagepassing/test_equivariant_network_block.py b/tests/test_messagepassing/test_equivariant_network_block.py new file mode 100644 index 000000000..c69d3a0ed --- /dev/null +++ b/tests/test_messagepassing/test_equivariant_network_block.py @@ -0,0 +1,130 @@ +import pytest +import torch +from pina.model.block.message_passing import EnEquivariantNetworkBlock + +# Data for testing +x = torch.rand(10, 4) +pos = torch.rand(10, 3) +edge_index = torch.randint(0, 10, (2, 20)) +edge_attr = torch.randn(20, 2) + + +@pytest.mark.parametrize("node_feature_dim", [1, 3]) +@pytest.mark.parametrize("edge_feature_dim", [0, 2]) +@pytest.mark.parametrize("pos_dim", [2, 3]) +def test_constructor(node_feature_dim, edge_feature_dim, pos_dim): + + EnEquivariantNetworkBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + pos_dim=pos_dim, + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + ) + + # Should fail if node_feature_dim is negative + with pytest.raises(AssertionError): + EnEquivariantNetworkBlock( + node_feature_dim=-1, + edge_feature_dim=edge_feature_dim, + pos_dim=pos_dim, + ) + + # Should fail if edge_feature_dim is negative + with pytest.raises(AssertionError): + EnEquivariantNetworkBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=-1, + pos_dim=pos_dim, + ) + + # Should fail if pos_dim is negative + with pytest.raises(AssertionError): + EnEquivariantNetworkBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + pos_dim=-1, + ) + + # Should fail if hidden_dim is negative + with pytest.raises(AssertionError): + EnEquivariantNetworkBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + pos_dim=pos_dim, + hidden_dim=-1, + ) + + # Should fail if n_message_layers is negative + with pytest.raises(AssertionError): + EnEquivariantNetworkBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + pos_dim=pos_dim, + n_message_layers=-1, + ) + + # Should fail if n_update_layers is negative + with pytest.raises(AssertionError): + EnEquivariantNetworkBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + pos_dim=pos_dim, + n_update_layers=-1, + ) + + +@pytest.mark.parametrize("edge_feature_dim", [0, 2]) +def test_forward(edge_feature_dim): + + model = EnEquivariantNetworkBlock( + node_feature_dim=x.shape[1], + edge_feature_dim=edge_feature_dim, + pos_dim=pos.shape[1], + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + ) + + if edge_feature_dim == 0: + output_ = model(edge_index=edge_index, x=x, pos=pos) + else: + output_ = model( + edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr + ) + + assert output_[0].shape == x.shape + assert output_[1].shape == pos.shape + + +@pytest.mark.parametrize("edge_feature_dim", [0, 2]) +def test_backward(edge_feature_dim): + + model = EnEquivariantNetworkBlock( + node_feature_dim=x.shape[1], + edge_feature_dim=edge_feature_dim, + pos_dim=pos.shape[1], + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + ) + + if edge_feature_dim == 0: + output_ = model( + edge_index=edge_index, + x=x.requires_grad_(), + pos=pos.requires_grad_(), + ) + else: + output_ = model( + edge_index=edge_index, + x=x.requires_grad_(), + pos=pos.requires_grad_(), + edge_attr=edge_attr.requires_grad_(), + ) + + loss = torch.mean(output_[0]) + loss.backward() + assert x.grad.shape == x.shape + assert pos.grad.shape == pos.shape diff --git a/tests/test_messagepassing/test_interaction_network_block.py b/tests/test_messagepassing/test_interaction_network_block.py new file mode 100644 index 000000000..d121fb173 --- /dev/null +++ b/tests/test_messagepassing/test_interaction_network_block.py @@ -0,0 +1,84 @@ +import pytest +import torch +from pina.model.block.message_passing import InteractionNetworkBlock + +# Data for testing +x = torch.rand(10, 3) +edge_index = torch.randint(0, 10, (2, 20)) +edge_attr = torch.randn(20, 2) + + +@pytest.mark.parametrize("node_feature_dim", [1, 3]) +@pytest.mark.parametrize("edge_feature_dim", [0, 2]) +def test_constructor(node_feature_dim, edge_feature_dim): + + InteractionNetworkBlock( + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + ) + + # Should fail if node_feature_dim is negative + with pytest.raises(AssertionError): + InteractionNetworkBlock(node_feature_dim=-1) + + # Should fail if edge_feature_dim is negative + with pytest.raises(AssertionError): + InteractionNetworkBlock(node_feature_dim=3, edge_feature_dim=-1) + + # Should fail if hidden_dim is negative + with pytest.raises(AssertionError): + InteractionNetworkBlock(node_feature_dim=3, hidden_dim=-1) + + # Should fail if n_message_layers is negative + with pytest.raises(AssertionError): + InteractionNetworkBlock(node_feature_dim=3, n_message_layers=-1) + + # Should fail if n_update_layers is negative + with pytest.raises(AssertionError): + InteractionNetworkBlock(node_feature_dim=3, n_update_layers=-1) + + +@pytest.mark.parametrize("edge_feature_dim", [0, 2]) +def test_forward(edge_feature_dim): + + model = InteractionNetworkBlock( + node_feature_dim=x.shape[1], + edge_feature_dim=edge_feature_dim, + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + ) + + if edge_feature_dim == 0: + output_ = model(edge_index=edge_index, x=x) + else: + output_ = model(edge_index=edge_index, x=x, edge_attr=edge_attr) + assert output_.shape == x.shape + + +@pytest.mark.parametrize("edge_feature_dim", [0, 2]) +def test_backward(edge_feature_dim): + + model = InteractionNetworkBlock( + node_feature_dim=x.shape[1], + edge_feature_dim=edge_feature_dim, + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + ) + + if edge_feature_dim == 0: + output_ = model(edge_index=edge_index, x=x.requires_grad_()) + else: + output_ = model( + edge_index=edge_index, + x=x.requires_grad_(), + edge_attr=edge_attr.requires_grad_(), + ) + + loss = torch.mean(output_) + loss.backward() + assert x.grad.shape == x.shape diff --git a/tests/test_messagepassing/test_radial_field_network_block.py b/tests/test_messagepassing/test_radial_field_network_block.py new file mode 100644 index 000000000..97c6cb797 --- /dev/null +++ b/tests/test_messagepassing/test_radial_field_network_block.py @@ -0,0 +1,67 @@ +import pytest +import torch +from pina.model.block.message_passing import RadialFieldNetworkBlock + +# Data for testing +x = torch.rand(10, 3) +edge_index = torch.randint(0, 10, (2, 20)) + + +@pytest.mark.parametrize("node_feature_dim", [1, 3]) +def test_constructor(node_feature_dim): + + RadialFieldNetworkBlock( + node_feature_dim=node_feature_dim, + hidden_dim=64, + n_layers=2, + ) + + # Should fail if node_feature_dim is negative + with pytest.raises(AssertionError): + RadialFieldNetworkBlock( + node_feature_dim=-1, + hidden_dim=64, + n_layers=2, + ) + + # Should fail if hidden_dim is negative + with pytest.raises(AssertionError): + RadialFieldNetworkBlock( + node_feature_dim=node_feature_dim, + hidden_dim=-1, + n_layers=2, + ) + + # Should fail if n_layers is negative + with pytest.raises(AssertionError): + RadialFieldNetworkBlock( + node_feature_dim=node_feature_dim, + hidden_dim=64, + n_layers=-1, + ) + + +def test_forward(): + + model = RadialFieldNetworkBlock( + node_feature_dim=x.shape[1], + hidden_dim=64, + n_layers=2, + ) + + output_ = model(edge_index=edge_index, x=x) + assert output_.shape == x.shape + + +def test_backward(): + + model = RadialFieldNetworkBlock( + node_feature_dim=x.shape[1], + hidden_dim=64, + n_layers=2, + ) + + output_ = model(edge_index=edge_index, x=x.requires_grad_()) + loss = torch.mean(output_) + loss.backward() + assert x.grad.shape == x.shape diff --git a/tests/test_messagepassing/test_schnet_block.py b/tests/test_messagepassing/test_schnet_block.py new file mode 100644 index 000000000..ddd84f97f --- /dev/null +++ b/tests/test_messagepassing/test_schnet_block.py @@ -0,0 +1,73 @@ +import pytest +import torch +from pina.model.block.message_passing import SchnetBlock + +# Data for testing +x = torch.rand(10, 4) +pos = torch.rand(10, 3) +edge_index = torch.randint(0, 10, (2, 20)) + + +@pytest.mark.parametrize("node_feature_dim", [1, 3]) +def test_constructor(node_feature_dim): + + SchnetBlock( + node_feature_dim=node_feature_dim, + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + n_radial_layers=2, + ) + + # Should fail if node_feature_dim is negative + with pytest.raises(AssertionError): + SchnetBlock(node_feature_dim=-1) + + # Should fail if hidden_dim is negative + with pytest.raises(AssertionError): + SchnetBlock(node_feature_dim=node_feature_dim, hidden_dim=-1) + + # Should fail if n_message_layers is negative + with pytest.raises(AssertionError): + SchnetBlock(node_feature_dim=node_feature_dim, n_message_layers=-1) + + # Should fail if n_update_layers is negative + with pytest.raises(AssertionError): + SchnetBlock(node_feature_dim=node_feature_dim, n_update_layers=-1) + + # Should fail if n_radial_layers is negative + with pytest.raises(AssertionError): + SchnetBlock(node_feature_dim=node_feature_dim, n_radial_layers=-1) + + +def test_forward(): + + model = SchnetBlock( + node_feature_dim=x.shape[1], + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + n_radial_layers=2, + ) + + output_ = model(edge_index=edge_index, x=x, pos=pos) + assert output_.shape == x.shape + + +def test_backward(): + + model = SchnetBlock( + node_feature_dim=x.shape[1], + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + n_radial_layers=2, + ) + + output_ = model( + edge_index=edge_index, x=x.requires_grad_(), pos=pos.requires_grad_() + ) + + loss = torch.mean(output_) + loss.backward() + assert x.grad.shape == x.shape diff --git a/tests/test_utils.py b/tests/test_utils.py index a641c3838..7e8518995 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,12 +1,9 @@ import torch +import pytest -from pina.utils import merge_tensors -from pina.label_tensor import LabelTensor from pina import LabelTensor -from pina.domain import EllipsoidDomain, CartesianDomain -from pina.utils import check_consistency -import pytest -from pina.domain import DomainInterface +from pina.utils import merge_tensors, check_consistency, check_positive_integer +from pina.domain import EllipsoidDomain, CartesianDomain, DomainInterface def test_merge_tensors(): @@ -50,3 +47,24 @@ def test_check_consistency_incorrect(): check_consistency(torch.Tensor, DomainInterface, subclass=True) with pytest.raises(ValueError): check_consistency(ellipsoid1, torch.Tensor) + + +@pytest.mark.parametrize("value", [0, 1, 2, 3, 10]) +@pytest.mark.parametrize("strict", [True, False]) +def test_check_positive_integer(value, strict): + if value != 0: + check_positive_integer(value, strict=strict) + else: + check_positive_integer(value, strict=False) + + # Should fail if value is negative + with pytest.raises(AssertionError): + check_positive_integer(-1, strict=strict) + + # Should fail if value is not an integer + with pytest.raises(AssertionError): + check_positive_integer(1.5, strict=strict) + + # Should fail if value is not a number + with pytest.raises(AssertionError): + check_positive_integer("string", strict=strict) From fbc0382cc50d211dd3c1bcaa9406b0fd49794902 Mon Sep 17 00:00:00 2001 From: giovanni Date: Sun, 1 Jun 2025 10:26:05 +0200 Subject: [PATCH 8/8] fix egnn + equivariance/invariance tests Co-authored-by: Dario Coscia --- .../en_equivariant_network_block.py | 77 ++++++++++++++++--- .../test_equivariant_network_block.py | 35 +++++++++ .../test_radial_field_network_block.py | 25 ++++++ .../test_messagepassing/test_schnet_block.py | 22 ++++++ 4 files changed, 147 insertions(+), 12 deletions(-) diff --git a/pina/model/block/message_passing/en_equivariant_network_block.py b/pina/model/block/message_passing/en_equivariant_network_block.py index fa256e9d5..904c1c6c9 100644 --- a/pina/model/block/message_passing/en_equivariant_network_block.py +++ b/pina/model/block/message_passing/en_equivariant_network_block.py @@ -10,7 +10,6 @@ class EnEquivariantNetworkBlock(MessagePassing): """ Implementation of the E(n) Equivariant Graph Neural Network block. - This block is used to perform message-passing between nodes and edges in a graph neural network, following the scheme proposed by Satorras et al. in 2021. It serves as an inner block in a larger graph neural network @@ -102,7 +101,7 @@ def __init__( ) # Layer for updating the node features - self.update_net = FeedForward( + self.update_feat_net = FeedForward( input_dimensions=node_feature_dim + pos_dim, output_dimensions=node_feature_dim, inner_size=hidden_dim, @@ -110,6 +109,16 @@ def __init__( func=activation, ) + # Layer for updating the node positions + # The output dimension is set to 1 for equivariant updates + self.update_pos_net = FeedForward( + input_dimensions=pos_dim, + output_dimensions=1, + inner_size=hidden_dim, + n_layers=n_update_layers, + func=activation, + ) + def forward(self, x, pos, edge_index, edge_attr=None): """ Forward pass of the block, triggering the message-passing routine. @@ -143,22 +152,62 @@ def message(self, x_i, x_j, pos_i, pos_j, edge_attr): :param edge_attr: The edge attributes. :type edge_attr: torch.Tensor | LabelTensor :return: The message to be passed. - :rtype: torch.Tensor + :rtype: tuple(torch.Tensor, torch.Tensor) """ - dist = torch.norm(pos_i - pos_j, dim=-1, keepdim=True) ** 2 + # Compute the euclidean distance between the sender and recipient nodes + diff = pos_i - pos_j + dist = torch.norm(diff, dim=-1, keepdim=True) ** 2 + + # Compute the message input if edge_attr is None: input_ = torch.cat((x_i, x_j, dist), dim=-1) else: input_ = torch.cat((x_i, x_j, dist, edge_attr), dim=-1) - return self.message_net(input_) + # Compute the messages and their equivariant counterpart + m_ij = self.message_net(input_) + message = diff * self.update_pos_net(m_ij) + + return message, m_ij - def update(self, message, x, pos, edge_index): + def aggregate(self, inputs, index, ptr=None, dim_size=None): + """ + Aggregate the messages at the nodes during message passing. + + This method receives a tuple of tensors corresponding to the messages + to be aggregated. Both messages are aggregated separately according to + the specified aggregation scheme. + + :param tuple(torch.Tensor) inputs: Tuple containing two messages to + aggregate. + :param index: The indices of target nodes for each message. This tensor + specifies which node each message is aggregated into. + :type index: torch.Tensor | LabelTensor + :param ptr: Optional tensor to specify the slices of messages for each + node (used in some aggregation strategies). Default is None. + :type ptr: torch.Tensor | LabelTensor + :param int dim_size: Optional size of the output dimension, i.e., + number of nodes. Default is None. + :return: Tuple of aggregated tensors corresponding to (aggregated + messages for position updates, aggregated messages for feature + updates). + :rtype: tuple(torch.Tensor, torch.Tensor) + """ + # Unpack the messages from the inputs + message, m_ij = inputs + + # Aggregate messages as usual using self.aggr method + agg_message = super().aggregate(message, index, ptr, dim_size) + agg_m_ij = super().aggregate(m_ij, index, ptr, dim_size) + + return agg_message, agg_m_ij + + def update(self, aggregated_inputs, x, pos, edge_index): """ Update the node features and the node coordinates with the received messages. - :param torch.Tensor message: The message to be passed. + :param tuple(torch.Tensor) aggregated_inputs: The messages to be passed. :param x: The node features. :type x: torch.Tensor | LabelTensor :param pos: The euclidean coordinates of the nodes. @@ -167,10 +216,14 @@ def update(self, message, x, pos, edge_index): :return: The updated node features and node positions. :rtype: tuple(torch.Tensor, torch.Tensor) """ - # Update the node features - x = self.update_net(torch.cat((x, message), dim=-1)) + # aggregated_inputs is tuple (agg_message, agg_m_ij) + agg_message, agg_m_ij = aggregated_inputs + + # Update node features with aggregated m_ij + x = self.update_feat_net(torch.cat((x, agg_m_ij), dim=-1)) + + # Degree for normalization of position updates + c = degree(edge_index[1], pos.shape[0]).unsqueeze(-1).clamp(min=1) + pos = pos + agg_message / c - # Update the node positions - c = degree(edge_index[0], pos.shape[0]).unsqueeze(-1) - pos = pos + message / c return x, pos diff --git a/tests/test_messagepassing/test_equivariant_network_block.py b/tests/test_messagepassing/test_equivariant_network_block.py index c69d3a0ed..eea000a0e 100644 --- a/tests/test_messagepassing/test_equivariant_network_block.py +++ b/tests/test_messagepassing/test_equivariant_network_block.py @@ -128,3 +128,38 @@ def test_backward(edge_feature_dim): loss.backward() assert x.grad.shape == x.shape assert pos.grad.shape == pos.shape + + +def test_equivariance(): + + # Graph to be fully connected and undirected + edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T + edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1) + + # Random rotation (det(rotation) should be 1) + rotation = torch.linalg.qr(torch.rand(pos.shape[-1], pos.shape[-1])).Q + if torch.det(rotation) < 0: + rotation[:, 0] *= -1 + + # Random translation + translation = torch.rand(1, pos.shape[-1]) + + model = EnEquivariantNetworkBlock( + node_feature_dim=x.shape[1], + edge_feature_dim=0, + pos_dim=pos.shape[1], + hidden_dim=64, + n_message_layers=2, + n_update_layers=2, + ).eval() + + h1, pos1 = model(edge_index=edge_index, x=x, pos=pos) + h2, pos2 = model( + edge_index=edge_index, x=x, pos=pos @ rotation.T + translation + ) + + # Transform model output + pos1_transformed = (pos1 @ rotation.T) + translation + + assert torch.allclose(pos2, pos1_transformed, atol=1e-5) + assert torch.allclose(h1, h2, atol=1e-5) diff --git a/tests/test_messagepassing/test_radial_field_network_block.py b/tests/test_messagepassing/test_radial_field_network_block.py index 97c6cb797..4632ebfc9 100644 --- a/tests/test_messagepassing/test_radial_field_network_block.py +++ b/tests/test_messagepassing/test_radial_field_network_block.py @@ -65,3 +65,28 @@ def test_backward(): loss = torch.mean(output_) loss.backward() assert x.grad.shape == x.shape + + +def test_equivariance(): + + # Graph to be fully connected and undirected + edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T + edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1) + + # Random rotation (det(rotation) should be 1) + rotation = torch.linalg.qr(torch.rand(x.shape[-1], x.shape[-1])).Q + if torch.det(rotation) < 0: + rotation[:, 0] *= -1 + + # Random translation + translation = torch.rand(1, x.shape[-1]) + + model = RadialFieldNetworkBlock(node_feature_dim=x.shape[1]).eval() + + pos1 = model(edge_index=edge_index, x=x) + pos2 = model(edge_index=edge_index, x=x @ rotation.T + translation) + + # Transform model output + pos1_transformed = (pos1 @ rotation.T) + translation + + assert torch.allclose(pos2, pos1_transformed, atol=1e-5) diff --git a/tests/test_messagepassing/test_schnet_block.py b/tests/test_messagepassing/test_schnet_block.py index ddd84f97f..51073b0f3 100644 --- a/tests/test_messagepassing/test_schnet_block.py +++ b/tests/test_messagepassing/test_schnet_block.py @@ -71,3 +71,25 @@ def test_backward(): loss = torch.mean(output_) loss.backward() assert x.grad.shape == x.shape + + +def test_invariance(): + + # Graph to be fully connected and undirected + edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T + edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1) + + # Random rotation (det(rotation) should be 1) + rotation = torch.linalg.qr(torch.rand(pos.shape[-1], pos.shape[-1])).Q + if torch.det(rotation) < 0: + rotation[:, 0] *= -1 + + # Random translation + translation = torch.rand(1, pos.shape[-1]) + + model = SchnetBlock(node_feature_dim=x.shape[1]).eval() + + out1 = model(edge_index=edge_index, x=x, pos=pos) + out2 = model(edge_index=edge_index, x=x, pos=pos @ rotation.T + translation) + + assert torch.allclose(out1, out2, atol=1e-5)