44import numpy as np
55
66from ._array_object import Array
7- from ._devices import ALL_DEVICES , Device
7+ from ._devices import ALL_DEVICES , Device , device_supports_dtype
88from ._data_type_functions import astype
99from ._dtypes import (
1010 DType ,
1414 complex64 ,
1515 float32 ,
1616)
17+ from ._info import __array_namespace_info__
1718from ._flags import requires_extension
1819
1920
@@ -269,6 +270,15 @@ def fftfreq(
269270 np_result = np .fft .fftfreq (n , d = d )
270271 if dtype :
271272 np_result = np_result .astype (dtype ._np_dtype )
273+
274+ if not device_supports_dtype (device , DType (np_result .dtype )):
275+ if dtype :
276+ # user input unsupported
277+ raise ValueError (f"Device { device !r} does not support dtype={ dtype !r} ." )
278+
279+ dt = __array_namespace_info__ ().default_dtypes (device = device )["real floating" ]
280+ np_result = np_result .astype (dt ._np_dtype )
281+
272282 return Array ._new (np_result , device = device )
273283
274284@requires_extension ('fft' )
@@ -293,6 +303,15 @@ def rfftfreq(
293303 np_result = np .fft .rfftfreq (n , d = d )
294304 if dtype :
295305 np_result = np_result .astype (dtype ._np_dtype )
306+
307+ if not device_supports_dtype (device , DType (np_result .dtype )):
308+ if dtype :
309+ # user input unsupported
310+ raise ValueError (f"Device { device !r} does not support dtype={ dtype !r} ." )
311+
312+ dt = __array_namespace_info__ ().default_dtypes (device = device )["real floating" ]
313+ np_result = np_result .astype (dt ._np_dtype )
314+
296315 return Array ._new (np_result , device = device )
297316
298317@requires_extension ('fft' )
0 commit comments