Skip to content

Commit c274dde

Browse files
committed
Add CiteSeer node classification example
1 parent 48c1528 commit c274dde

1 file changed

Lines changed: 96 additions & 0 deletions

File tree

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# An example of semi-supervised node classification on CiteSeer.
2+
# Ported from the Cora example — same GCN architecture, different dataset.
3+
# See node_classification_cora.jl for a detailed walk-through.
4+
5+
using Flux
6+
using Flux: onecold, onehotbatch
7+
using Flux.Losses: logitcrossentropy
8+
using GraphNeuralNetworks
9+
using MLDatasets: CiteSeer
10+
using Statistics, Random
11+
#using CUDA
12+
#CUDA.allowscalar(false)
13+
14+
15+
# Computes cross-entropy loss and accuracy on a boolean node mask.
16+
17+
function eval_loss_accuracy(X, y, mask, model, g)
18+
ŷ = model(g, X)
19+
l = logitcrossentropy(ŷ[:, mask], y[:, mask])
20+
acc = mean(onecold(ŷ[:, mask]) .== onecold(y[:, mask]))
21+
return (loss = round(l, digits = 4),
22+
acc = round(acc * 100, digits = 2))
23+
end
24+
25+
26+
# Hyperparameters
27+
28+
Base.@kwdef mutable struct Args
29+
η = 1.0f-3 # learning rate
30+
epochs = 200 # total training epochs
31+
seed = 17 # RNG seed (set > 0 for reproducibility)
32+
#usecuda = true # use GPU when available
33+
nhidden = 128 # hidden-layer width
34+
infotime = 10 # log every `infotime` epochs
35+
end
36+
37+
38+
# Main training function
39+
40+
function train(; kws...)
41+
args = Args(; kws...)
42+
args.seed > 0 && Random.seed!(args.seed)
43+
44+
#Device selection
45+
46+
device = cpu
47+
@info "Training on CPU"
48+
49+
50+
# Load dataset
51+
dataset = CiteSeer()
52+
classes = dataset.metadata["classes"]
53+
g = mldataset2gnngraph(dataset) |> device
54+
X = g.ndata.features
55+
56+
y = onehotbatch(g.ndata.targets |> cpu, classes) |> device
57+
58+
display(g)
59+
@show length(classes)
60+
@show is_bidirected(g)
61+
@show has_self_loops(g)
62+
63+
#Model
64+
nin, nhidden, nout = size(X, 1), args.nhidden, length(classes)
65+
66+
model = GNNChain(
67+
GCNConv(nin => nhidden, relu),
68+
Dropout(0.5),
69+
GCNConv(nhidden => nhidden, relu),
70+
Dense(nhidden, nout),
71+
) |> device
72+
73+
opt = Flux.setup(Adam(args.η), model)
74+
75+
#Training loop
76+
ytrain = y[:, g.ndata.train_mask]
77+
78+
function report(epoch)
79+
train_m = eval_loss_accuracy(X, y, g.ndata.train_mask, model, g)
80+
val_m = eval_loss_accuracy(X, y, g.ndata.val_mask, model, g)
81+
test_m = eval_loss_accuracy(X, y, g.ndata.test_mask, model, g)
82+
@info "" epoch train_m val_m test_m
83+
end
84+
85+
report(0)
86+
for epoch in 1:(args.epochs)
87+
grads = Flux.gradient(model) do model
88+
ŷ = model(g, X)
89+
logitcrossentropy(ŷ[:, g.ndata.train_mask], ytrain)
90+
end
91+
Flux.update!(opt, model, grads[1])
92+
epoch % args.infotime == 0 && report(epoch)
93+
end
94+
end
95+
96+
train()

0 commit comments

Comments
 (0)