-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathSAGEConv.py
More file actions
29 lines (22 loc) · 1022 Bytes
/
SAGEConv.py
File metadata and controls
29 lines (22 loc) · 1022 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn.functional as F
from torch.nn import Parameter
# from torch_scatter import scatter_mean
from torch_geometric.nn.conv import MessagePassing
# from torch_geometric.utils import remove_self_loops, add_self_loops
# from ..inits import uniform
class SAGEConv(MessagePassing):
def __init__(self, in_channels, out_channels, normalize=True, bias=True, aggr='mean', **kwargs):
super(SAGEConv, self).__init__(aggr=aggr, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
def forward(self, x, edge_index, weight_vector, size=None):
self.weight_vector = weight_vector
return self.propagate(edge_index, size=size, x=x)
def message(self, x_j):
return x_j * self.weight_vector
def update(self, aggr_out):
return aggr_out
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)