-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_specifications.py
More file actions
45 lines (33 loc) · 1.3 KB
/
test_specifications.py
File metadata and controls
45 lines (33 loc) · 1.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from wavebase.generators import Sinusoid
from wavebase.specifications import SpectralLSTM
from wavebase.lstm import LSTMLayer, LSTMCell
from .helper import TestCase
class TestSpectralLSTM(TestCase):
def test_spectral_lstm(self):
freq_max = 2
hidden_size = 14
p = SpectralLSTM(input_size=1, hidden_size=hidden_size)
params = p.lstm_params()
custom_lstm = LSTMLayer(LSTMCell, **params)
dataset = Sinusoid(
n_beats=1,
n_samples_per_beat=16,
freq_max=freq_max,
rps_max=.1
)
x, ps = dataset.signal()
# get custom lstm outputs
out, (hx, cx) = custom_lstm(x, return_outputs=True)
hx = hx.reshape(hidden_size)
hx_generator = hx[p.gen_:p.sum_]
hx_sum = hx[p.sum_:p.prod_]
hx_prod = hx[p.prod_:p.spec_]
hx_spec = hx[p.spec_:]
cx = cx.reshape(hidden_size)
cx_generator = cx[p.gen_:p.sum_]
cx_sum = cx[p.sum_:p.prod_]
cx_prod = cx[p.prod_:p.spec_]
cx_spec = cx[p.spec_:]
self.atol = 1e-3
self.assertTensorsClose(ps, cx_spec, atol=1e-2, msg="This does not compute the power spectrum")
self.assertTensorsClose(ps, hx_spec, atol=1e-2, msg="This does not compute the power spectrum")