|
1 | 1 | import torch |
2 | | -import torch.nn as nn |
3 | 2 | from torch_geometric.nn import MessagePassing |
4 | 3 | from torch_geometric.utils import degree |
5 | | -from ....utils import check_consistency |
6 | 4 |
|
7 | 5 |
|
8 | 6 | class EnEquivariantGraphBlock(MessagePassing): |
9 | | - def __init__(self, |
10 | | - channels_h, |
11 | | - channels_m, |
12 | | - channels_a, |
13 | | - aggr: str = 'add', |
14 | | - hidden_channels: int = 64, |
15 | | - **kwargs): |
| 7 | + def __init__( |
| 8 | + self, |
| 9 | + channels_h, |
| 10 | + channels_m, |
| 11 | + channels_a, |
| 12 | + aggr: str = "add", |
| 13 | + hidden_channels: int = 64, |
| 14 | + **kwargs, |
| 15 | + ): |
16 | 16 | super().__init__(aggr=aggr, **kwargs) |
17 | 17 |
|
18 | | - self.phi_e = nn.Sequential( |
19 | | - nn.Linear(2 * channels_h + 1 + channels_a, hidden_channels), |
20 | | - nn.LayerNorm(hidden_channels), |
21 | | - nn.SiLU(), |
22 | | - nn.Linear(hidden_channels, channels_m), |
23 | | - nn.LayerNorm(channels_m), |
24 | | - nn.SiLU() |
| 18 | + self.phi_e = torch.nn.Sequential( |
| 19 | + torch.nn.Linear(2 * channels_h + 1 + channels_a, hidden_channels), |
| 20 | + torch.nn.LayerNorm(hidden_channels), |
| 21 | + torch.nn.SiLU(), |
| 22 | + torch.nn.Linear(hidden_channels, channels_m), |
| 23 | + torch.nn.LayerNorm(channels_m), |
| 24 | + torch.nn.SiLU(), |
25 | 25 | ) |
26 | | - self.phi_x = nn.Sequential( |
27 | | - nn.Linear(channels_m, hidden_channels), |
28 | | - nn.LayerNorm(hidden_channels), |
29 | | - nn.SiLU(), |
30 | | - nn.Linear(hidden_channels, 1), |
| 26 | + self.phi_x = torch.nn.Sequential( |
| 27 | + torch.nn.Linear(channels_m, hidden_channels), |
| 28 | + torch.nn.LayerNorm(hidden_channels), |
| 29 | + torch.nn.SiLU(), |
| 30 | + torch.nn.Linear(hidden_channels, 1), |
| 31 | + ) |
| 32 | + self.phi_h = torch.nn.Sequential( |
| 33 | + torch.nn.Linear(channels_h + channels_m, hidden_channels), |
| 34 | + torch.nn.LayerNorm(hidden_channels), |
| 35 | + torch.nn.SiLU(), |
| 36 | + torch.nn.Linear(hidden_channels, channels_h), |
31 | 37 | ) |
32 | | - self.phi_h = nn.Sequential( |
33 | | - nn.Linear(channels_h + channels_m, hidden_channels), |
34 | | - nn.LayerNorm(hidden_channels), |
35 | | - nn.SiLU(), |
36 | | - nn.Linear(hidden_channels, channels_h), |
37 | | - ) |
38 | 38 |
|
39 | 39 | def forward(self, x, h, edge_attr, edge_index, c=None): |
40 | 40 | if c is None: |
41 | 41 | c = degree(edge_index[0], x.shape[0]).unsqueeze(-1) |
42 | | - return self.propagate(edge_index=edge_index, x=x, h=h, edge_attr=edge_attr, c=c) |
| 42 | + return self.propagate( |
| 43 | + edge_index=edge_index, x=x, h=h, edge_attr=edge_attr, c=c |
| 44 | + ) |
43 | 45 |
|
44 | 46 | def message(self, x_i, x_j, h_i, h_j, edge_attr): |
45 | | - mh_ij = self.phi_e(torch.cat([h_i, h_j, torch.norm(x_i - x_j, dim=-1, keepdim=True)**2, edge_attr], dim=-1)) |
| 47 | + mh_ij = self.phi_e( |
| 48 | + torch.cat( |
| 49 | + [ |
| 50 | + h_i, |
| 51 | + h_j, |
| 52 | + torch.norm(x_i - x_j, dim=-1, keepdim=True) ** 2, |
| 53 | + edge_attr, |
| 54 | + ], |
| 55 | + dim=-1, |
| 56 | + ) |
| 57 | + ) |
46 | 58 | mx_ij = (x_i - x_j) * self.phi_x(mh_ij) |
47 | 59 | return torch.cat((mx_ij, mh_ij), dim=-1) |
48 | 60 |
|
49 | 61 | def update(self, aggr_out, x, h, edge_attr, c): |
50 | | - m_x, m_h = aggr_out[:, :self.m_len], aggr_out[:, self.m_len:] |
| 62 | + m_x, m_h = aggr_out[:, : self.m_len], aggr_out[:, self.m_len :] |
51 | 63 | h_l1 = self.phi_h(torch.cat([h, m_h], dim=-1)) |
52 | 64 | x_l1 = x + (m_x / c) |
53 | 65 | return x_l1, h_l1 |
|
0 commit comments