Skip to content

Commit 0ebb206

Browse files
authored
Misc dtype fixes (#771)
* Add missing dtypes in cumulative_sum and scan * Use nxp.bool not Python bool in `all` and `any` functions * Use default array index data type rather than hardcoding. * Simplify linspace start/size/step types, since they are just Python types * Namespace fixes in the optimization tests
1 parent e50a4a7 commit 0ebb206

6 files changed

Lines changed: 26 additions & 18 deletions

File tree

cubed/array_api/creation_functions.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -228,16 +228,12 @@ def linspace(
228228
def _linspace(
229229
x, size, start, step, endpoint, linspace_dtype, device=None, block_id=None
230230
):
231-
dtypes = __array_namespace_info__().default_dtypes(device=device)
232-
233231
bs = x.shape[0]
234232
i = block_id[0]
235233
adjusted_bs = bs - 1 if endpoint else bs
236234

237-
# float_ is a type casting function.
238-
float_ = dtypes["real floating"].type
239-
blockstart = float_(start + (i * size * step))
240-
blockstop = float_(blockstart + float_(adjusted_bs * step))
235+
blockstart = start + (i * size * step)
236+
blockstop = blockstart + adjusted_bs * step
241237
return nxp.linspace(
242238
blockstart, blockstop, bs, endpoint=endpoint, dtype=linspace_dtype
243239
)

cubed/array_api/searching_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,15 @@ def searchsorted(x1, x2, /, *, side="left", sorter=None):
4949
)
5050

5151
# call nxp.searchsorted for each pair of blocks in x1 and v
52+
dtype = nxp.__array_namespace_info__().default_dtypes(device=x1.device)["indexing"]
5253
out = blockwise(
5354
_searchsorted,
5455
list(range(x2.ndim + 1)),
5556
x1,
5657
[0],
5758
x2,
5859
list(range(1, x2.ndim + 1)),
59-
dtype=nxp.int64, # TODO: index dtype
60+
dtype=dtype,
6061
adjust_chunks={0: 1}, # one row for each block in x1
6162
side=side,
6263
)

cubed/array_api/statistical_functions.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@ def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False, device
2323
fname="cumulative_sum",
2424
device=device,
2525
)
26-
return scan(x, preop=nxp.sum, func=_cumulative_sum_func, binop=nxp.add, axis=axis)
26+
return scan(
27+
x,
28+
preop=nxp.sum,
29+
func=_cumulative_sum_func,
30+
binop=nxp.add,
31+
axis=axis,
32+
dtype=dtype,
33+
)
2734

2835

2936
def _cumulative_sum_func(a, /, *, axis=None, dtype=None, include_initial=False):

cubed/array_api/utility_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def all(x, /, *, axis=None, keepdims=False, split_every=None):
1010
x,
1111
nxp.all,
1212
axis=axis,
13-
dtype=bool,
13+
dtype=nxp.bool,
1414
keepdims=keepdims,
1515
split_every=split_every,
1616
)
@@ -23,7 +23,7 @@ def any(x, /, *, axis=None, keepdims=False, split_every=None):
2323
x,
2424
nxp.any,
2525
axis=axis,
26-
dtype=bool,
26+
dtype=nxp.bool,
2727
keepdims=keepdims,
2828
split_every=split_every,
2929
)

cubed/core/ops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,7 +1339,7 @@ def _partial_reduce(arrays, reduce_func=None, initial_func=None, axis=None):
13391339

13401340
def arg_reduction(x, /, arg_func, axis=None, *, keepdims=False, split_every=None):
13411341
"""A reduction that returns the array indexes, not the values."""
1342-
dtype = nxp.int64 # index data type
1342+
dtype = nxp.__array_namespace_info__().default_dtypes(device=x.device)["indexing"]
13431343
intermediate_dtype = [("i", dtype), ("v", x.dtype)]
13441344

13451345
# initial map does arg reduction on each block, and uses block id to find the absolute index within whole array
@@ -1549,7 +1549,9 @@ def scan(
15491549

15501550
# Blelloch (1990) out-of-core algorithm.
15511551
# 1. First, scan blockwise
1552-
scanned = map_blocks(func, array, axis=axis, include_initial=include_initial)
1552+
scanned = map_blocks(
1553+
func, array, dtype=dtype, axis=axis, include_initial=include_initial
1554+
)
15531555
# If there is only a single chunk, we can be done
15541556
if array.numblocks[axis] == 1:
15551557
return scanned
@@ -1567,6 +1569,7 @@ def identity_func(a, **kwargs):
15671569
initial_func=partial(preop, axis=axis, keepdims=True),
15681570
func=identity_func,
15691571
split_every={axis: split_size},
1572+
dtype=dtype,
15701573
combine_sizes={axis: split_size},
15711574
)
15721575

@@ -1581,6 +1584,7 @@ def identity_func(a, **kwargs):
15811584
preop=preop,
15821585
binop=binop,
15831586
axis=axis,
1587+
dtype=dtype,
15841588
include_initial=True,
15851589
)
15861590

cubed/tests/test_optimization.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def spec(tmp_path):
3636
def test_fusion(spec, opt_fn):
3737
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
3838
b = xp.negative(a)
39-
c = xp.astype(b, np.float32)
39+
c = xp.astype(b, xp.float32)
4040
d = xp.negative(c)
4141

4242
num_arrays = 4 # a, b, c, d
@@ -69,7 +69,7 @@ def test_fusion(spec, opt_fn):
6969
def test_fusion_compute_multiple(spec, opt_fn):
7070
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
7171
b = xp.negative(a)
72-
c = xp.astype(b, np.float32)
72+
c = xp.astype(b, xp.float32)
7373
d = xp.negative(c)
7474

7575
# if we compute c and d then both have to be materialized
@@ -97,7 +97,7 @@ def test_fusion_compute_multiple(spec, opt_fn):
9797
def test_fusion_transpose(spec, opt_fn):
9898
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
9999
b = xp.negative(a)
100-
c = xp.astype(b, np.float32)
100+
c = xp.astype(b, xp.float32)
101101
d = c.T
102102

103103
num_created_arrays = 3 # b, c, d
@@ -191,7 +191,7 @@ def test_no_fusion_multiple_edges(spec):
191191
def test_custom_optimize_function(spec):
192192
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
193193
b = xp.negative(a)
194-
c = xp.astype(b, np.float32)
194+
c = xp.astype(b, xp.float32)
195195
d = xp.negative(c)
196196

197197
num_tasks_with_no_optimization = d.plan._finalize(optimize_graph=False).num_tasks()
@@ -992,7 +992,7 @@ def test_fuse_merge_chunks_binary(spec):
992992
def test_fuse_partial_reduce_unary(spec):
993993
a = xp.ones((3, 2), chunks=(1, 2), spec=spec)
994994
b = xp.negative(a)
995-
c = partial_reduce(b, np.sum, split_every={0: 3})
995+
c = partial_reduce(b, nxp.sum, split_every={0: 3}, dtype=xp.float64)
996996

997997
opt_fn = fuse_multiple_levels()
998998

@@ -1017,7 +1017,7 @@ def test_fuse_partial_reduce_binary(spec):
10171017
a = xp.ones((3, 2), chunks=(1, 2), spec=spec)
10181018
b = xp.ones((3, 2), chunks=(1, 2), spec=spec)
10191019
c = xp.add(a, b)
1020-
d = partial_reduce(c, np.sum, split_every={0: 3})
1020+
d = partial_reduce(c, nxp.sum, split_every={0: 3}, dtype=xp.float64)
10211021

10221022
opt_fn = fuse_multiple_levels()
10231023

0 commit comments

Comments
 (0)