@@ -88,15 +88,34 @@ def forward(self, x):
8888)
8989
9090# Test Loop
91+ test_accuracy = correct / len (test_loader .dataset )
92+ test_loss /= len (test_loader )
93+ test_log += f"- **Test Loss**: { test_loss :.4f} \n "
94+ test_log += f"- **Test Accuracy**: { test_accuracy :.4f} \n \n "
95+ test_log += "## Sample Predictions\n \n "
96+ test_log += "| Image Index | True Label | Predicted Label |\n "
97+ test_log += "|-------------|------------|------------------|\n "
98+ test_log += predict_and_log_samples (model , test_loader , device )
99+ sample_imgs = save_test_samples_images (model , test_loader , device , out_dir = "images" , num_samples = 5 )
100+ test_log += "\n ### Sample Images\n "
101+ for i , (idx , label , pred , img_path ) in enumerate (sample_imgs , start = 1 ):
102+ test_log += "\n "
103+ save_markdown_log ("test_output.md" , test_log )
104+ # Mirror test log into docs/
105+ save_markdown_log (os .path .join ("docs" , "test_output.md" ), test_log )
91106model .eval ()
92107test_loss = 0
93108correct = 0
109+ test_accuracies = []
94110with torch .no_grad ():
95111 for inputs , labels in test_loader :
96112 inputs , labels = inputs .to (device ), labels .to (device )
97113 outputs = model (inputs )
98114 test_loss += criterion (outputs , labels ).item ()
99- correct += (outputs .argmax (1 ) == labels ).sum ().item ()
115+ batch_correct = (outputs .argmax (1 ) == labels ).sum ().item ()
116+ correct += batch_correct
117+ batch_acc = batch_correct / len (inputs )
118+ test_accuracies .append (batch_acc )
100119
101120test_accuracy = correct / len (test_loader .dataset )
102121test_loss /= len (test_loader )
@@ -115,7 +134,7 @@ def forward(self, x):
115134 test_log += f" True: { label } , Pred: { pred } \n "
116135
117136# Add a test accuracy image that always overwrites
118- save_test_accuracy_image (test_accuracy , out_dir = "images" )
137+ save_test_accuracy_image (test_accuracies , out_dir = "images" )
119138test_log += "\n ## Test Accuracy Visual\n "
120139test_log += "\n "
121140save_markdown_log ("test_output.md" , test_log )
0 commit comments