@@ -67,19 +67,20 @@ def save_test_samples_images(model, dataloader, device, out_dir=".", num_samples
6767
6868 return results
6969
70- def save_test_accuracy_image (test_accuracy : float , out_dir : str = "." ) -> str :
71- """Create/overwrite an image summarizing test accuracy as text."""
70+ def save_test_accuracy_image (test_accuracies , out_dir : str = "." ) -> str :
71+ """
72+ Create/overwrite an image summarizing test accuracy as a graph, similar to train_accuracy.png.
73+ Accepts a list of test accuracies per batch or sample.
74+ """
7275 os .makedirs (out_dir , exist_ok = True )
73- plt .figure (figsize = (4 , 3 ))
74- plt .axis ('off' )
75- plt .text (0.5 , 0.6 , 'Test Accuracy' , ha = 'center' , va = 'center' , fontsize = 16 , weight = 'bold' )
76- plt .text (0.5 , 0.35 , f"{ test_accuracy * 100 :.2f} %" , ha = 'center' , va = 'center' , fontsize = 22 )
77- # Optional simple bar
78- acc = max (0.0 , min (1.0 , float (test_accuracy )))
79- plt .hlines (0.1 , 0.1 , 0.9 , colors = '#e0e0e0' , linewidth = 12 )
80- plt .hlines (0.1 , 0.1 , 0.1 + 0.8 * acc , colors = '#4caf50' , linewidth = 12 )
76+ plt .figure ()
77+ plt .plot (range (1 , len (test_accuracies )+ 1 ), test_accuracies , marker = 'o' , color = '#4caf50' )
78+ plt .title ('Test Accuracy' )
79+ plt .xlabel ('Batch' )
80+ plt .ylabel ('Accuracy' )
81+ plt .grid (True )
8182 out_path = os .path .join (out_dir , 'test_accuracy.png' )
82- plt .savefig (out_path , bbox_inches = 'tight' )
83+ plt .savefig (out_path )
8384 plt .close ()
8485 return out_path
8586def predict_and_log_samples (model , dataloader , device , num_samples = 5 ):
0 commit comments