1+ import math
2+ import torch
3+ from torch .nn import Parameter
4+ import torch .nn .functional as F
5+ from typing import Any
6+ from .utils import mask_adjs , mask_x
7+
8+ def glorot (tensor ):
9+ if tensor is not None :
10+ stdv = math .sqrt (6.0 / (tensor .size (- 2 ) + tensor .size (- 1 )))
11+ tensor .data .uniform_ (- stdv , stdv )
12+
13+ def zeros (tensor ):
14+ if tensor is not None :
15+ tensor .data .fill_ (0 )
16+
17+ def reset (value : Any ):
18+ if hasattr (value , 'reset_parameters' ):
19+ value .reset_parameters ()
20+ else :
21+ for child in value .children () if hasattr (value , 'children' ) else []:
22+ reset (child )
23+
24+ # -------- GCN layer --------
25+ class DenseGCNConv (torch .nn .Module ):
26+ r"""See :class:`torch_geometric.nn.conv.GCNConv`.
27+ """
28+ def __init__ (self , in_channels , out_channels , improved = False , bias = True ):
29+ super (DenseGCNConv , self ).__init__ ()
30+
31+ self .in_channels = in_channels
32+ self .out_channels = out_channels
33+ self .improved = improved
34+
35+ self .weight = Parameter (torch .Tensor (self .in_channels , out_channels ))
36+
37+ if bias :
38+ self .bias = Parameter (torch .Tensor (out_channels ))
39+ else :
40+ self .register_parameter ('bias' , None )
41+
42+ self .reset_parameters ()
43+
44+ def reset_parameters (self ):
45+ glorot (self .weight )
46+ zeros (self .bias )
47+
48+
49+ def forward (self , x , adj , mask = None , add_loop = True ):
50+ r"""
51+ Args:
52+ x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B
53+ \times N \times F}`, with batch-size :math:`B`, (maximum)
54+ number of nodes :math:`N` for each graph, and feature
55+ dimension :math:`F`.
56+ adj (Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B
57+ \times N \times N}`. The adjacency tensor is broadcastable in
58+ the batch dimension, resulting in a shared adjacency matrix for
59+ the complete batch.
60+ mask (BoolTensor, optional): Mask matrix
61+ :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
62+ the valid nodes for each graph. (default: :obj:`None`)
63+ add_loop (bool, optional): If set to :obj:`False`, the layer will
64+ not automatically add self-loops to the adjacency matrices.
65+ (default: :obj:`True`)
66+ """
67+ x = x .unsqueeze (0 ) if x .dim () == 2 else x
68+ adj = adj .unsqueeze (0 ) if adj .dim () == 2 else adj
69+ B , N , _ = adj .size ()
70+
71+ if add_loop :
72+ adj = adj .clone ()
73+ idx = torch .arange (N , dtype = torch .long , device = adj .device )
74+ adj [:, idx , idx ] = 1 if not self .improved else 2
75+
76+ out = torch .matmul (x , self .weight )
77+ deg_inv_sqrt = adj .sum (dim = - 1 ).clamp (min = 1 ).pow (- 0.5 )
78+
79+ adj = deg_inv_sqrt .unsqueeze (- 1 ) * adj * deg_inv_sqrt .unsqueeze (- 2 )
80+ out = torch .matmul (adj , out )
81+
82+ if self .bias is not None :
83+ out = out + self .bias
84+
85+ if mask is not None :
86+ out = out * mask .view (B , N , 1 ).to (x .dtype )
87+
88+ return out
89+
90+
91+ def __repr__ (self ):
92+ return '{}({}, {})' .format (self .__class__ .__name__ , self .in_channels ,
93+ self .out_channels )
94+
95+ # -------- MLP layer --------
96+ class MLP (torch .nn .Module ):
97+ def __init__ (self , num_layers , input_dim , hidden_dim , output_dim , use_bn = False , activate_func = F .relu ):
98+ """
99+ num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model.
100+ input_dim: dimensionality of input features
101+ hidden_dim: dimensionality of hidden units at ALL layers
102+ output_dim: number of classes for prediction
103+ num_classes: the number of classes of input, to be treated with different gains and biases,
104+ (see the definition of class `ConditionalLayer1d`)
105+ """
106+
107+ super (MLP , self ).__init__ ()
108+
109+ self .linear_or_not = True # default is linear model
110+ self .num_layers = num_layers
111+ self .use_bn = use_bn
112+ self .activate_func = activate_func
113+
114+ if num_layers < 1 :
115+ raise ValueError ("number of layers should be positive!" )
116+ elif num_layers == 1 :
117+ # Linear model
118+ self .linear = torch .nn .Linear (input_dim , output_dim )
119+ else :
120+ # Multi-layer model
121+ self .linear_or_not = False
122+ self .linears = torch .nn .ModuleList ()
123+
124+ self .linears .append (torch .nn .Linear (input_dim , hidden_dim ))
125+ for layer in range (num_layers - 2 ):
126+ self .linears .append (torch .nn .Linear (hidden_dim , hidden_dim ))
127+ self .linears .append (torch .nn .Linear (hidden_dim , output_dim ))
128+
129+ if self .use_bn :
130+ self .batch_norms = torch .nn .ModuleList ()
131+ for layer in range (num_layers - 1 ):
132+ self .batch_norms .append (torch .nn .BatchNorm1d (hidden_dim ))
133+
134+
135+ def forward (self , x ):
136+ """
137+ :param x: [num_classes * batch_size, N, F_i], batch of node features
138+ note that in self.cond_layers[layer],
139+ `x` is splited into `num_classes` groups in dim=0,
140+ and then treated with different gains and biases
141+ """
142+ if self .linear_or_not :
143+ # If linear model
144+ return self .linear (x )
145+ else :
146+ # If MLP
147+ h = x
148+ for layer in range (self .num_layers - 1 ):
149+ h = self .linears [layer ](h )
150+ if self .use_bn :
151+ h = self .batch_norms [layer ](h )
152+ h = self .activate_func (h )
153+ return self .linears [self .num_layers - 1 ](h )
154+
155+
156+ # -------- Graph Multi-Head Attention (GMH) --------
157+ # -------- From Baek et al. (2021) --------
158+ class Attention (torch .nn .Module ):
159+ def __init__ (self , in_dim , attn_dim , out_dim , num_heads = 4 , conv = 'GCN' ):
160+ super (Attention , self ).__init__ ()
161+ self .num_heads = num_heads
162+ self .attn_dim = attn_dim
163+ self .out_dim = out_dim
164+ self .conv = conv
165+
166+ self .gnn_q , self .gnn_k , self .gnn_v = self .get_gnn (in_dim , attn_dim , out_dim , conv )
167+ self .activation = torch .tanh
168+ self .softmax_dim = 2
169+
170+ def forward (self , x , adj , flags , attention_mask = None ):
171+ if self .conv == 'GCN' :
172+ Q = self .gnn_q (x , adj )
173+ K = self .gnn_k (x , adj )
174+ else :
175+ Q = self .gnn_q (x )
176+ K = self .gnn_k (x )
177+
178+ V = self .gnn_v (x , adj )
179+ dim_split = self .attn_dim // self .num_heads
180+ Q_ = torch .cat (Q .split (dim_split , 2 ), 0 )
181+ K_ = torch .cat (K .split (dim_split , 2 ), 0 )
182+
183+ if attention_mask is not None :
184+ attention_mask = torch .cat ([attention_mask for _ in range (self .num_heads )], 0 )
185+ attention_score = Q_ .bmm (K_ .transpose (1 ,2 ))/ math .sqrt (self .out_dim )
186+ A = self .activation ( attention_mask + attention_score )
187+ else :
188+ A = self .activation ( Q_ .bmm (K_ .transpose (1 ,2 ))/ math .sqrt (self .out_dim ) ) # (B x num_heads) x N x N
189+
190+ # -------- (B x num_heads) x N x N --------
191+ A = A .view (- 1 , * adj .shape )
192+ A = A .mean (dim = 0 )
193+ A = (A + A .transpose (- 1 ,- 2 ))/ 2
194+
195+ return V , A
196+
197+ def get_gnn (self , in_dim , attn_dim , out_dim , conv = 'GCN' ):
198+
199+ if conv == 'GCN' :
200+ gnn_q = DenseGCNConv (in_dim , attn_dim )
201+ gnn_k = DenseGCNConv (in_dim , attn_dim )
202+ gnn_v = DenseGCNConv (in_dim , out_dim )
203+
204+ return gnn_q , gnn_k , gnn_v
205+
206+ elif conv == 'MLP' :
207+ num_layers = 2
208+ gnn_q = MLP (num_layers , in_dim , 2 * attn_dim , attn_dim , activate_func = torch .tanh )
209+ gnn_k = MLP (num_layers , in_dim , 2 * attn_dim , attn_dim , activate_func = torch .tanh )
210+ gnn_v = DenseGCNConv (in_dim , out_dim )
211+
212+ return gnn_q , gnn_k , gnn_v
213+
214+ else :
215+ raise NotImplementedError (f'{ conv } not implemented.' )
216+
217+
218+ # -------- Layer of ScoreNetworkA --------
219+ class AttentionLayer (torch .nn .Module ):
220+ def __init__ (self , num_linears , conv_input_dim , attn_dim , conv_output_dim , input_dim , output_dim ,
221+ num_heads = 4 , conv = 'GCN' ):
222+ super (AttentionLayer , self ).__init__ ()
223+ self .attn = torch .nn .ModuleList ()
224+ for _ in range (input_dim ):
225+ self .attn_dim = attn_dim
226+ self .attn .append (Attention (conv_input_dim , self .attn_dim , conv_output_dim ,
227+ num_heads = num_heads , conv = conv ))
228+
229+ self .hidden_dim = 2 * max (input_dim , output_dim )
230+ self .mlp = MLP (num_linears , 2 * input_dim , self .hidden_dim , output_dim , use_bn = False , activate_func = F .elu )
231+ self .multi_channel = MLP (2 , input_dim * conv_output_dim , self .hidden_dim , conv_output_dim ,
232+ use_bn = False , activate_func = F .elu )
233+
234+ def forward (self , x , adj , flags ):
235+ """
236+
237+ :param x: B x N x F_i
238+ :param adj: B x C_i x N x N
239+ :return: x_out: B x N x F_o, adj_out: B x C_o x N x N
240+ """
241+ mask_list = []
242+ x_list = []
243+ for _ in range (len (self .attn )):
244+ _x , mask = self .attn [_ ](x , adj [:,_ ,:,:], flags )
245+ mask_list .append (mask .unsqueeze (- 1 ))
246+ x_list .append (_x )
247+ x_out = mask_x (self .multi_channel (torch .cat (x_list , dim = - 1 )), flags )
248+ x_out = torch .tanh (x_out )
249+
250+ mlp_in = torch .cat ([torch .cat (mask_list , dim = - 1 ), adj .permute (0 ,2 ,3 ,1 )], dim = - 1 )
251+ shape = mlp_in .shape
252+ mlp_out = self .mlp (mlp_in .view (- 1 , shape [- 1 ]))
253+ _adj = mlp_out .view (shape [0 ], shape [1 ], shape [2 ], - 1 ).permute (0 ,3 ,1 ,2 )
254+ _adj = _adj + _adj .transpose (- 1 ,- 2 )
255+ adj_out = mask_adjs (_adj , flags )
256+
257+ return x_out , adj_out
0 commit comments