Skip to content

Commit f328888

Browse files
committed
Add CiteSeer link prediction example
1 parent 48c1528 commit f328888

1 file changed

Lines changed: 114 additions & 0 deletions

File tree

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# An example of link prediction using negative and positive samples on CiteSeer.
2+
# Ported from link_prediction_pubmed.jl — same pipeline, CiteSeer dataset.
3+
# See https://arxiv.org/pdf/2102.12557.pdf for a comparison of methods.
4+
5+
using Flux
6+
using Flux.Losses: logitbinarycrossentropy
7+
using GraphNeuralNetworks
8+
using MLDatasets: CiteSeer
9+
using Statistics, Random, LinearAlgebra
10+
#using CUDA
11+
#CUDA.allowscalar(false)
12+
13+
# Hyperparameters
14+
Base.@kwdef mutable struct Args
15+
η = 1.0f-3 # learning rate
16+
epochs = 200 # total training epochs
17+
seed = 17 # RNG seed
18+
#usecuda = true # use GPU when available
19+
nhidden = 64 # GCN hidden / output embedding dimension
20+
infotime = 10 # log every `infotime` epochs
21+
end
22+
23+
# Edge decoder
24+
# We define our own edge prediction layer but could also
25+
# use GraphNeuralNetworks.DotDecoder instead.
26+
struct DotDecoder end
27+
28+
function (::DotDecoder)(g, h)
29+
z = apply_edges((xi, xj, e) -> sum(xi .* xj, dims = 1), g, xi = h, xj = h)
30+
return vec(z)
31+
end
32+
33+
# Loss + accuracy helper
34+
function loss_and_acc(model, pred, h, pos_g, neg_g)
35+
pos_score = pred(pos_g, h)
36+
neg_score = pred(neg_g, h)
37+
scores = [pos_score; neg_score]
38+
labels = [ones(Float32, length(pos_score)); zeros(Float32, length(neg_score))]
39+
l = logitbinarycrossentropy(scores, labels)
40+
acc = 0.5f0 * mean(pos_score .>= 0) + 0.5f0 * mean(neg_score .< 0)
41+
return l, acc
42+
end
43+
44+
# Main training function
45+
function train(; kws...)
46+
args = Args(; kws...)
47+
args.seed > 0 && Random.seed!(args.seed)
48+
49+
# Device selection
50+
device = cpu
51+
@info "Training on CPU"
52+
53+
54+
# Load dataset
55+
g = mldataset2gnngraph(CiteSeer()) |> device
56+
X = g.ndata.features
57+
58+
display(g)
59+
@show is_bidirected(g)
60+
@show has_self_loops(g)
61+
@show has_multi_edges(g)
62+
@show mean(degree(g))
63+
64+
isbidir = is_bidirected(g)
65+
66+
#### TRAIN/TEST splits
67+
# with bidirected graph, we make sure that an edge and its reverse
68+
# are in the same split
69+
train_pos_g, test_pos_g = rand_edge_split(g, 0.9, bidirected = isbidir)
70+
71+
test_neg_g = negative_sample(g,
72+
num_neg_edges = test_pos_g.num_edges,
73+
bidirected = isbidir)
74+
75+
#Model
76+
77+
nin, nhidden = size(X, 1), args.nhidden
78+
79+
model = WithGraph(
80+
GNNChain(
81+
GCNConv(nin => nhidden, relu),
82+
GCNConv(nhidden => nhidden),
83+
),
84+
train_pos_g,
85+
) |> device
86+
87+
pred = DotDecoder()
88+
opt = Flux.setup(Adam(args.η), model)
89+
90+
#Logging
91+
function report(epoch)
92+
h = model(X)
93+
train_neg_g = negative_sample(train_pos_g, bidirected = isbidir)
94+
train_l, train_acc = loss_and_acc(model, pred, h, train_pos_g, train_neg_g)
95+
test_l, test_acc = loss_and_acc(model, pred, h, test_pos_g, test_neg_g)
96+
@info "" epoch (;train_l, train_acc) (;test_l, test_acc)
97+
end
98+
99+
#Training loop
100+
101+
report(0)
102+
for epoch in 1:(args.epochs)
103+
grads = Flux.gradient(model) do model
104+
h = model(X)
105+
neg_g = negative_sample(train_pos_g, bidirected = isbidir)
106+
l, _ = loss_and_acc(model, pred, h, train_pos_g, neg_g)
107+
l
108+
end
109+
Flux.update!(opt, model, grads[1])
110+
epoch % args.infotime == 0 && report(epoch)
111+
end
112+
end
113+
114+
train()

0 commit comments

Comments
 (0)