Skip to content

Commit c6644c2

Browse files
committed
fix: test_accuracy
1 parent b24e9f6 commit c6644c2

1 file changed

Lines changed: 21 additions & 2 deletions

File tree

train_mnist.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,34 @@ def forward(self, x):
8888
)
8989

9090
# Test Loop
91+
test_accuracy = correct / len(test_loader.dataset)
92+
test_loss /= len(test_loader)
93+
test_log += f"- **Test Loss**: {test_loss:.4f}\n"
94+
test_log += f"- **Test Accuracy**: {test_accuracy:.4f}\n\n"
95+
test_log += "## Sample Predictions\n\n"
96+
test_log += "| Image Index | True Label | Predicted Label |\n"
97+
test_log += "|-------------|------------|------------------|\n"
98+
test_log += predict_and_log_samples(model, test_loader, device)
99+
sample_imgs = save_test_samples_images(model, test_loader, device, out_dir="images", num_samples=5)
100+
test_log += "\n### Sample Images\n"
101+
for i, (idx, label, pred, img_path) in enumerate(sample_imgs, start=1):
102+
test_log += "![Test Accuracy](images/test_accuracy.png)\n"
103+
save_markdown_log("test_output.md", test_log)
104+
# Mirror test log into docs/
105+
save_markdown_log(os.path.join("docs", "test_output.md"), test_log)
91106
model.eval()
92107
test_loss = 0
93108
correct = 0
109+
test_accuracies = []
94110
with torch.no_grad():
95111
for inputs, labels in test_loader:
96112
inputs, labels = inputs.to(device), labels.to(device)
97113
outputs = model(inputs)
98114
test_loss += criterion(outputs, labels).item()
99-
correct += (outputs.argmax(1) == labels).sum().item()
115+
batch_correct = (outputs.argmax(1) == labels).sum().item()
116+
correct += batch_correct
117+
batch_acc = batch_correct / len(inputs)
118+
test_accuracies.append(batch_acc)
100119

101120
test_accuracy = correct / len(test_loader.dataset)
102121
test_loss /= len(test_loader)
@@ -115,7 +134,7 @@ def forward(self, x):
115134
test_log += f"![Sample {i:02d}](images/{img_name}) True: {label}, Pred: {pred}\n"
116135

117136
# Add a test accuracy image that always overwrites
118-
save_test_accuracy_image(test_accuracy, out_dir="images")
137+
save_test_accuracy_image(test_accuracies, out_dir="images")
119138
test_log += "\n## Test Accuracy Visual\n"
120139
test_log += "![Test Accuracy](images/test_accuracy.png)\n"
121140
save_markdown_log("test_output.md", test_log)

0 commit comments

Comments
 (0)