Skip to content

Commit 59c732b

Browse files
committed
Working classification for grid cell memories
1 parent e06080f commit 59c732b

5 files changed

Lines changed: 108 additions & 90 deletions

File tree

scripts/Chris/DQN/Eval.ipynb

Lines changed: 74 additions & 74 deletions
Large diffs are not rendered by default.

scripts/Chris/DQN/classify_recalls.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __getitem__(self, idx):
1616
# Compress spike train into windows for dimension reduction
1717
return self.samples[idx].flatten(), self.labels[idx]
1818

19-
def classify_recalls(out_dim, train_ratio, batch_size):
19+
def classify_recalls(out_dim, train_ratio, batch_size, epochs):
2020
print("Classifying recalled memories...")
2121

2222
## Load recalled memory samples ##
@@ -38,7 +38,7 @@ def classify_recalls(out_dim, train_ratio, batch_size):
3838
## Training ##
3939
loss_log = []
4040
accuracy_log = []
41-
for epoch in range(20):
41+
for epoch in range(epochs):
4242
total_loss = 0
4343
correct = 0
4444
for memory_batch, positions in train_loader:
@@ -68,13 +68,29 @@ def classify_recalls(out_dim, train_ratio, batch_size):
6868
## Testing ##
6969
total = 0
7070
correct = 0
71+
confusion_matrix = torch.zeros(25, 25)
72+
out_of_bounds = 0
7173
with torch.no_grad():
7274
for memories, labels in test_loader:
7375
outputs = model(memories)
7476
loss = criterion(outputs, labels)
7577
total += len(labels)
7678
correct += torch.all(outputs.round() == labels.round(),
7779
dim=1).sum().item() # Check if prediction for both x and y are correct
80+
for t, p in zip(labels, outputs):
81+
label_ind = int(t[0].round() * 5 + t[1].round())
82+
pred_ind = int(p[0].round() * 5 + p[1].round())
83+
if label_ind < 0 or label_ind >= 25 or pred_ind < 0 or pred_ind >= 25:
84+
out_of_bounds += 1
85+
else:
86+
confusion_matrix[label_ind, pred_ind] += 1
87+
88+
plt.imshow(confusion_matrix)
89+
plt.title('Confusion Matrix')
90+
plt.xlabel('Predicted')
91+
plt.ylabel('True Label')
92+
plt.colorbar()
93+
plt.show()
7894

7995
print(f'Accuracy: {round(correct / total, 3)*100}%')
8096

scripts/Chris/DQN/pipeline_executor.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
## Constants ##
1212
WIDTH = 5
1313
HEIGHT = 5
14-
SAMPLES_PER_POS = 10
14+
SAMPLES_PER_POS = 1000
1515
NOISE = 0.1 # Noise in sampling
1616
WINDOW_FREQ = 10
1717
WINDOW_SIZE = 10
@@ -29,6 +29,7 @@
2929
OUT_DIM = 2
3030
TRAIN_RATIO = 0.8
3131
BATCH_SIZE = 10
32+
TRAIN_EPOCHS = 15
3233
PLOT = True
3334
exc_hyper_params = {
3435
'thresh_exc': -55,
@@ -63,15 +64,15 @@
6364
#
6465
# # Spike Train Generation ##
6566
# spike_trains, labels, sorted_spike_trains = spike_train_generator(samples, labels, SIM_TIME, GC_MULTIPLES, MAX_SPIKE_FREQ)
66-
67-
# ## Association (Store) ##
68-
store_reservoir(EXC_SIZE, INH_SIZE, STORE_SAMPLES, NUM_CELLS, GC_MULTIPLES, SIM_TIME, hyper_params, PLOT)
69-
70-
# ## Association (Recall) ##
71-
recall_reservoir(EXC_SIZE, INH_SIZE, SIM_TIME, PLOT)
72-
67+
#
68+
# # ## Association (Store) ##
69+
# store_reservoir(EXC_SIZE, INH_SIZE, STORE_SAMPLES, NUM_CELLS, GC_MULTIPLES, SIM_TIME, hyper_params, PLOT)
70+
#
71+
# # ## Association (Recall) ##
72+
# recall_reservoir(EXC_SIZE, INH_SIZE, SIM_TIME, PLOT)
73+
#
7374
# # Preprocess Recalls ##
7475
# recalled_mem_preprocessing(WINDOW_FREQ, WINDOW_SIZE, PLOT)
7576

76-
## Train ANN ##
77-
# classify_recalls(OUT_DIM, TRAIN_RATIO, BATCH_SIZE)
77+
# Train ANN ##
78+
classify_recalls(OUT_DIM, TRAIN_RATIO, BATCH_SIZE, TRAIN_EPOCHS)

scripts/Chris/DQN/recalled_mem_preprocessing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def recalled_mem_preprocessing(window_freq, window_size, plot):
5656
# plt.tight_layout()
5757
# plt.show()
5858

59+
positions = np.array([key for key in new_samples_sorted.keys()])
5960
fig = plt.figure(figsize=(10, 10))
6061
gs = fig.add_gridspec(nrows=5, ncols=5)
6162
for i, pos in enumerate(positions):

scripts/Chris/DQN/store_reservoir.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ def store_reservoir(exc_size, inh_size, num_samples, num_grid_cells, gc_multiple
1616
w_exc_exc = torch.rand(exc_size, exc_size)
1717
w_exc_inh = torch.rand(exc_size, inh_size)
1818
w_inh_exc = -torch.rand(inh_size, exc_size)
19-
w_inh_inh = torch.rand(inh_size, inh_size)
19+
w_inh_inh = -torch.rand(inh_size, inh_size)
2020
w_in_exc = sparsify(w_in_exc, 0.85) # 0 x% of weights
2121
w_in_inh = sparsify(w_in_inh, 0.85)
22-
w_exc_exc = sparsify(w_exc_exc, 0.85)
23-
w_exc_inh = sparsify(w_exc_inh, 0.85)
24-
w_inh_exc = sparsify(w_inh_exc, 0.85)
22+
w_exc_exc = sparsify(w_exc_exc, 0.8)
23+
w_exc_inh = sparsify(w_exc_inh, 0.5)
24+
w_inh_exc = sparsify(w_inh_exc, 0.7)
2525
w_inh_inh = sparsify(w_inh_inh, 0.85)
2626
res = Reservoir(in_size, exc_size, inh_size, hyper_params,
2727
w_in_exc, w_in_inh, w_exc_exc, w_exc_inh, w_inh_exc, w_inh_inh)

0 commit comments

Comments
 (0)