Skip to content

Commit 60688c7

Browse files
committed
ENH: fft: fftfreq respect default dtype/device pairs
1 parent 8d4eb67 commit 60688c7

2 files changed

Lines changed: 32 additions & 1 deletion

File tree

array_api_strict/_fft.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
from ._array_object import Array
7-
from ._devices import ALL_DEVICES, Device
7+
from ._devices import ALL_DEVICES, Device, device_supports_dtype
88
from ._data_type_functions import astype
99
from ._dtypes import (
1010
DType,
@@ -14,6 +14,7 @@
1414
complex64,
1515
float32,
1616
)
17+
from ._info import __array_namespace_info__
1718
from ._flags import requires_extension
1819

1920

@@ -269,6 +270,15 @@ def fftfreq(
269270
np_result = np.fft.fftfreq(n, d=d)
270271
if dtype:
271272
np_result = np_result.astype(dtype._np_dtype)
273+
274+
if not device_supports_dtype(device, DType(np_result.dtype)):
275+
if dtype:
276+
# user input unsupported
277+
raise ValueError(f"Device {device!r} does not support dtype={dtype!r}.")
278+
279+
dt = __array_namespace_info__().default_dtypes(device=device)["real floating"]
280+
np_result = np_result.astype(dt._np_dtype)
281+
272282
return Array._new(np_result, device=device)
273283

274284
@requires_extension('fft')
@@ -293,6 +303,15 @@ def rfftfreq(
293303
np_result = np.fft.rfftfreq(n, d=d)
294304
if dtype:
295305
np_result = np_result.astype(dtype._np_dtype)
306+
307+
if not device_supports_dtype(device, DType(np_result.dtype)):
308+
if dtype:
309+
# user input unsupported
310+
raise ValueError(f"Device {device!r} does not support dtype={dtype!r}.")
311+
312+
dt = __array_namespace_info__().default_dtypes(device=device)["real floating"]
313+
np_result = np_result.astype(dt._np_dtype)
314+
296315
return Array._new(np_result, device=device)
297316

298317
@requires_extension('fft')

array_api_strict/tests/test_device_support.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,18 @@ def test_fft_device_support_real(func_name):
3838
assert x.device == y.device
3939

4040

41+
@pytest.mark.parametrize("func_name", ("fftfreq", "rfftfreq"))
42+
def test_fft_default_dtype(func_name):
43+
func = getattr(xp.fft, func_name)
44+
device = xp.Device("F32_device")
45+
res = func(3, device=device)
46+
assert res.device == device
47+
assert res.dtype == xp.__array_namespace_info__().default_dtypes(device=device)["real floating"]
48+
49+
with pytest.raises(ValueError):
50+
func(3, device=device, dtype=xp.float64)
51+
52+
4153
class TestF32Device:
4254
@pytest.mark.parametrize("dtype_str", ["float64", "complex128"])
4355
def test_f64_raises(self, dtype_str):

0 commit comments

Comments
 (0)