1+ # An example of link prediction using negative and positive samples on CiteSeer.
2+ # Ported from link_prediction_pubmed.jl — same pipeline, CiteSeer dataset.
3+ # See https://arxiv.org/pdf/2102.12557.pdf for a comparison of methods.
4+
5+ using Flux
6+ using Flux. Losses: logitbinarycrossentropy
7+ using GraphNeuralNetworks
8+ using MLDatasets: CiteSeer
9+ using Statistics, Random, LinearAlgebra
10+ # using CUDA
11+ # CUDA.allowscalar(false)
12+
13+ # Hyperparameters
14+ Base. @kwdef mutable struct Args
15+ η = 1.0f-3 # learning rate
16+ epochs = 200 # total training epochs
17+ seed = 17 # RNG seed
18+ # usecuda = true # use GPU when available
19+ nhidden = 64 # GCN hidden / output embedding dimension
20+ infotime = 10 # log every `infotime` epochs
21+ end
22+
23+ # Edge decoder
24+ # We define our own edge prediction layer but could also
25+ # use GraphNeuralNetworks.DotDecoder instead.
26+ struct DotDecoder end
27+
28+ function (:: DotDecoder )(g, h)
29+ z = apply_edges ((xi, xj, e) -> sum (xi .* xj, dims = 1 ), g, xi = h, xj = h)
30+ return vec (z)
31+ end
32+
33+ # Loss + accuracy helper
34+ function loss_and_acc (model, pred, h, pos_g, neg_g)
35+ pos_score = pred (pos_g, h)
36+ neg_score = pred (neg_g, h)
37+ scores = [pos_score; neg_score]
38+ labels = [ones (Float32, length (pos_score)); zeros (Float32, length (neg_score))]
39+ l = logitbinarycrossentropy (scores, labels)
40+ acc = 0.5f0 * mean (pos_score .>= 0 ) + 0.5f0 * mean (neg_score .< 0 )
41+ return l, acc
42+ end
43+
44+ # Main training function
45+ function train (; kws... )
46+ args = Args (; kws... )
47+ args. seed > 0 && Random. seed! (args. seed)
48+
49+ # Device selection
50+ device = cpu
51+ @info " Training on CPU"
52+
53+
54+ # Load dataset
55+ g = mldataset2gnngraph (CiteSeer ()) |> device
56+ X = g. ndata. features
57+
58+ display (g)
59+ @show is_bidirected (g)
60+ @show has_self_loops (g)
61+ @show has_multi_edges (g)
62+ @show mean (degree (g))
63+
64+ isbidir = is_bidirected (g)
65+
66+ # ### TRAIN/TEST splits
67+ # with bidirected graph, we make sure that an edge and its reverse
68+ # are in the same split
69+ train_pos_g, test_pos_g = rand_edge_split (g, 0.9 , bidirected = isbidir)
70+
71+ test_neg_g = negative_sample (g,
72+ num_neg_edges = test_pos_g. num_edges,
73+ bidirected = isbidir)
74+
75+ # Model
76+
77+ nin, nhidden = size (X, 1 ), args. nhidden
78+
79+ model = WithGraph (
80+ GNNChain (
81+ GCNConv (nin => nhidden, relu),
82+ GCNConv (nhidden => nhidden),
83+ ),
84+ train_pos_g,
85+ ) |> device
86+
87+ pred = DotDecoder ()
88+ opt = Flux. setup (Adam (args. η), model)
89+
90+ # Logging
91+ function report (epoch)
92+ h = model (X)
93+ train_neg_g = negative_sample (train_pos_g, bidirected = isbidir)
94+ train_l, train_acc = loss_and_acc (model, pred, h, train_pos_g, train_neg_g)
95+ test_l, test_acc = loss_and_acc (model, pred, h, test_pos_g, test_neg_g)
96+ @info " " epoch (;train_l, train_acc) (;test_l, test_acc)
97+ end
98+
99+ # Training loop
100+
101+ report (0 )
102+ for epoch in 1 : (args. epochs)
103+ grads = Flux. gradient (model) do model
104+ h = model (X)
105+ neg_g = negative_sample (train_pos_g, bidirected = isbidir)
106+ l, _ = loss_and_acc (model, pred, h, train_pos_g, neg_g)
107+ l
108+ end
109+ Flux. update! (opt, model, grads[1 ])
110+ epoch % args. infotime == 0 && report (epoch)
111+ end
112+ end
113+
114+ train ()
0 commit comments