Skip to content

Commit b24e9f6

Browse files
committed
fix: overwrite: test_accuracy.png
1 parent b88c8ca commit b24e9f6

1 file changed

Lines changed: 12 additions & 11 deletions

File tree

utils.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,20 @@ def save_test_samples_images(model, dataloader, device, out_dir=".", num_samples
6767

6868
return results
6969

70-
def save_test_accuracy_image(test_accuracy: float, out_dir: str = ".") -> str:
71-
"""Create/overwrite an image summarizing test accuracy as text."""
70+
def save_test_accuracy_image(test_accuracies, out_dir: str = ".") -> str:
71+
"""
72+
Create/overwrite an image summarizing test accuracy as a graph, similar to train_accuracy.png.
73+
Accepts a list of test accuracies per batch or sample.
74+
"""
7275
os.makedirs(out_dir, exist_ok=True)
73-
plt.figure(figsize=(4, 3))
74-
plt.axis('off')
75-
plt.text(0.5, 0.6, 'Test Accuracy', ha='center', va='center', fontsize=16, weight='bold')
76-
plt.text(0.5, 0.35, f"{test_accuracy * 100:.2f}%", ha='center', va='center', fontsize=22)
77-
# Optional simple bar
78-
acc = max(0.0, min(1.0, float(test_accuracy)))
79-
plt.hlines(0.1, 0.1, 0.9, colors='#e0e0e0', linewidth=12)
80-
plt.hlines(0.1, 0.1, 0.1 + 0.8 * acc, colors='#4caf50', linewidth=12)
76+
plt.figure()
77+
plt.plot(range(1, len(test_accuracies)+1), test_accuracies, marker='o', color='#4caf50')
78+
plt.title('Test Accuracy')
79+
plt.xlabel('Batch')
80+
plt.ylabel('Accuracy')
81+
plt.grid(True)
8182
out_path = os.path.join(out_dir, 'test_accuracy.png')
82-
plt.savefig(out_path, bbox_inches='tight')
83+
plt.savefig(out_path)
8384
plt.close()
8485
return out_path
8586
def predict_and_log_samples(model, dataloader, device, num_samples=5):

0 commit comments

Comments
 (0)