@@ -992,6 +992,41 @@ def test_full_dtype_inference():
992992 assert np .issubdtype (dpt .full (10 , 0.3 - 2j ).dtype , np .complexfloating )
993993
994994
995+ def test_full_fill_array ():
996+ q = get_queue_or_skip ()
997+
998+ Xnp = np .array ([1 , 2 , 3 ], dtype = "i4" )
999+ X = dpt .asarray (Xnp , sycl_queue = q )
1000+
1001+ shape = (3 , 3 )
1002+ Y = dpt .full (shape , X )
1003+ Ynp = np .full (shape , Xnp )
1004+
1005+ assert Y .dtype == Ynp .dtype
1006+ assert Y .usm_type == "device"
1007+ assert np .array_equal (dpt .asnumpy (Y ), Ynp )
1008+
1009+
1010+ def test_full_compute_follows_data ():
1011+ q1 = get_queue_or_skip ()
1012+ q2 = get_queue_or_skip ()
1013+
1014+ X = dpt .arange (10 , dtype = "i4" , sycl_queue = q1 , usm_type = "shared" )
1015+ Y = dpt .full (10 , X [3 ])
1016+
1017+ assert Y .dtype == X .dtype
1018+ assert Y .usm_type == X .usm_type
1019+ assert dpctl .utils .get_execution_queue ((Y .sycl_queue , X .sycl_queue ))
1020+ assert np .array_equal (dpt .asnumpy (Y ), np .full (10 , 3 , dtype = "i4" ))
1021+
1022+ Y = dpt .full (10 , X [3 ], dtype = "f4" , sycl_queue = q2 , usm_type = "host" )
1023+
1024+ assert Y .dtype == dpt .dtype ("f4" )
1025+ assert Y .usm_type == "host"
1026+ assert dpctl .utils .get_execution_queue ((Y .sycl_queue , q2 ))
1027+ assert np .array_equal (dpt .asnumpy (Y ), np .full (10 , 3 , dtype = "f4" ))
1028+
1029+
9951030@pytest .mark .parametrize (
9961031 "dt" ,
9971032 _all_dtypes [1 :],
0 commit comments