Skip to content

Commit e06080f

Browse files
committed
Split inhib/exc populations
1 parent 8e1fa07 commit e06080f

5 files changed

Lines changed: 226 additions & 102 deletions

File tree

scripts/Chris/DQN/Eval.ipynb

Lines changed: 66 additions & 10 deletions
Large diffs are not rendered by default.

scripts/Chris/DQN/Reservoir.py

Lines changed: 77 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,57 +7,104 @@
77

88

99
class Reservoir(Network):
10-
def __init__(self, in_size, res_size, hyper_params,
11-
w_in_res, w_res_res, device='cpu'):
10+
def __init__(self, in_size, exc_size, inh_size, hyper_params,
11+
w_in_exc, w_in_inh, w_exc_exc, w_exc_inh, w_inh_exc, w_inh_inh,
12+
device='cpu'):
1213
super().__init__()
1314

1415
## Layers ##
1516
input = Input(n=in_size)
16-
res = AdaptiveLIFNodes(
17-
n=res_size,
18-
thresh=hyper_params['thresh'],
19-
theta_plus=hyper_params['theta_plus'],
20-
refrac=hyper_params['refrac'],
21-
reset=hyper_params['reset'],
22-
tc_theta_decay=hyper_params['tc_theta_decay'],
23-
tc_decay=hyper_params['tc_decay'],
17+
res_exc = AdaptiveLIFNodes(
18+
n=exc_size,
19+
thresh=hyper_params['thresh_exc'],
20+
theta_plus=hyper_params['theta_plus_exc'],
21+
refrac=hyper_params['refrac_exc'],
22+
reset=hyper_params['reset_exc'],
23+
tc_theta_decay=hyper_params['tc_theta_decay_exc'],
24+
tc_decay=hyper_params['tc_decay_exc'],
2425
traces=True,
2526
)
26-
res_monitor = Monitor(res, ["s"], device=device)
27-
self.add_monitor(res_monitor, name='res_monitor')
28-
self.res_monitor = res_monitor
27+
exc_monitor = Monitor(res_exc, ["s"], device=device)
28+
self.add_monitor(exc_monitor, name='res_monitor_exc')
29+
self.exc_monitor = exc_monitor
30+
res_inh = AdaptiveLIFNodes(
31+
n=inh_size,
32+
thresh=hyper_params['thresh_inh'],
33+
theta_plus=hyper_params['theta_plus_inh'],
34+
refrac=hyper_params['refrac_inh'],
35+
reset=hyper_params['reset_inh'],
36+
tc_theta_decay=hyper_params['tc_theta_decay_inh'],
37+
tc_decay=hyper_params['tc_decay_inh'],
38+
traces=True,
39+
)
40+
inh_monitor = Monitor(res_inh, ["s"], device=device)
41+
self.add_monitor(inh_monitor, name='res_monitor_inh')
42+
self.inh_monitor = inh_monitor
2943
self.add_layer(input, name='input')
30-
self.add_layer(res, name='res')
44+
self.add_layer(res_exc, name='res_exc')
45+
self.add_layer(res_inh, name='res_inh')
3146

3247
## Connections ##
33-
in_res_wfeat = Weight(name='in_res_weight_feature', value=w_in_res,)
34-
in_res_conn = MulticompartmentConnection(
35-
source=input, target=res,
36-
device=device, pipeline=[in_res_wfeat],
48+
in_exc_wfeat = Weight(name='in_exc_weight_feature', value=w_in_exc,)
49+
in_exc_conn = MulticompartmentConnection(
50+
source=input, target=res_exc,
51+
device=device, pipeline=[in_exc_wfeat],
52+
)
53+
in_inh_wfeat = Weight(name='in_inh_weight_feature', value=w_in_inh,)
54+
in_inh_conn = MulticompartmentConnection(
55+
source=input, target=res_inh,
56+
device=device, pipeline=[in_inh_wfeat],
57+
)
58+
59+
exc_exc_wfeat = Weight(name='exc_exc_weight_feature', value=w_exc_exc,)
60+
# learning_rule=MSTDP,
61+
# nu=hyper_params['nu_exc_exc'], range=hyper_params['range_exc_exc'], decay=hyper_params['decay_exc_exc'])
62+
exc_exc_conn = MulticompartmentConnection(
63+
source=res_exc, target=res_exc,
64+
device=device, pipeline=[exc_exc_wfeat],
65+
)
66+
exc_inh_wfeat = Weight(name='exc_inh_weight_feature', value=w_exc_inh,)
67+
# learning_rule=MSTDP,
68+
# nu=hyper_params['nu_exc_inh'], range=hyper_params['range_exc_inh'], decay=hyper_params['decay_exc_inh'])
69+
exc_inh_conn = MulticompartmentConnection(
70+
source=res_exc, target=res_inh,
71+
device=device, pipeline=[exc_inh_wfeat],
72+
)
73+
inh_exc_wfeat = Weight(name='inh_exc_weight_feature', value=w_inh_exc,)
74+
# learning_rule=MSTDP,
75+
# nu=hyper_params['nu_inh_exc'], range=hyper_params['range_inh_exc'], decay=hyper_params['decay_inh_exc'])
76+
inh_exc_conn = MulticompartmentConnection(
77+
source=res_inh, target=res_exc,
78+
device=device, pipeline=[inh_exc_wfeat],
3779
)
38-
res_res_wfeat = Weight(name='res_res_weight_feature', value=w_res_res,
80+
inh_inh_wfeat = Weight(name='inh_inh_weight_feature', value=w_inh_inh,)
3981
# learning_rule=MSTDP,
40-
nu=hyper_params['nu'], range=hyper_params['range'], decay=hyper_params['decay'])
41-
res_res_conn = MulticompartmentConnection(
42-
source=res, target=res,
43-
device=device, pipeline=[res_res_wfeat],
82+
# nu=hyper_params['nu_inh_inh'], range=hyper_params['range_inh_inh'], decay=hyper_params['decay_inh_inh'])
83+
inh_inh_conn = MulticompartmentConnection(
84+
source=res_inh, target=res_inh,
85+
device=device, pipeline=[inh_inh_wfeat],
4486
)
45-
self.add_connection(in_res_conn, source='input', target='res')
46-
self.add_connection(res_res_conn, source='res', target='res')
47-
self.res_res_conn = res_res_conn
87+
self.add_connection(in_exc_conn, source='input', target='res_exc')
88+
self.add_connection(in_inh_conn, source='input', target='res_inh')
89+
self.add_connection(exc_exc_conn, source='res_exc', target='res_exc')
90+
self.add_connection(exc_inh_conn, source='res_exc', target='res_inh')
91+
self.add_connection(inh_exc_conn, source='res_inh', target='res_exc')
92+
self.add_connection(inh_inh_conn, source='res_inh', target='res_inh')
4893

4994
## Migrate ##
5095
self.to(device)
5196

5297
def store(self, spike_train, sim_time):
5398
self.learning = True
5499
self.run(inputs={'input': spike_train}, time=sim_time, reward=1)
55-
res_spikes = self.res_monitor.get('s')
100+
exc_spikes = self.exc_monitor.get('s')
101+
inh_spikes = self.inh_monitor.get('s')
56102
self.learning = False
57-
return res_spikes
103+
return exc_spikes, inh_spikes
58104

59105
def recall(self, spike_train, sim_time):
60106
self.learning = False
61107
self.run(inputs={'input': spike_train}, time=sim_time,)
62-
res_spikes = self.res_monitor.get('s')
63-
return res_spikes
108+
exc_spikes = self.exc_monitor.get('s')
109+
inh_spikes = self.inh_monitor.get('s')
110+
return exc_spikes, inh_spikes

scripts/Chris/DQN/pipeline_executor.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,53 +11,64 @@
1111
## Constants ##
1212
WIDTH = 5
1313
HEIGHT = 5
14-
SAMPLES_PER_POS = 1000
14+
SAMPLES_PER_POS = 10
1515
NOISE = 0.1 # Noise in sampling
1616
WINDOW_FREQ = 10
1717
WINDOW_SIZE = 10
1818
NUM_CELLS = 20
1919
X_RANGE = (0, 5)
2020
Y_RANGE = (0, 5)
2121
SIM_TIME = 50
22-
MAX_SPIKE_FREQ = 0.3
22+
MAX_SPIKE_FREQ = 0.8
2323
GC_MULTIPLES = 1
24-
RES_SIZE = 250
24+
EXC_SIZE = 250
25+
INH_SIZE = 50
2526
STORE_SAMPLES = 100
26-
hyper_params = {
27-
'thresh': -55,
28-
'theta_plus': 0,
29-
'refrac': 1,
30-
'reset': -65,
31-
'tc_theta_decay': 500,
32-
'tc_decay': 30,
33-
'nu': (0.01, -0.01),
34-
'range': [-1, 1],
35-
'decay': None,
36-
}
3727
WINDOW_FREQ = 10
3828
WINDOW_SIZE = 10
3929
OUT_DIM = 2
4030
TRAIN_RATIO = 0.8
4131
BATCH_SIZE = 10
4232
PLOT = True
33+
exc_hyper_params = {
34+
'thresh_exc': -55,
35+
'theta_plus_exc': 0,
36+
'refrac_exc': 1,
37+
'reset_exc': -65,
38+
'tc_theta_decay_exc': 500,
39+
'tc_decay_exc': 30,
40+
# 'nu': (0.01, -0.01),
41+
# 'range': [-1, 1],
42+
# 'decay': None,
43+
}
44+
inh_hyper_params = {
45+
'thresh_inh': -55,
46+
'theta_plus_inh': 0,
47+
'refrac_inh': 1,
48+
'reset_inh': -65,
49+
'tc_theta_decay_inh': 500,
50+
'tc_decay_inh': 30,
51+
}
52+
hyper_params = exc_hyper_params | inh_hyper_params
4353

4454
## Sample Generation ##
45-
x_offsets = np.random.uniform(-1, 1, NUM_CELLS)
46-
y_offsets = np.random.uniform(-1, 1, NUM_CELLS)
47-
offsets = list(zip(x_offsets, y_offsets)) # Grid Cell x & y offsets
48-
scales = [np.random.uniform(1.7, 5) for i in range(NUM_CELLS)] # Dist. between Grid Cell peaks
49-
vars = [.85] * NUM_CELLS # Variance of Grid Cell activity
55+
# x_offsets = np.random.uniform(-1, 1, NUM_CELLS)
56+
#
57+
# y_offsets = np.random.uniform(-1, 1, NUM_CELLS)
58+
# offsets = list(zip(x_offsets, y_offsets)) # Grid Cell x & y offsets
59+
# scales = [np.random.uniform(1.7, 5) for i in range(NUM_CELLS)] # Dist. between Grid Cell peaks
60+
# vars = [.85] * NUM_CELLS # Variance of Grid Cell activity
5061
# samples, labels, sorted_samples = sample_generator(scales, offsets, vars, X_RANGE, Y_RANGE, SAMPLES_PER_POS,
5162
# noise=NOISE, padding=1, plot=PLOT)
5263
#
53-
# ## Spike Train Generation ##
64+
# # Spike Train Generation ##
5465
# spike_trains, labels, sorted_spike_trains = spike_train_generator(samples, labels, SIM_TIME, GC_MULTIPLES, MAX_SPIKE_FREQ)
5566

56-
## Association (Store) ##
57-
store_reservoir(RES_SIZE, STORE_SAMPLES, NUM_CELLS, GC_MULTIPLES, SIM_TIME, hyper_params, PLOT)
67+
# ## Association (Store) ##
68+
store_reservoir(EXC_SIZE, INH_SIZE, STORE_SAMPLES, NUM_CELLS, GC_MULTIPLES, SIM_TIME, hyper_params, PLOT)
5869

59-
## Association (Recall) ##
60-
# recall_reservoir(RES_SIZE, SIM_TIME, PLOT)
70+
# ## Association (Recall) ##
71+
recall_reservoir(EXC_SIZE, INH_SIZE, SIM_TIME, PLOT)
6172

6273
# # Preprocess Recalls ##
6374
# recalled_mem_preprocessing(WINDOW_FREQ, WINDOW_SIZE, PLOT)

scripts/Chris/DQN/recall_reservoir.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from matplotlib import pyplot as plt
55

6-
def recall_reservoir(res_size, sim_time, plot=False):
6+
def recall_reservoir(exc_size, inh_size, sim_time, plot=False):
77
print("Recalling memories...")
88

99
## Load memory module and memory keys ##
@@ -13,16 +13,18 @@ def recall_reservoir(res_size, sim_time, plot=False):
1313
memory_keys, labels = pkl.load(f)
1414

1515
## Recall memories ##
16-
recalled_memories = np.zeros((len(memory_keys), sim_time, res_size))
16+
# TODO: Plot output spikes according to inh exc populations
17+
recalled_memories = np.zeros((len(memory_keys), sim_time, exc_size + inh_size))
1718
recalled_memories_sorted = {}
1819
for i, (key, label) in enumerate(zip(memory_keys, labels)):
19-
res_spike_train = res_module.recall(torch.tensor(key.reshape(sim_time, -1)), sim_time=sim_time) # Recall the sample
20-
recalled_memories[i] = res_spike_train.squeeze() # Store the recalled memory
20+
exc_spikes, inh_spikes = res_module.recall(torch.tensor(key.reshape(sim_time, -1)), sim_time=sim_time) # Recall the sample
21+
all_spikes = torch.cat((exc_spikes, inh_spikes), dim=2).squeeze()
22+
recalled_memories[i] = all_spikes # Store the recalled memory
2123
label = tuple(label.round())
2224
if label not in recalled_memories_sorted:
23-
recalled_memories_sorted[label] = [res_spike_train.squeeze()]
25+
recalled_memories_sorted[label] = [all_spikes]
2426
else:
25-
recalled_memories_sorted[label].append(res_spike_train.squeeze())
27+
recalled_memories_sorted[label].append(all_spikes)
2628

2729
## Save recalled memories ##
2830
with open('Data/recalled_memories.pkl', 'wb') as f:

scripts/Chris/DQN/store_reservoir.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,45 @@
55
import numpy as np
66
from matplotlib import pyplot as plt
77

8-
def store_reservoir(res_size, num_samples, num_grid_cells, gc_multiples, sim_time,
8+
def store_reservoir(exc_size, inh_size, num_samples, num_grid_cells, gc_multiples, sim_time,
99
hyper_params, plot=False):
1010
print("Storing memories...")
1111

1212
## Create synaptic weights ##
1313
in_size = num_grid_cells * gc_multiples
14-
w_in_res = torch.rand(in_size, res_size)
15-
w_res_res = torch.rand(res_size, res_size)
16-
w_in_res = sparsify(w_in_res, 0.85)
17-
w_res_res = sparsify(w_res_res, 0.85)
18-
w_res_res = assign_inhibition(w_res_res, 0.2, 1)
19-
res = Reservoir(in_size, res_size, hyper_params, w_in_res, w_res_res)
14+
w_in_exc = torch.rand(in_size, exc_size) # Initialize weights
15+
w_in_inh = torch.rand(in_size, inh_size)
16+
w_exc_exc = torch.rand(exc_size, exc_size)
17+
w_exc_inh = torch.rand(exc_size, inh_size)
18+
w_inh_exc = -torch.rand(inh_size, exc_size)
19+
w_inh_inh = torch.rand(inh_size, inh_size)
20+
w_in_exc = sparsify(w_in_exc, 0.85) # 0 x% of weights
21+
w_in_inh = sparsify(w_in_inh, 0.85)
22+
w_exc_exc = sparsify(w_exc_exc, 0.85)
23+
w_exc_inh = sparsify(w_exc_inh, 0.85)
24+
w_inh_exc = sparsify(w_inh_exc, 0.85)
25+
w_inh_inh = sparsify(w_inh_inh, 0.85)
26+
res = Reservoir(in_size, exc_size, inh_size, hyper_params,
27+
w_in_exc, w_in_inh, w_exc_exc, w_exc_inh, w_inh_exc, w_inh_inh)
2028

2129
## Load grid cell spike-train samples ##
2230
with open('Data/grid_cell_spk_trains.pkl', 'rb') as f:
2331
grid_cell_data, labels = pkl.load(f) # (samples, time, num_cells)
2432

2533
## Store memories ##
2634
# -> STDP active
27-
if plot:
28-
fig, ax = plt.subplots(2, 2, figsize=(10, 5))
29-
im = ax[0, 0].imshow(w_in_res)
30-
ax[0, 0].set_title("Initial Input-to-Res")
31-
plt.colorbar(im, ax=ax[0, 0])
32-
ax[0, 0].set_xlabel("Res Neuron")
33-
ax[0, 0].set_ylabel("Input Neuron")
34-
im = ax[0, 1].imshow(w_res_res)
35-
ax[0, 1].set_title("Initial Res-to-Res")
36-
plt.colorbar(im, ax=ax[0, 1])
37-
ax[0, 1].set_xlabel("Res Neuron")
38-
ax[0, 1].set_ylabel("Res Neuron")
35+
# if plot:
36+
# fig, ax = plt.subplots(2, 2, figsize=(10, 5))
37+
# im = ax[0, 0].imshow(w_in_res)
38+
# ax[0, 0].set_title("Initial Input-to-Res")
39+
# plt.colorbar(im, ax=ax[0, 0])
40+
# ax[0, 0].set_xlabel("Res Neuron")
41+
# ax[0, 0].set_ylabel("Input Neuron")
42+
# im = ax[0, 1].imshow(w_res_res)
43+
# ax[0, 1].set_title("Initial Res-to-Res")
44+
# plt.colorbar(im, ax=ax[0, 1])
45+
# ax[0, 1].set_xlabel("Res Neuron")
46+
# ax[0, 1].set_ylabel("Res Neuron")
3947

4048
# Store samples
4149
sample_inds = np.random.choice(len(grid_cell_data), num_samples, replace=False)
@@ -46,19 +54,19 @@ def store_reservoir(res_size, num_samples, num_grid_cells, gc_multiples, sim_tim
4654
res.store(torch.tensor(s.reshape(sim_time, -1)), sim_time=sim_time)
4755
res.reset_state_variables()
4856

49-
if plot:
50-
im = ax[1, 0].imshow(w_in_res)
51-
ax[1, 0].set_title("Final Input-to-Res")
52-
plt.colorbar(im, ax=ax[1, 0])
53-
ax[1, 0].set_xlabel("Res Neuron")
54-
ax[1, 0].set_ylabel("Input Neuron")
55-
im = ax[1, 1].imshow(w_res_res)
56-
ax[1, 1].set_title("Final Res-to-Res")
57-
plt.colorbar(im, ax=ax[1, 1])
58-
ax[1, 1].set_xlabel("Res Neuron")
59-
ax[1, 1].set_ylabel("Res Neuron")
60-
plt.tight_layout()
61-
plt.show()
57+
# if plot:
58+
# im = ax[1, 0].imshow(w_in_res)
59+
# ax[1, 0].set_title("Final Input-to-Res")
60+
# plt.colorbar(im, ax=ax[1, 0])
61+
# ax[1, 0].set_xlabel("Res Neuron")
62+
# ax[1, 0].set_ylabel("Input Neuron")
63+
# im = ax[1, 1].imshow(w_res_res)
64+
# ax[1, 1].set_title("Final Res-to-Res")
65+
# plt.colorbar(im, ax=ax[1, 1])
66+
# ax[1, 1].set_xlabel("Res Neuron")
67+
# ax[1, 1].set_ylabel("Res Neuron")
68+
# plt.tight_layout()
69+
# plt.show()
6270

6371
## Save ##
6472
with open('Data/reservoir_module.pkl', 'wb') as f:

0 commit comments

Comments
 (0)