-
Notifications
You must be signed in to change notification settings - Fork 78
Expand file tree
/
Copy pathgrace.py
More file actions
361 lines (311 loc) · 12.3 KB
/
grace.py
File metadata and controls
361 lines (311 loc) · 12.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
GRACE (Graph Contrastive Learning)
References
----------
Paper: https://arxiv.org/abs/2006.04131
Author's code: https://github.com/CRIPAC-DIG/GRACE
DGL code: https://github.com/dmlc/dgl/tree/master/examples/pytorch/grace
"""
import dgl
import numpy as np
import torch
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv
from torch import nn
class GRACE(nn.Module):
"""
GRACE model for graph representation learning via contrastive learning.
Parameters
----------
n_in_feats : int
Number of input features per node.
n_hidden : int
Dimension of the hidden layers.
n_out_feats : int
Dimension of the output features.
n_layers : int
Number of GNN layers.
act_fn : nn.Module
Activation function used in each layer.
temp : float
Temperature parameter for contrastive loss, controls the sharpness of
the similarity distribution.
edges_removing_rate_1 : float
Proportion of edges to remove when generating the first view of the graph.
edges_removing_rate_2 : float
Proportion of edges to remove when generating the second view of the graph.
feats_masking_rate_1 : float
Proportion of node features to mask when generating the first view of the graph.
feats_masking_rate_2 : float
Proportion of node features to mask when generating the second view of the graph.
"""
def __init__(
self,
n_in_feats,
n_hidden=128,
n_out_feats=128,
n_layers=2,
act_fn=None,
temp=0.4,
edges_removing_rate_1=0.2,
edges_removing_rate_2=0.4,
feats_masking_rate_1=0.3,
feats_masking_rate_2=0.4,
):
super().__init__()
self.encoder = GCN(n_in_feats, n_hidden, act_fn, n_layers) # Initialize the GCN encoder
# Initialize the MLP projector to map the encoded features to the contrastive space
self.proj = MLP(n_hidden, n_out_feats)
self.temp = temp # Set the temperature for the contrastive loss
self.edges_removing_rate_1 = edges_removing_rate_1 # Edge removal rate for the first view
self.edges_removing_rate_2 = edges_removing_rate_2 # Edge removal rate for the second view
self.feats_masking_rate_1 = feats_masking_rate_1 # Feature masking rate for the first view
self.feats_masking_rate_2 = feats_masking_rate_2 # Feature masking rate for the second view
@staticmethod
def sim(z1, z2):
"""
Compute the cosine similarity between two sets of node embeddings.
Parameters
----------
z1 : torch.Tensor
Node embeddings from the first view.
z2 : torch.Tensor
Node embeddings from the second view.
Returns
-------
torch.Tensor
Cosine similarity matrix.
"""
z1 = F.normalize(z1) # Normalize the embeddings for the first view
z2 = F.normalize(z2) # Normalize the embeddings for the second view
return torch.mm(z1, z2.t()) # Compute pairwise cosine similarity
def sim_loss(self, z1, z2):
"""
Compute the contrastive loss based on cosine similarity.
Parameters
----------
z1 : torch.Tensor
Node embeddings from the first view.
z2 : torch.Tensor
Node embeddings from the second view.
Returns
-------
torch.Tensor
Contrastive loss for the input embeddings.
"""
refl_sim = torch.exp(self.sim(z1, z1) / self.temp) # Self-similarity within the first view
between_sim = torch.exp(self.sim(z1, z2) / self.temp) # Cross-similarity between the two views
x1 = refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag() # Summation of similarities
loss = -torch.log(between_sim.diag() / x1) # Compute the contrastive loss
return loss
def loss(self, z1, z2):
"""
Compute the symmetric contrastive loss for both views.
Parameters
----------
z1 : torch.Tensor
Node embeddings from the first view.
z2 : torch.Tensor
Node embeddings from the second view.
Returns
-------
torch.Tensor
Average symmetric contrastive loss.
"""
l1 = self.sim_loss(z1=z1, z2=z2) # Loss for the first view
l2 = self.sim_loss(z1=z2, z2=z1) # Loss for the second view (symmetry)
return (l1 + l2).mean() * 0.5 # Average the loss for symmetry
def get_embedding(self, graph, feats):
"""
Get the node embeddings from the encoder without computing gradients.
Parameters
----------
graph : dgl.DGLGraph
The input graph.
feats : torch.Tensor
Node features.
Returns
-------
torch.Tensor
Node embeddings.
"""
h = self.encoder(graph, feats) # Encode the node features with GCN
return h.detach() # Detach from computation graph for evaluation
def forward(self, graph, feats):
"""
Perform the forward pass and compute the contrastive loss.
Parameters
----------
graph : dgl.DGLGraph
The input graph.
feats : torch.Tensor
Node features.
Returns
-------
torch.Tensor
Contrastive loss between two views of the graph.
"""
# Generate the first view
graph1, feats1 = _generating_views(graph, feats, self.edges_removing_rate_1, self.feats_masking_rate_1)
# Generate the second view
graph2, feats2 = _generating_views(graph, feats, self.edges_removing_rate_2, self.feats_masking_rate_2)
z1 = self.proj(self.encoder(graph1, feats1)) # Project the encoded features for the first view
z2 = self.proj(self.encoder(graph2, feats2)) # Project the encoded features for the second view
loss = self.loss(z1, z2) # Compute the contrastive loss
return loss
class GCN(nn.Module):
"""
Graph Convolutional Network (GCN) for node feature transformation.
Parameters
----------
n_in_feats : int
Number of input features per node.
n_out_feats : int
Number of output features per node.
act_fn : nn.Module
Activation function.
n_layers : int
Number of GCN layers.
"""
def __init__(self, n_in_feats, n_out_feats, act_fn, n_layers=2):
super().__init__()
assert n_layers >= 2, "Number of layers should be at least 2."
self.n_layers = n_layers # Set the number of layers
self.n_hidden = n_out_feats * 2 # Set the hidden dimension as twice the output dimension
self.input_layer = GraphConv(n_in_feats, self.n_hidden, activation=act_fn) # Define the input layer
self.hidden_layers = nn.ModuleList(
[GraphConv(self.n_hidden, self.n_hidden, activation=act_fn) for _ in range(n_layers - 2)]
) # Define the hidden layers
self.output_layer = GraphConv(self.n_hidden, n_out_feats, activation=act_fn) # Define the output layer
def forward(self, graph, feat):
"""
Forward pass through the GCN.
Parameters
----------
graph : dgl.DGLGraph
The input graph.
feat : torch.Tensor
Node features.
Returns
-------
torch.Tensor
Transformed node features after passing through the GCN layers.
"""
feat = self.input_layer(graph, feat) # Apply graph convolution at the input layer
for hidden_layer in self.hidden_layers:
feat = hidden_layer(graph, feat) # Apply graph convolution at each hidden layer
return self.output_layer(graph, feat) # Apply graph convolution at the output layer
class MLP(nn.Module):
"""
A simple Multi-Layer Perceptron (MLP) for projecting node embeddings to a new space.
Parameters
----------
n_in_feats : int
Number of input features.
n_out_feats : int
Number of output features.
"""
def __init__(self, n_in_feats, n_out_feats):
super().__init__()
self.fc1 = nn.Linear(n_in_feats, n_out_feats) # Define the first fully connected layer
self.fc2 = nn.Linear(n_out_feats, n_out_feats) # Define the second fully connected layer
def forward(self, x):
"""
Forward pass through the MLP.
Parameters
----------
x : torch.Tensor
Input node embeddings.
Returns
-------
torch.Tensor
Projected node embeddings.
"""
z = F.elu(self.fc1(x)) # Apply ELU activation after the first layer
return self.fc2(z) # Return the output of the second layer
def _generating_views(graph, feats, edges_removing_rate, feats_masking_rate):
"""
Generate two different views of the graph by removing edges and masking node features.
Parameters
----------
graph : dgl.DGLGraph
The input graph.
feats : torch.Tensor
Node features.
edges_removing_rate : float
Proportion of edges to remove.
feats_masking_rate : float
Proportion of node features to mask.
Returns
-------
new_graph : dgl.DGLGraph
The modified graph with some edges removed.
masked_feats : torch.Tensor
Node features with some values masked.
"""
# Removing edges (RE)
removing_edges_idx = _get_removing_edges_idx(graph, edges_removing_rate) # Get the indices of edges to remove
src = graph.edges()[0] # Source nodes of the edges
dst = graph.edges()[1] # Destination nodes of the edges
new_src = src[removing_edges_idx] # New source nodes after edge removal
new_dst = dst[removing_edges_idx] # New destination nodes after edge removal
new_graph = dgl.graph(
(new_src, new_dst), num_nodes=graph.num_nodes(), device=graph.device
) # Create a new graph with the remaining edges
new_graph = dgl.add_self_loop(new_graph) # Add self-loops to the new graph
# Masking node features (MF)
masked_feats = _masking_node_feats(feats, feats_masking_rate) # Mask node features
return new_graph, masked_feats # Return the modified graph and masked features
def _masking_node_feats(feats, masking_rate):
"""
Mask node features by setting a certain proportion to zero.
Parameters
----------
feats : torch.Tensor
Node features.
masking_rate : float
Proportion of features to mask.
Returns
-------
torch.Tensor
Node features with some values masked.
"""
mask = torch.rand(feats.size(1), dtype=torch.float32, device=feats.device) < masking_rate # Generate a random mask
feats = feats.clone() # Clone the features to avoid in-place modification
feats[:, mask] = 0 # Set masked features to zero
return feats # Return the masked features
def _get_removing_edges_idx(graph, edges_removing_rate):
"""
Generate the indices of edges to be removed from the graph.
Parameters
----------
graph : dgl.DGLGraph
The input graph.
edges_removing_rate : float
Proportion of edges to remove.
Returns
-------
torch.Tensor
Indices of the edges to be removed.
"""
n_edges = graph.num_edges() # Total number of edges
mask_rates = torch.FloatTensor(np.ones(n_edges) * edges_removing_rate) # Generate mask rates for each edge
masks = torch.bernoulli(1 - mask_rates) # Generate a mask indicating which edges to keep
mask_idx = masks.nonzero().squeeze(1) # Get the indices of edges to keep
return mask_idx # Return the indices of edges to be removed