@@ -51,7 +51,7 @@ def train(net, train_loader, loss=None, optimizer=None, scheduler=None, grad_cli
5151
5252 n_batches = len (train_loader )
5353 last_loss , acc , val_loss , val_acc = [torch .tensor (0.0 ) for _ in range (4 )]
54- metric_dict = {}
54+ metric_dict = {'train_loss' : [], 'train_accuracy' : [], 'val_loss' : [], 'val_accuracy' : [] }
5555 for epoch in range (num_epochs ):
5656 # checkpoint using a user defined function
5757 if checkpoint_every_k is not None and (epoch + 1 ) % checkpoint_every_k == 0 :
@@ -91,12 +91,15 @@ def train(net, train_loader, loss=None, optimizer=None, scheduler=None, grad_cli
9191 last_loss = train_loss / (batch_idx + 1 )
9292 acc = correct / total
9393
94+ # save metrics
95+ metric_dict ['train_loss' ].append (last_loss .item ())
96+ metric_dict ['train_accuracy' ].append (acc .item ())
97+
9498 if validation_loader is not None :
95- val_loss , val_acc = test (net , validation_loader , loss , num_tests = num_validation , device = device )
99+ test_metrics = test (net , validation_loader , loss , num_tests = num_validation , device = device )
100+ metric_dict ['val_loss' ].append (test_metrics ['test_loss' ])
101+ metric_dict ['val_accuracy' ].append (test_metrics ['test_accuracy' ])
96102
97- # save metrics
98- metric_dict = {'train_loss' : last_loss .item (), 'train_accuracy' : acc .item (),
99- 'val_loss' : val_loss .item (), 'val_accuracy' : val_acc .item ()}
100103 print ('--------------------' )
101104 scheduler .step ()
102105
0 commit comments