forked from JuliaGraphs/GraphNeuralNetworks.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnode_classification_citeseer.jl
More file actions
96 lines (73 loc) · 2.58 KB
/
node_classification_citeseer.jl
File metadata and controls
96 lines (73 loc) · 2.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# An example of semi-supervised node classification on CiteSeer.
# Ported from the Cora example — same GCN architecture, different dataset.
# See node_classification_cora.jl for a detailed walk-through.
using Flux
using Flux: onecold, onehotbatch
using Flux.Losses: logitcrossentropy
using GraphNeuralNetworks
using MLDatasets: CiteSeer
using Statistics, Random
#using CUDA
#CUDA.allowscalar(false)
# Computes cross-entropy loss and accuracy on a boolean node mask.
function eval_loss_accuracy(X, y, mask, model, g)
ŷ = model(g, X)
l = logitcrossentropy(ŷ[:, mask], y[:, mask])
acc = mean(onecold(ŷ[:, mask]) .== onecold(y[:, mask]))
return (loss = round(l, digits = 4),
acc = round(acc * 100, digits = 2))
end
# Hyperparameters
Base.@kwdef mutable struct Args
η = 1.0f-3 # learning rate
epochs = 200 # total training epochs
seed = 17 # RNG seed (set > 0 for reproducibility)
#usecuda = true # use GPU when available
nhidden = 128 # hidden-layer width
infotime = 10 # log every `infotime` epochs
end
# Main training function
function train(; kws...)
args = Args(; kws...)
args.seed > 0 && Random.seed!(args.seed)
#Device selection
device = cpu
@info "Training on CPU"
# Load dataset
dataset = CiteSeer()
classes = dataset.metadata["classes"]
g = mldataset2gnngraph(dataset) |> device
X = g.ndata.features
y = onehotbatch(g.ndata.targets |> cpu, classes) |> device
display(g)
@show length(classes)
@show is_bidirected(g)
@show has_self_loops(g)
#Model
nin, nhidden, nout = size(X, 1), args.nhidden, length(classes)
model = GNNChain(
GCNConv(nin => nhidden, relu),
Dropout(0.5),
GCNConv(nhidden => nhidden, relu),
Dense(nhidden, nout),
) |> device
opt = Flux.setup(Adam(args.η), model)
#Training loop
ytrain = y[:, g.ndata.train_mask]
function report(epoch)
train_m = eval_loss_accuracy(X, y, g.ndata.train_mask, model, g)
val_m = eval_loss_accuracy(X, y, g.ndata.val_mask, model, g)
test_m = eval_loss_accuracy(X, y, g.ndata.test_mask, model, g)
@info "" epoch train_m val_m test_m
end
report(0)
for epoch in 1:(args.epochs)
grads = Flux.gradient(model) do model
ŷ = model(g, X)
logitcrossentropy(ŷ[:, g.ndata.train_mask], ytrain)
end
Flux.update!(opt, model, grads[1])
epoch % args.infotime == 0 && report(epoch)
end
end
train()