Skip to content

Commit 65bb60a

Browse files
committed
Recording loss and accuracy throughout the training. Fixing metrics printing.
1 parent 18c32dc commit 65bb60a

2 files changed

Lines changed: 11 additions & 6 deletions

File tree

main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ def reload_checkpoint(path, device=None):
9191
result = train(net, train_loader, loss=criterion, optimizer=optimizer, scheduler=scheduler, grad_clip=args.grad_clip,
9292
num_epochs=args.epochs, num_validation=args.validation_size, validation_loader=valid_loader,
9393
device=args.device, print_frequency=args.print_freq)
94-
print(f"Final train metrics: {result}")
94+
95+
last_epoch = {k: v[-1] for k, v in result.items() if len(v) > 0}
96+
print(f"Final train metrics: {last_epoch}")
9597

9698
result = test(net, test_loader, loss=criterion, num_tests=test_size, device=args.device)
9799
print(f"\nFinal test metrics: {result}")

nasbench_pytorch/trainer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def train(net, train_loader, loss=None, optimizer=None, scheduler=None, grad_cli
5151

5252
n_batches = len(train_loader)
5353
last_loss, acc, val_loss, val_acc = [torch.tensor(0.0) for _ in range(4)]
54-
metric_dict = {}
54+
metric_dict = {'train_loss': [], 'train_accuracy': [], 'val_loss': [], 'val_accuracy': []}
5555
for epoch in range(num_epochs):
5656
# checkpoint using a user defined function
5757
if checkpoint_every_k is not None and (epoch + 1) % checkpoint_every_k == 0:
@@ -91,12 +91,15 @@ def train(net, train_loader, loss=None, optimizer=None, scheduler=None, grad_cli
9191
last_loss = train_loss / (batch_idx + 1)
9292
acc = correct / total
9393

94+
# save metrics
95+
metric_dict['train_loss'].append(last_loss.item())
96+
metric_dict['train_accuracy'].append(acc.item())
97+
9498
if validation_loader is not None:
95-
val_loss, val_acc = test(net, validation_loader, loss, num_tests=num_validation, device=device)
99+
test_metrics = test(net, validation_loader, loss, num_tests=num_validation, device=device)
100+
metric_dict['val_loss'].append(test_metrics['test_loss'])
101+
metric_dict['val_accuracy'].append(test_metrics['test_accuracy'])
96102

97-
# save metrics
98-
metric_dict = {'train_loss': last_loss.item(), 'train_accuracy': acc.item(),
99-
'val_loss': val_loss.item(), 'val_accuracy': val_acc.item()}
100103
print('--------------------')
101104
scheduler.step()
102105

0 commit comments

Comments
 (0)