Skip to content

Commit fb0bdf0

Browse files
buggy egnn - fix linter
1 parent 23b3560 commit fb0bdf0

1 file changed

Lines changed: 42 additions & 30 deletions

File tree

pina/model/block/message_passing/egnn_block.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,65 @@
11
import torch
2-
import torch.nn as nn
32
from torch_geometric.nn import MessagePassing
43
from torch_geometric.utils import degree
5-
from ....utils import check_consistency
64

75

86
class EnEquivariantGraphBlock(MessagePassing):
9-
def __init__(self,
10-
channels_h,
11-
channels_m,
12-
channels_a,
13-
aggr: str = 'add',
14-
hidden_channels: int = 64,
15-
**kwargs):
7+
def __init__(
8+
self,
9+
channels_h,
10+
channels_m,
11+
channels_a,
12+
aggr: str = "add",
13+
hidden_channels: int = 64,
14+
**kwargs,
15+
):
1616
super().__init__(aggr=aggr, **kwargs)
1717

18-
self.phi_e = nn.Sequential(
19-
nn.Linear(2 * channels_h + 1 + channels_a, hidden_channels),
20-
nn.LayerNorm(hidden_channels),
21-
nn.SiLU(),
22-
nn.Linear(hidden_channels, channels_m),
23-
nn.LayerNorm(channels_m),
24-
nn.SiLU()
18+
self.phi_e = torch.nn.Sequential(
19+
torch.nn.Linear(2 * channels_h + 1 + channels_a, hidden_channels),
20+
torch.nn.LayerNorm(hidden_channels),
21+
torch.nn.SiLU(),
22+
torch.nn.Linear(hidden_channels, channels_m),
23+
torch.nn.LayerNorm(channels_m),
24+
torch.nn.SiLU(),
2525
)
26-
self.phi_x = nn.Sequential(
27-
nn.Linear(channels_m, hidden_channels),
28-
nn.LayerNorm(hidden_channels),
29-
nn.SiLU(),
30-
nn.Linear(hidden_channels, 1),
26+
self.phi_x = torch.nn.Sequential(
27+
torch.nn.Linear(channels_m, hidden_channels),
28+
torch.nn.LayerNorm(hidden_channels),
29+
torch.nn.SiLU(),
30+
torch.nn.Linear(hidden_channels, 1),
31+
)
32+
self.phi_h = torch.nn.Sequential(
33+
torch.nn.Linear(channels_h + channels_m, hidden_channels),
34+
torch.nn.LayerNorm(hidden_channels),
35+
torch.nn.SiLU(),
36+
torch.nn.Linear(hidden_channels, channels_h),
3137
)
32-
self.phi_h = nn.Sequential(
33-
nn.Linear(channels_h + channels_m, hidden_channels),
34-
nn.LayerNorm(hidden_channels),
35-
nn.SiLU(),
36-
nn.Linear(hidden_channels, channels_h),
37-
)
3838

3939
def forward(self, x, h, edge_attr, edge_index, c=None):
4040
if c is None:
4141
c = degree(edge_index[0], x.shape[0]).unsqueeze(-1)
42-
return self.propagate(edge_index=edge_index, x=x, h=h, edge_attr=edge_attr, c=c)
42+
return self.propagate(
43+
edge_index=edge_index, x=x, h=h, edge_attr=edge_attr, c=c
44+
)
4345

4446
def message(self, x_i, x_j, h_i, h_j, edge_attr):
45-
mh_ij = self.phi_e(torch.cat([h_i, h_j, torch.norm(x_i - x_j, dim=-1, keepdim=True)**2, edge_attr], dim=-1))
47+
mh_ij = self.phi_e(
48+
torch.cat(
49+
[
50+
h_i,
51+
h_j,
52+
torch.norm(x_i - x_j, dim=-1, keepdim=True) ** 2,
53+
edge_attr,
54+
],
55+
dim=-1,
56+
)
57+
)
4658
mx_ij = (x_i - x_j) * self.phi_x(mh_ij)
4759
return torch.cat((mx_ij, mh_ij), dim=-1)
4860

4961
def update(self, aggr_out, x, h, edge_attr, c):
50-
m_x, m_h = aggr_out[:, :self.m_len], aggr_out[:, self.m_len:]
62+
m_x, m_h = aggr_out[:, : self.m_len], aggr_out[:, self.m_len :]
5163
h_l1 = self.phi_h(torch.cat([h, m_h], dim=-1))
5264
x_l1 = x + (m_x / c)
5365
return x_l1, h_l1

0 commit comments

Comments
 (0)