1+ # An example of semi-supervised node classification on CiteSeer.
2+ # Ported from the Cora example — same GCN architecture, different dataset.
3+ # See node_classification_cora.jl for a detailed walk-through.
4+
5+ using Flux
6+ using Flux: onecold, onehotbatch
7+ using Flux. Losses: logitcrossentropy
8+ using GraphNeuralNetworks
9+ using MLDatasets: CiteSeer
10+ using Statistics, Random
11+ # using CUDA
12+ # CUDA.allowscalar(false)
13+
14+
15+ # Computes cross-entropy loss and accuracy on a boolean node mask.
16+
17+ function eval_loss_accuracy (X, y, mask, model, g)
18+ ŷ = model (g, X)
19+ l = logitcrossentropy (ŷ[:, mask], y[:, mask])
20+ acc = mean (onecold (ŷ[:, mask]) .== onecold (y[:, mask]))
21+ return (loss = round (l, digits = 4 ),
22+ acc = round (acc * 100 , digits = 2 ))
23+ end
24+
25+
26+ # Hyperparameters
27+
28+ Base. @kwdef mutable struct Args
29+ η = 1.0f-3 # learning rate
30+ epochs = 200 # total training epochs
31+ seed = 17 # RNG seed (set > 0 for reproducibility)
32+ # usecuda = true # use GPU when available
33+ nhidden = 128 # hidden-layer width
34+ infotime = 10 # log every `infotime` epochs
35+ end
36+
37+
38+ # Main training function
39+
40+ function train (; kws... )
41+ args = Args (; kws... )
42+ args. seed > 0 && Random. seed! (args. seed)
43+
44+ # Device selection
45+
46+ device = cpu
47+ @info " Training on CPU"
48+
49+
50+ # Load dataset
51+ dataset = CiteSeer ()
52+ classes = dataset. metadata[" classes" ]
53+ g = mldataset2gnngraph (dataset) |> device
54+ X = g. ndata. features
55+
56+ y = onehotbatch (g. ndata. targets |> cpu, classes) |> device
57+
58+ display (g)
59+ @show length (classes)
60+ @show is_bidirected (g)
61+ @show has_self_loops (g)
62+
63+ # Model
64+ nin, nhidden, nout = size (X, 1 ), args. nhidden, length (classes)
65+
66+ model = GNNChain (
67+ GCNConv (nin => nhidden, relu),
68+ Dropout (0.5 ),
69+ GCNConv (nhidden => nhidden, relu),
70+ Dense (nhidden, nout),
71+ ) |> device
72+
73+ opt = Flux. setup (Adam (args. η), model)
74+
75+ # Training loop
76+ ytrain = y[:, g. ndata. train_mask]
77+
78+ function report (epoch)
79+ train_m = eval_loss_accuracy (X, y, g. ndata. train_mask, model, g)
80+ val_m = eval_loss_accuracy (X, y, g. ndata. val_mask, model, g)
81+ test_m = eval_loss_accuracy (X, y, g. ndata. test_mask, model, g)
82+ @info " " epoch train_m val_m test_m
83+ end
84+
85+ report (0 )
86+ for epoch in 1 : (args. epochs)
87+ grads = Flux. gradient (model) do model
88+ ŷ = model (g, X)
89+ logitcrossentropy (ŷ[:, g. ndata. train_mask], ytrain)
90+ end
91+ Flux. update! (opt, model, grads[1 ])
92+ epoch % args. infotime == 0 && report (epoch)
93+ end
94+ end
95+
96+ train ()
0 commit comments