@@ -16,7 +16,7 @@ def __getitem__(self, idx):
1616 # Compress spike train into windows for dimension reduction
1717 return self .samples [idx ].flatten (), self .labels [idx ]
1818
19- def classify_recalls (out_dim , train_ratio , batch_size ):
19+ def classify_recalls (out_dim , train_ratio , batch_size , epochs ):
2020 print ("Classifying recalled memories..." )
2121
2222 ## Load recalled memory samples ##
@@ -38,7 +38,7 @@ def classify_recalls(out_dim, train_ratio, batch_size):
3838 ## Training ##
3939 loss_log = []
4040 accuracy_log = []
41- for epoch in range (20 ):
41+ for epoch in range (epochs ):
4242 total_loss = 0
4343 correct = 0
4444 for memory_batch , positions in train_loader :
@@ -68,13 +68,29 @@ def classify_recalls(out_dim, train_ratio, batch_size):
6868 ## Testing ##
6969 total = 0
7070 correct = 0
71+ confusion_matrix = torch .zeros (25 , 25 )
72+ out_of_bounds = 0
7173 with torch .no_grad ():
7274 for memories , labels in test_loader :
7375 outputs = model (memories )
7476 loss = criterion (outputs , labels )
7577 total += len (labels )
7678 correct += torch .all (outputs .round () == labels .round (),
7779 dim = 1 ).sum ().item () # Check if prediction for both x and y are correct
80+ for t , p in zip (labels , outputs ):
81+ label_ind = int (t [0 ].round () * 5 + t [1 ].round ())
82+ pred_ind = int (p [0 ].round () * 5 + p [1 ].round ())
83+ if label_ind < 0 or label_ind >= 25 or pred_ind < 0 or pred_ind >= 25 :
84+ out_of_bounds += 1
85+ else :
86+ confusion_matrix [label_ind , pred_ind ] += 1
87+
88+ plt .imshow (confusion_matrix )
89+ plt .title ('Confusion Matrix' )
90+ plt .xlabel ('Predicted' )
91+ plt .ylabel ('True Label' )
92+ plt .colorbar ()
93+ plt .show ()
7894
7995 print (f'Accuracy: { round (correct / total , 3 )* 100 } %' )
8096
0 commit comments