|
7 | 7 |
|
8 | 8 |
|
9 | 9 | class Reservoir(Network): |
10 | | - def __init__(self, in_size, res_size, hyper_params, |
11 | | - w_in_res, w_res_res, device='cpu'): |
| 10 | + def __init__(self, in_size, exc_size, inh_size, hyper_params, |
| 11 | + w_in_exc, w_in_inh, w_exc_exc, w_exc_inh, w_inh_exc, w_inh_inh, |
| 12 | + device='cpu'): |
12 | 13 | super().__init__() |
13 | 14 |
|
14 | 15 | ## Layers ## |
15 | 16 | input = Input(n=in_size) |
16 | | - res = AdaptiveLIFNodes( |
17 | | - n=res_size, |
18 | | - thresh=hyper_params['thresh'], |
19 | | - theta_plus=hyper_params['theta_plus'], |
20 | | - refrac=hyper_params['refrac'], |
21 | | - reset=hyper_params['reset'], |
22 | | - tc_theta_decay=hyper_params['tc_theta_decay'], |
23 | | - tc_decay=hyper_params['tc_decay'], |
| 17 | + res_exc = AdaptiveLIFNodes( |
| 18 | + n=exc_size, |
| 19 | + thresh=hyper_params['thresh_exc'], |
| 20 | + theta_plus=hyper_params['theta_plus_exc'], |
| 21 | + refrac=hyper_params['refrac_exc'], |
| 22 | + reset=hyper_params['reset_exc'], |
| 23 | + tc_theta_decay=hyper_params['tc_theta_decay_exc'], |
| 24 | + tc_decay=hyper_params['tc_decay_exc'], |
24 | 25 | traces=True, |
25 | 26 | ) |
26 | | - res_monitor = Monitor(res, ["s"], device=device) |
27 | | - self.add_monitor(res_monitor, name='res_monitor') |
28 | | - self.res_monitor = res_monitor |
| 27 | + exc_monitor = Monitor(res_exc, ["s"], device=device) |
| 28 | + self.add_monitor(exc_monitor, name='res_monitor_exc') |
| 29 | + self.exc_monitor = exc_monitor |
| 30 | + res_inh = AdaptiveLIFNodes( |
| 31 | + n=inh_size, |
| 32 | + thresh=hyper_params['thresh_inh'], |
| 33 | + theta_plus=hyper_params['theta_plus_inh'], |
| 34 | + refrac=hyper_params['refrac_inh'], |
| 35 | + reset=hyper_params['reset_inh'], |
| 36 | + tc_theta_decay=hyper_params['tc_theta_decay_inh'], |
| 37 | + tc_decay=hyper_params['tc_decay_inh'], |
| 38 | + traces=True, |
| 39 | + ) |
| 40 | + inh_monitor = Monitor(res_inh, ["s"], device=device) |
| 41 | + self.add_monitor(inh_monitor, name='res_monitor_inh') |
| 42 | + self.inh_monitor = inh_monitor |
29 | 43 | self.add_layer(input, name='input') |
30 | | - self.add_layer(res, name='res') |
| 44 | + self.add_layer(res_exc, name='res_exc') |
| 45 | + self.add_layer(res_inh, name='res_inh') |
31 | 46 |
|
32 | 47 | ## Connections ## |
33 | | - in_res_wfeat = Weight(name='in_res_weight_feature', value=w_in_res,) |
34 | | - in_res_conn = MulticompartmentConnection( |
35 | | - source=input, target=res, |
36 | | - device=device, pipeline=[in_res_wfeat], |
| 48 | + in_exc_wfeat = Weight(name='in_exc_weight_feature', value=w_in_exc,) |
| 49 | + in_exc_conn = MulticompartmentConnection( |
| 50 | + source=input, target=res_exc, |
| 51 | + device=device, pipeline=[in_exc_wfeat], |
| 52 | + ) |
| 53 | + in_inh_wfeat = Weight(name='in_inh_weight_feature', value=w_in_inh,) |
| 54 | + in_inh_conn = MulticompartmentConnection( |
| 55 | + source=input, target=res_inh, |
| 56 | + device=device, pipeline=[in_inh_wfeat], |
| 57 | + ) |
| 58 | + |
| 59 | + exc_exc_wfeat = Weight(name='exc_exc_weight_feature', value=w_exc_exc,) |
| 60 | + # learning_rule=MSTDP, |
| 61 | + # nu=hyper_params['nu_exc_exc'], range=hyper_params['range_exc_exc'], decay=hyper_params['decay_exc_exc']) |
| 62 | + exc_exc_conn = MulticompartmentConnection( |
| 63 | + source=res_exc, target=res_exc, |
| 64 | + device=device, pipeline=[exc_exc_wfeat], |
| 65 | + ) |
| 66 | + exc_inh_wfeat = Weight(name='exc_inh_weight_feature', value=w_exc_inh,) |
| 67 | + # learning_rule=MSTDP, |
| 68 | + # nu=hyper_params['nu_exc_inh'], range=hyper_params['range_exc_inh'], decay=hyper_params['decay_exc_inh']) |
| 69 | + exc_inh_conn = MulticompartmentConnection( |
| 70 | + source=res_exc, target=res_inh, |
| 71 | + device=device, pipeline=[exc_inh_wfeat], |
| 72 | + ) |
| 73 | + inh_exc_wfeat = Weight(name='inh_exc_weight_feature', value=w_inh_exc,) |
| 74 | + # learning_rule=MSTDP, |
| 75 | + # nu=hyper_params['nu_inh_exc'], range=hyper_params['range_inh_exc'], decay=hyper_params['decay_inh_exc']) |
| 76 | + inh_exc_conn = MulticompartmentConnection( |
| 77 | + source=res_inh, target=res_exc, |
| 78 | + device=device, pipeline=[inh_exc_wfeat], |
37 | 79 | ) |
38 | | - res_res_wfeat = Weight(name='res_res_weight_feature', value=w_res_res, |
| 80 | + inh_inh_wfeat = Weight(name='inh_inh_weight_feature', value=w_inh_inh,) |
39 | 81 | # learning_rule=MSTDP, |
40 | | - nu=hyper_params['nu'], range=hyper_params['range'], decay=hyper_params['decay']) |
41 | | - res_res_conn = MulticompartmentConnection( |
42 | | - source=res, target=res, |
43 | | - device=device, pipeline=[res_res_wfeat], |
| 82 | + # nu=hyper_params['nu_inh_inh'], range=hyper_params['range_inh_inh'], decay=hyper_params['decay_inh_inh']) |
| 83 | + inh_inh_conn = MulticompartmentConnection( |
| 84 | + source=res_inh, target=res_inh, |
| 85 | + device=device, pipeline=[inh_inh_wfeat], |
44 | 86 | ) |
45 | | - self.add_connection(in_res_conn, source='input', target='res') |
46 | | - self.add_connection(res_res_conn, source='res', target='res') |
47 | | - self.res_res_conn = res_res_conn |
| 87 | + self.add_connection(in_exc_conn, source='input', target='res_exc') |
| 88 | + self.add_connection(in_inh_conn, source='input', target='res_inh') |
| 89 | + self.add_connection(exc_exc_conn, source='res_exc', target='res_exc') |
| 90 | + self.add_connection(exc_inh_conn, source='res_exc', target='res_inh') |
| 91 | + self.add_connection(inh_exc_conn, source='res_inh', target='res_exc') |
| 92 | + self.add_connection(inh_inh_conn, source='res_inh', target='res_inh') |
48 | 93 |
|
49 | 94 | ## Migrate ## |
50 | 95 | self.to(device) |
51 | 96 |
|
52 | 97 | def store(self, spike_train, sim_time): |
53 | 98 | self.learning = True |
54 | 99 | self.run(inputs={'input': spike_train}, time=sim_time, reward=1) |
55 | | - res_spikes = self.res_monitor.get('s') |
| 100 | + exc_spikes = self.exc_monitor.get('s') |
| 101 | + inh_spikes = self.inh_monitor.get('s') |
56 | 102 | self.learning = False |
57 | | - return res_spikes |
| 103 | + return exc_spikes, inh_spikes |
58 | 104 |
|
59 | 105 | def recall(self, spike_train, sim_time): |
60 | 106 | self.learning = False |
61 | 107 | self.run(inputs={'input': spike_train}, time=sim_time,) |
62 | | - res_spikes = self.res_monitor.get('s') |
63 | | - return res_spikes |
| 108 | + exc_spikes = self.exc_monitor.get('s') |
| 109 | + inh_spikes = self.inh_monitor.get('s') |
| 110 | + return exc_spikes, inh_spikes |
0 commit comments