-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtrain_for_classification.py
More file actions
160 lines (131 loc) · 5.51 KB
/
train_for_classification.py
File metadata and controls
160 lines (131 loc) · 5.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import torch
import time
from sklearn.metrics import (
precision_score,
recall_score,
f1_score
)
def initialize_metrics_storage():
return {
'losses': [],
'accuracies': [],
'precisions': [],
'recalls': [],
'f1_scores': []
}
def train_step(model, optimizer, criterion, data):
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
def validate_step(model, data):
return calculate_metrics(model, data, 'val')
def train(num_epochs, data, model, optimizer, criterion):
# Initialize metrics storage
train_metrics = initialize_metrics_storage()
val_metrics = initialize_metrics_storage()
for epoch in range(1, num_epochs + 1):
# Training Step
train_loss = train_step(model, optimizer, criterion, data)
train_metrics_epoch = calculate_metrics(model, data, 'train')
update_metrics(train_metrics, train_metrics_epoch, train_loss)
# Validation Step
val_metrics_epoch = validate_step(model, data)
update_metrics(val_metrics, val_metrics_epoch)
# Logging
if epoch % 100 == 0:
log_epoch(epoch, train_loss, train_metrics_epoch, val_metrics_epoch)
return {
'train': train_metrics,
'val': val_metrics
}
def calculate_metrics(model, data, mask_type='train'):
mask = getattr(data, f"{mask_type}_mask")
model.eval()
with torch.no_grad():
out = model(data.x, data.edge_index)
pred = out[mask].argmax(dim=1)
correct = (pred == data.y[mask]).sum()
accuracy = int(correct) / int(mask.sum())
y_true = data.y[mask].cpu().numpy()
y_pred = pred.cpu().numpy()
precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1_score': f1
}
def update_metrics(metrics, metrics_epoch, loss=None):
if loss is not None:
metrics['losses'].append(loss)
metrics['accuracies'].append(metrics_epoch['accuracy'])
metrics['precisions'].append(metrics_epoch['precision'])
metrics['recalls'].append(metrics_epoch['recall'])
metrics['f1_scores'].append(metrics_epoch['f1_score'])
def log_epoch(epoch, train_loss, train_metrics_epoch, val_metrics_epoch):
print(f'Epoch {epoch:03d}, Loss: {train_loss:.4f}, Train - '
f'Acc: {train_metrics_epoch["accuracy"]:.4f} - '
f'Prec: {train_metrics_epoch["precision"]:.4f} - '
f'Rec: {train_metrics_epoch["recall"]:.4f} - '
f'F1: {train_metrics_epoch["f1_score"]:.4f}')
print(f'Val - Acc: {val_metrics_epoch["accuracy"]:.4f} - '
f'Prec: {val_metrics_epoch["precision"]:.4f} - '
f'Rec: {val_metrics_epoch["recall"]:.4f} - '
f'F1: {val_metrics_epoch["f1_score"]:.4f}')
def train_multi_models(classifier,
models,
data,
hidden_dim,
num_classes,
num_epochs=100,
lr=0.01,
weight_decay=0.0005,
device='cuda'):
"""
Trains and evaluates multiple models for the classification task
Args:
classifier (torch.nn.Module): Classifier model.
models (dict): Dictionary where keys are model names and values are model classes (uninstantiated).
data (torch_geometric.data.Data): Graph data object.
hidden_dim (int): Hidden dimension for the model.
num_classes (int): Number of target classes.
num_epochs (int): Number of epochs for training. Default is 400.
lr (float): Learning rate. Default is 0.01.
weight_decay (float): Weight decay for the optimizer. Default is 0.0005.
device (str): Device to run the models on ('cuda' or 'cpu').
Returns:
dict: Dictionary containing training and validation metrics for all models.
dict: Dictionary containing the trained model instances for all models.
"""
# Prepare data and loss criterion
data = data.to(device)
criterion = torch.nn.CrossEntropyLoss()
metrics = {}
trained_models = {} # To store the trained model instances
for model_name, model_class in models.items():
print(f"\n### Training {model_name}...")
# Instantiate the model and move it to the device
model = classifier(model_class(input_dim=data.num_features,
hidden_dim=hidden_dim,
out_dim=num_classes)).to(device)
# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
# Record the start time
start_time = time.time()
# Train the model
train_val_metrics = train(num_epochs, data, model, optimizer, criterion)
# Record the end time
end_time = time.time()
elapsed_time = end_time - start_time
# Update the global metrics dictionary
metrics[model_name] = train_val_metrics
# Store the trained model
trained_models[model_name] = model
print(f"{model_name} training completed in {elapsed_time:.2f} seconds.")
return metrics, trained_models