Skip to content

Commit b2693e4

Browse files
committed
agentic-fix(try-2): test_accurac issue
1 parent ae394aa commit b2693e4

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

train_mnist.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ def forward(self, x):
104104
save_markdown_log("test_output.md", test_log)
105105
# Mirror test log into docs/
106106
save_markdown_log(os.path.join("docs", "test_output.md"), test_log)
107-
batch_correct = (outputs.argmax(1) == labels).sum().item()
108-
correct += batch_correct
107+
batch_correct = (outputs.argmax(1) == labels).sum().item()
108+
correct += batch_correct
109109
test_log += predict_and_log_samples(model, test_loader, device)
110110
sample_imgs = save_test_samples_images(model, test_loader, device, out_dir="images", num_samples=5)
111111
for i, (idx, label, pred, img_path) in enumerate(sample_imgs, start=1):
@@ -126,8 +126,8 @@ def forward(self, x):
126126
inputs, labels = inputs.to(device), labels.to(device)
127127
outputs = model(inputs)
128128
test_loss += criterion(outputs, labels).item()
129-
batch_correct = (outputs.argmax(1) == labels).sum().item()
130-
correct += batch_correct
129+
batch_correct = (outputs.argmax(1) == labels).sum().item()
130+
correct += batch_correct
131131
batch_acc = batch_correct / len(inputs)
132132
test_accuracies.append(batch_acc)
133133

0 commit comments

Comments
 (0)