@@ -104,8 +104,8 @@ def forward(self, x):
104104save_markdown_log ("test_output.md" , test_log )
105105# Mirror test log into docs/
106106save_markdown_log (os .path .join ("docs" , "test_output.md" ), test_log )
107- batch_correct = (outputs .argmax (1 ) == labels ).sum ().item ()
108- correct += batch_correct
107+ batch_correct = (outputs .argmax (1 ) == labels ).sum ().item ()
108+ correct += batch_correct
109109test_log += predict_and_log_samples (model , test_loader , device )
110110sample_imgs = save_test_samples_images (model , test_loader , device , out_dir = "images" , num_samples = 5 )
111111for i , (idx , label , pred , img_path ) in enumerate (sample_imgs , start = 1 ):
@@ -126,8 +126,8 @@ def forward(self, x):
126126 inputs , labels = inputs .to (device ), labels .to (device )
127127 outputs = model (inputs )
128128 test_loss += criterion (outputs , labels ).item ()
129- batch_correct = (outputs .argmax (1 ) == labels ).sum ().item ()
130- correct += batch_correct
129+ batch_correct = (outputs .argmax (1 ) == labels ).sum ().item ()
130+ correct += batch_correct
131131 batch_acc = batch_correct / len (inputs )
132132 test_accuracies .append (batch_acc )
133133
0 commit comments