@@ -265,6 +265,30 @@ def test_s_axes_out(dtype, s, axes, func):
265265 assert_allclose (r1 , r2 , rtol = rtol , atol = atol )
266266
267267
268+ @requires_numpy_2
269+ @pytest .mark .parametrize ("dtype" , [complex , float ])
270+ @pytest .mark .parametrize ("axes" , [(0 , 1 , 2 ), (- 1 , - 2 , - 3 ), [1 , 0 , 2 ]])
271+ @pytest .mark .parametrize ("func" , ["fftn" , "ifftn" , "rfftn" ])
272+ def test_s_none_vs_s_full (dtype , axes , func ):
273+ shape = (30 , 20 , 10 )
274+ if dtype is complex and func != "rfftn" :
275+ x = np .random .random (shape ) + 1j * np .random .random (shape )
276+ else :
277+ x = np .random .random (shape )
278+
279+ implied_s = [shape [ax ] for ax in axes ]
280+ if func == "irfftn" :
281+ implied_s [- 1 ] = 2 * (implied_s [- 1 ] - 1 )
282+
283+ r1 = getattr (np .fft , func )(x , axes = axes )
284+ r2 = getattr (mkl_fft , func )(x , axes = axes )
285+ r3 = getattr (mkl_fft , func )(x , s = implied_s , axes = axes )
286+
287+ rtol , atol = _get_rtol_atol (x )
288+ assert_allclose (r1 , r2 , rtol = rtol , atol = atol )
289+ assert_allclose (r1 , r3 , rtol = rtol , atol = atol )
290+
291+
268292@pytest .mark .parametrize ("dtype" , [complex , float ])
269293@pytest .mark .parametrize ("axes" , [(2 , 0 , 2 , 0 ), (0 , 1 , 1 ), (2 , 0 , 1 , 3 , 2 , 1 )])
270294@pytest .mark .parametrize ("func" , ["rfftn" , "irfftn" ])
0 commit comments