-
Notifications
You must be signed in to change notification settings - Fork 55
Expand file tree
/
Copy pathGCN.cs
More file actions
137 lines (112 loc) · 5.71 KB
/
GCN.cs
File metadata and controls
137 lines (112 loc) · 5.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
using System.Diagnostics;
using System.Linq;
using TorchSharp;
using TorchSharp.Examples;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
using static TorchSharp.torch.nn.functional;
namespace CSharpExamples
{
/// <summary>
/// Graph Convolutional Network (GCN) for node classification
///
/// Based on: https://github.com/pytorch/examples/tree/main/gcn
///
/// Implements a 2-layer GCN for semi-supervised node classification.
/// Uses synthetic graph data for demonstration since the Cora dataset
/// requires external download infrastructure.
/// </summary>
public class GCN
{
internal static void Run(int epochs, int timeout, string logdir)
{
var device =
torch.cuda.is_available() ? torch.CUDA :
torch.mps_is_available() ? torch.MPS :
torch.CPU;
Console.WriteLine();
Console.WriteLine($"\tRunning GCN on {device.type} for {epochs} epochs, terminating after {TimeSpan.FromSeconds(timeout)}.");
Console.WriteLine();
torch.random.manual_seed(42);
// Create synthetic graph data for demonstration
// In practice, you would load a real graph dataset like Cora
int numNodes = 2708;
int numFeatures = 1433;
int numClasses = 7;
int hiddenDim = 16;
Console.WriteLine($"\tGenerating synthetic graph data...");
Console.WriteLine($"\t Nodes: {numNodes}, Features: {numFeatures}, Classes: {numClasses}");
// Random features and labels
var features = torch.randn(numNodes, numFeatures, device: device);
var labels = torch.randint(numClasses, numNodes, device: device);
// Create a random sparse adjacency matrix (simulating graph structure)
int numEdges = 10556;
var edgeIdx1 = torch.randint(numNodes, numEdges, device: device);
var edgeIdx2 = torch.randint(numNodes, numEdges, device: device);
var adjMat = torch.zeros(numNodes, numNodes, device: device);
// Add edges and self-loops
for (int i = 0; i < numNodes; i++) {
adjMat[i, i] = 1.0f; // self-loops
}
// Note: In a real implementation, you'd construct the adjacency matrix properly
// and apply the renormalization trick D^(-1/2) A D^(-1/2)
// For now, use identity + random edges normalized by degree
adjMat = adjMat + torch.eye(numNodes, device: device) * 0.1f;
// Normalize adjacency matrix (simplified)
var degree = adjMat.sum(dim: 1);
var degreeInvSqrt = torch.sqrt(1.0f / degree);
degreeInvSqrt = torch.where(degreeInvSqrt.isinf(), torch.zeros_like(degreeInvSqrt), degreeInvSqrt);
var degreeMatrix = torch.diag(degreeInvSqrt);
adjMat = torch.mm(torch.mm(degreeMatrix, adjMat), degreeMatrix);
// Split into train/val/test
var idx = torch.randperm(numNodes, device: device);
var idxTrain = idx.slice(0, 1500, numNodes, 1);
var idxVal = idx.slice(0, 1000, 1500, 1);
var idxTest = idx.slice(0, 0, 1000, 1);
Console.WriteLine($"\tCreating GCN model...");
var model = new GCNModel("gcn", numFeatures, hiddenDim, numClasses,
useBias: true, dropoutP: 0.5, device: device);
var optimizer = optim.Adam(model.parameters(), lr: 0.01, weight_decay: 5e-4);
var criterion = NLLLoss();
Console.WriteLine($"\tTraining...");
Stopwatch totalTime = new Stopwatch();
totalTime.Start();
for (int epoch = 1; epoch <= epochs; epoch++) {
using (var d = torch.NewDisposeScope()) {
// Training
model.train();
optimizer.zero_grad();
var output = model.forward(features, adjMat);
var loss = criterion.forward(output.index(idxTrain), labels.index(idxTrain));
loss.backward();
optimizer.step();
if (epoch % 20 == 0 || epoch == 1) {
// Evaluate
model.eval();
using (torch.no_grad()) {
var evalOutput = model.forward(features, adjMat);
var trainAcc = evalOutput.index(idxTrain).argmax(1)
.eq(labels.index(idxTrain)).to_type(ScalarType.Float32).mean().item<float>();
var valAcc = evalOutput.index(idxVal).argmax(1)
.eq(labels.index(idxVal)).to_type(ScalarType.Float32).mean().item<float>();
Console.WriteLine($"\tEpoch {epoch:D4} | Loss: {loss.item<float>():F4} | Train Acc: {trainAcc:F4} | Val Acc: {valAcc:F4}");
}
}
}
if (totalTime.Elapsed.TotalSeconds > timeout) break;
}
// Final test evaluation
model.eval();
using (torch.no_grad()) {
var testOutput = model.forward(features, adjMat);
var testAcc = testOutput.index(idxTest).argmax(1)
.eq(labels.index(idxTest)).to_type(ScalarType.Float32).mean().item<float>();
Console.WriteLine($"\tTest accuracy: {testAcc:F4}");
}
totalTime.Stop();
Console.WriteLine($"Elapsed time: {totalTime.Elapsed.TotalSeconds:F1} s.");
}
}
}