@@ -36,7 +36,7 @@ def spec(tmp_path):
3636def test_fusion (spec , opt_fn ):
3737 a = xp .asarray ([[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]], chunks = (2 , 2 ), spec = spec )
3838 b = xp .negative (a )
39- c = xp .astype (b , np .float32 )
39+ c = xp .astype (b , xp .float32 )
4040 d = xp .negative (c )
4141
4242 num_arrays = 4 # a, b, c, d
@@ -69,7 +69,7 @@ def test_fusion(spec, opt_fn):
6969def test_fusion_compute_multiple (spec , opt_fn ):
7070 a = xp .asarray ([[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]], chunks = (2 , 2 ), spec = spec )
7171 b = xp .negative (a )
72- c = xp .astype (b , np .float32 )
72+ c = xp .astype (b , xp .float32 )
7373 d = xp .negative (c )
7474
7575 # if we compute c and d then both have to be materialized
@@ -97,7 +97,7 @@ def test_fusion_compute_multiple(spec, opt_fn):
9797def test_fusion_transpose (spec , opt_fn ):
9898 a = xp .asarray ([[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]], chunks = (2 , 2 ), spec = spec )
9999 b = xp .negative (a )
100- c = xp .astype (b , np .float32 )
100+ c = xp .astype (b , xp .float32 )
101101 d = c .T
102102
103103 num_created_arrays = 3 # b, c, d
@@ -191,7 +191,7 @@ def test_no_fusion_multiple_edges(spec):
191191def test_custom_optimize_function (spec ):
192192 a = xp .asarray ([[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]], chunks = (2 , 2 ), spec = spec )
193193 b = xp .negative (a )
194- c = xp .astype (b , np .float32 )
194+ c = xp .astype (b , xp .float32 )
195195 d = xp .negative (c )
196196
197197 num_tasks_with_no_optimization = d .plan ._finalize (optimize_graph = False ).num_tasks ()
@@ -992,7 +992,7 @@ def test_fuse_merge_chunks_binary(spec):
992992def test_fuse_partial_reduce_unary (spec ):
993993 a = xp .ones ((3 , 2 ), chunks = (1 , 2 ), spec = spec )
994994 b = xp .negative (a )
995- c = partial_reduce (b , np .sum , split_every = {0 : 3 })
995+ c = partial_reduce (b , nxp .sum , split_every = {0 : 3 }, dtype = xp . float64 )
996996
997997 opt_fn = fuse_multiple_levels ()
998998
@@ -1017,7 +1017,7 @@ def test_fuse_partial_reduce_binary(spec):
10171017 a = xp .ones ((3 , 2 ), chunks = (1 , 2 ), spec = spec )
10181018 b = xp .ones ((3 , 2 ), chunks = (1 , 2 ), spec = spec )
10191019 c = xp .add (a , b )
1020- d = partial_reduce (c , np .sum , split_every = {0 : 3 })
1020+ d = partial_reduce (c , nxp .sum , split_every = {0 : 3 }, dtype = xp . float64 )
10211021
10221022 opt_fn = fuse_multiple_levels ()
10231023
0 commit comments