2222 zeros ,
2323 zeros_like ,
2424)
25- from .._dtypes import float32 , float64
25+ from .._dtypes import float32 , float64 , bool as xp_bool
2626from .._array_object import Array
2727from .._devices import CPU_DEVICE , ALL_DEVICES , Device
2828from .._info import __array_namespace_info__
@@ -250,6 +250,7 @@ def test_ones_like_etc_correct(self, func):
250250 device = Device ('F32_device' )
251251 b = func (a , device = device )
252252 assert b .dtype == self .info .default_dtypes (device = device )["real floating" ]
253+ assert a .device == device
253254
254255 @pytest .mark .parametrize ("func" , [empty_like , zeros_like , ones_like , _full_like ])
255256 def test_ones_like_etc_incorrect (self , func ):
@@ -277,6 +278,7 @@ def test_eye(self):
277278 device = Device ('F32_device' )
278279 a = eye (3 , device = device )
279280 assert a .dtype == self .info .default_dtypes (device = device )["real floating" ]
281+ assert a .device == device
280282
281283 with pytest .raises ((TypeError , ValueError )):
282284 eye (3 , device = device , dtype = float64 )
@@ -286,6 +288,7 @@ def test_linspace(self):
286288
287289 a = linspace (1 , 10 , 11 , device = device )
288290 assert a .dtype == self .info .default_dtypes (device = device )["real floating" ]
291+ assert a .device == device
289292
290293 a = linspace (1 + 0j , 10 , 11 , device = device )
291294 assert a .dtype == self .info .default_dtypes (device = device )["complex floating" ]
@@ -298,18 +301,59 @@ def test_arange(self):
298301
299302 a = arange (0 , 10 , 1 , device = device )
300303 assert a .dtype == self .info .default_dtypes (device = device )["integral" ]
304+ assert a .device == device
301305
302306 a = arange (0.0 , 10 , 1 , device = device )
303307 assert a .dtype == self .info .default_dtypes (device = device )["real floating" ]
308+ assert a .device == device
304309
305310 with pytest .raises ((TypeError , ValueError )):
306311 arange (0 , 10 , 1 , device = device , dtype = float64 )
307312
308313 with pytest .raises ((TypeError , ValueError )):
309314 arange (0.0 , 10 , 1 , device = device , dtype = float64 )
310315
311- # TODO:
312- # def asarray(
316+ def test_asarray (self ):
317+ device = Device ('F32_device' )
318+
319+ ### asarray(python_object)
320+ for x in (True , [False ,]):
321+ arr = asarray (x , device = device )
322+ assert arr .dtype == xp_bool
323+ assert arr .device == device
324+
325+ for x in [1 , [1 ,]]:
326+ arr = asarray (x , device = device )
327+ assert arr .dtype == self .info .default_dtypes (device = device )['integral' ]
328+ assert arr .device == device
329+
330+ for x in [1.0 , [1.0 ,]]:
331+ arr = asarray (x , device = device )
332+ assert arr .dtype == self .info .default_dtypes (device = device )['real floating' ]
333+ assert arr .device == device
334+
335+ for x in [1j , [1j ,]]:
336+ arr = asarray (x , device = device )
337+ assert arr .dtype == self .info .default_dtypes (device = device )['complex floating' ]
338+ assert arr .device == device
339+
340+ # asarray(python_object, dtype=unsupported_by_device)
341+ with pytest .raises (ValueError , match = "Device" ):
342+ asarray (1 , dtype = float64 , device = device )
343+
344+ ### asarray(array)
345+
346+ # compatible dtypes, device transfer
347+ src = asarray (1 , dtype = float32 , device = Device ('device1' ))
348+ dst = asarray (src , device = device )
349+ assert dst .device == device
350+ assert dst .dtype == float32
351+
352+ # incompatible dtypes, device transfer
353+ src = asarray (1 , dtype = float64 , device = Device ('device1' ))
354+
355+ with pytest .raises (ValueError , match = "Device" ):
356+ asarray (src , device = device )
313357
314358
315359@pytest .mark .parametrize ("api_version" , ['2021.12' , '2022.12' , '2023.12' ])
0 commit comments