Skip to content

Commit 8996c80

Browse files
Update test_cwt_wavelets.py
1 parent 79acb20 commit 8996c80

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

pywt/tests/test_cwt_wavelets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def test_cwt_batch(axis, method):
405405
hop_size = 1
406406
batch_axis = 1 - axis
407407
sst1 = np.asarray(sst, dtype=dtype)
408-
sst = np.stack((sst1, ) * np.ceil(n_batch / hop_size).astype(int), axis=np.ceil(batch_axis / hop_size).astype(int))
408+
sst = np.stack((sst1, ) * n_batch, axis=batch_axis)
409409
dt = time[1] - time[0]
410410
wavelet = 'cmor1.5-1.0'
411411
scales = np.arange(1, 32)
@@ -424,11 +424,11 @@ def test_cwt_batch(axis, method):
424424

425425
# verify expected shape
426426
assert_equal(cfs.shape[0], len(scales))
427-
assert_equal(cfs.shape[1 + batch_axis], np.ceil(n_batch / hop_size).astype(int))
427+
assert_equal(cfs.shape[1 + batch_axis], n_batch)
428428
assert_equal(cfs.shape[1 + axis], sst.shape[axis])
429429

430430
# batch result on stacked input is the same as stacked 1d result
431-
assert_almost_equal(cfs, np.stack((cfs1,) * np.ceil(n_batch / hop_size).astype(int), axis=np.ceil(batch_axis / hop_size).astype(int) + 1),
431+
assert_almost_equal(cfs, np.stack((cfs1,) * n_batch, axis=batch_axis + 1),
432432
decimal=12)
433433

434434

0 commit comments

Comments
 (0)