|
7 | 7 |
|
8 | 8 | class EnEquivariantGraphBlock(MessagePassing): |
9 | 9 | """ |
10 | | - TODO |
| 10 | + Implementation of the E(n) Equivariant Graph Neural Network block. |
| 11 | +
|
| 12 | + This block is used to perform message-passing between nodes and edges in a |
| 13 | + graph neural network, following the scheme proposed by Satorras et al. (2021). |
| 14 | + It serves as an inner block in a larger graph neural network architecture. |
| 15 | +
|
| 16 | + The message between two nodes connected by an edge is computed by applying a |
| 17 | + linear transformation to the sender node features and the edge features, |
| 18 | + followed by a non-linear activation function. Messages are then aggregated |
| 19 | + using an aggregation scheme (e.g., sum, mean, min, max, or product). |
| 20 | +
|
| 21 | + The update step is performed by a simple addition of the incoming messages |
| 22 | + to the node features. |
| 23 | +
|
| 24 | + .. seealso:: |
| 25 | +
|
| 26 | + **Original reference** Satorras, V. G., Hoogeboom, E., & Welling, M. (2021, July). |
| 27 | + E (n) equivariant graph neural networks. |
| 28 | + In International conference on machine learning (pp. 9323-9332). PMLR. |
11 | 29 | """ |
12 | 30 |
|
13 | 31 | def __init__( |
14 | 32 | self, |
15 | | - channels_h, |
| 33 | + channels_x, |
16 | 34 | channels_m, |
17 | 35 | channels_a, |
18 | 36 | aggr: str = "add", |
19 | 37 | hidden_channels: int = 64, |
20 | 38 | **kwargs, |
21 | 39 | ): |
22 | 40 | """ |
23 | | - TODO |
| 41 | + Initialization of the :class:`EnEquivariantGraphBlock` class. |
| 42 | +
|
| 43 | + :param int channels_x: The dimension of the node features. |
| 44 | + :param int channels_m: The dimension of the Euclidean coordinates (should be =3). |
| 45 | + :param int channels_a: The dimension of the edge features. |
| 46 | + :param str aggr: The aggregation scheme to use for message passing. |
| 47 | + Available options are "add", "mean", "min", "max", "mul". |
| 48 | + See :class:`torch_geometric.nn.MessagePassing` for more details. |
| 49 | + Default is "add". |
| 50 | + :param int hidden_channels_dim: The hidden dimension in each MLPs initialized in the block. |
24 | 51 | """ |
25 | 52 | super().__init__(aggr=aggr, **kwargs) |
26 | 53 |
|
27 | 54 | self.phi_e = torch.nn.Sequential( |
28 | | - torch.nn.Linear(2 * channels_h + 1 + channels_a, hidden_channels), |
| 55 | + torch.nn.Linear(2 * channels_x + 1 + channels_a, hidden_channels), |
29 | 56 | torch.nn.LayerNorm(hidden_channels), |
30 | 57 | torch.nn.SiLU(), |
31 | 58 | torch.nn.Linear(hidden_channels, channels_m), |
32 | 59 | torch.nn.LayerNorm(channels_m), |
33 | 60 | torch.nn.SiLU(), |
34 | 61 | ) |
35 | | - self.phi_x = torch.nn.Sequential( |
| 62 | + self.phi_pos = torch.nn.Sequential( |
36 | 63 | torch.nn.Linear(channels_m, hidden_channels), |
37 | 64 | torch.nn.LayerNorm(hidden_channels), |
38 | 65 | torch.nn.SiLU(), |
39 | 66 | torch.nn.Linear(hidden_channels, 1), |
40 | 67 | ) |
41 | | - self.phi_h = torch.nn.Sequential( |
42 | | - torch.nn.Linear(channels_h + channels_m, hidden_channels), |
| 68 | + self.phi_x = torch.nn.Sequential( |
| 69 | + torch.nn.Linear(channels_x + channels_m, hidden_channels), |
43 | 70 | torch.nn.LayerNorm(hidden_channels), |
44 | 71 | torch.nn.SiLU(), |
45 | | - torch.nn.Linear(hidden_channels, channels_h), |
| 72 | + torch.nn.Linear(hidden_channels, channels_x), |
46 | 73 | ) |
47 | 74 |
|
48 | | - def forward(self, x, h, edge_attr, edge_index, c=None): |
| 75 | + def forward(self, x, pos, edge_attr, edge_index, c=None): |
49 | 76 | """ |
50 | | - TODO |
| 77 | + Forward pass of the block, triggering the message-passing routine. |
| 78 | +
|
| 79 | + :param x: The node features. |
| 80 | + :type x: torch.Tensor | LabelTensor |
| 81 | + :param pos_i: 3D Euclidean coordinates. |
| 82 | + :type pos_i: torch.Tensor | LabelTensor |
| 83 | + :param torch.Tensor edge_index: The edge indices. In the original formulation, |
| 84 | + the messages are aggregated from all nodes, not only from the neighbours. |
| 85 | + :return: The updated node features. |
| 86 | + :rtype: torch.Tensor |
51 | 87 | """ |
52 | 88 | if c is None: |
53 | | - c = degree(edge_index[0], x.shape[0]).unsqueeze(-1) |
| 89 | + c = degree(edge_index[0], pos.shape[0]).unsqueeze(-1) |
54 | 90 | return self.propagate( |
55 | | - edge_index=edge_index, x=x, h=h, edge_attr=edge_attr, c=c |
| 91 | + edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr, c=c |
56 | 92 | ) |
57 | 93 |
|
58 | | - def message(self, x_i, x_j, h_i, h_j, edge_attr): |
| 94 | + def message(self, x_i, x_j, pos_i, pos_j, edge_attr): |
59 | 95 | """ |
60 | | - TODO |
| 96 | + Compute the message to be passed between nodes and edges. |
| 97 | +
|
| 98 | + :param x_i: Node features of the sender nodes. |
| 99 | + :type x_i: torch.Tensor | LabelTensor |
| 100 | + :param pos_i: 3D Euclidean coordinates of the sender nodes. |
| 101 | + :type pos_i: torch.Tensor | LabelTensor |
| 102 | + :param edge_attr: The edge attributes. |
| 103 | + :type edge_attr: torch.Tensor | LabelTensor |
| 104 | + :return: The message to be passed. |
| 105 | + :rtype: torch.Tensor |
61 | 106 | """ |
62 | | - mh_ij = self.phi_e( |
| 107 | + mpos_ij = self.phi_e( |
63 | 108 | torch.cat( |
64 | 109 | [ |
65 | | - h_i, |
66 | | - h_j, |
67 | | - torch.norm(x_i - x_j, dim=-1, keepdim=True) ** 2, |
| 110 | + x_i, |
| 111 | + x_j, |
| 112 | + torch.norm(pos_i - pos_j, dim=-1, keepdim=True) ** 2, |
68 | 113 | edge_attr, |
69 | 114 | ], |
70 | 115 | dim=-1, |
71 | 116 | ) |
72 | 117 | ) |
73 | | - mx_ij = (x_i - x_j) * self.phi_x(mh_ij) |
74 | | - return torch.cat((mx_ij, mh_ij), dim=-1) |
| 118 | + mpos_ij = (pos_i - pos_j) * self.phi_pos(mpos_ij) |
| 119 | + return mpos_ij |
75 | 120 |
|
76 | | - def update(self, aggr_out, x, h, edge_attr, c): |
| 121 | + def update(self, message, x, pos, c): |
77 | 122 | """ |
78 | | - TODO |
79 | | - """ |
80 | | - m_x, m_h = aggr_out[:, : self.m_len], aggr_out[:, self.m_len :] |
81 | | - h_l1 = self.phi_h(torch.cat([h, m_h], dim=-1)) |
82 | | - x_l1 = x + (m_x / c) |
83 | | - return x_l1, h_l1 |
| 123 | + Update the node features with the received messages. |
84 | 124 |
|
85 | | - @property |
86 | | - def edge_function(self): |
87 | | - """ |
88 | | - TODO |
89 | | - """ |
90 | | - return self._edge_function |
91 | | - |
92 | | - @property |
93 | | - def attribute_function(self): |
94 | | - """ |
95 | | - TODO |
| 125 | + :param torch.Tensor message: The message to be passed. |
| 126 | + :param x: The node features. |
| 127 | + :type x: torch.Tensor | LabelTensor |
| 128 | + :param pos: The 3D Euclidean coordinates of the nodes. |
| 129 | + :type pos: torch.Tensor | LabelTensor |
| 130 | + :param c: the constant that divides the aggregated message (it should be (M-1), where M is the number of nodes) |
| 131 | + :type pos: torch.Tensor |
| 132 | + :return: The concatenation of the update position features and the updated node features. |
| 133 | + :rtype: torch.Tensor |
96 | 134 | """ |
97 | | - return self._attribute_function |
| 135 | + x = self.phi_x(torch.cat([x, message], dim=-1)) |
| 136 | + pos = pos + (message / c) |
| 137 | + return pos, x |
0 commit comments