Skip to content

Commit d761e90

Browse files
committed
add test for s=None vs equivalent s
1 parent db1a1c2 commit d761e90

1 file changed

Lines changed: 24 additions & 0 deletions

File tree

mkl_fft/tests/test_fftnd.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,30 @@ def test_s_axes_out(dtype, s, axes, func):
265265
assert_allclose(r1, r2, rtol=rtol, atol=atol)
266266

267267

268+
@requires_numpy_2
269+
@pytest.mark.parametrize("dtype", [complex, float])
270+
@pytest.mark.parametrize("axes", [(0, 1, 2), (-1, -2, -3), [1, 0, 2]])
271+
@pytest.mark.parametrize("func", ["fftn", "ifftn", "rfftn"])
272+
def test_s_none_vs_s_full(dtype, axes, func):
273+
shape = (30, 20, 10)
274+
if dtype is complex and func != "rfftn":
275+
x = np.random.random(shape) + 1j * np.random.random(shape)
276+
else:
277+
x = np.random.random(shape)
278+
279+
implied_s = [shape[ax] for ax in axes]
280+
if func == "irfftn":
281+
implied_s[-1] = 2 * (implied_s[-1] - 1)
282+
283+
r1 = getattr(np.fft, func)(x, axes=axes)
284+
r2 = getattr(mkl_fft, func)(x, axes=axes)
285+
r3 = getattr(mkl_fft, func)(x, s=implied_s, axes=axes)
286+
287+
rtol, atol = _get_rtol_atol(x)
288+
assert_allclose(r1, r2, rtol=rtol, atol=atol)
289+
assert_allclose(r1, r3, rtol=rtol, atol=atol)
290+
291+
268292
@pytest.mark.parametrize("dtype", [complex, float])
269293
@pytest.mark.parametrize("axes", [(2, 0, 2, 0), (0, 1, 1), (2, 0, 1, 3, 2, 1)])
270294
@pytest.mark.parametrize("func", ["rfftn", "irfftn"])

0 commit comments

Comments
 (0)