Skip to content

Commit ca2beb1

Browse files
authored
Add a blocks property to Cubed arrays to allow block-level indexing (#737)
* Add a `blocks` property to Cubed arrays to allow block-level indexing This is similar to Dask and Zarr. * Disallow multiple integer array indexes for `blocks` Fix typo Handle missing else case
1 parent 689db10 commit ca2beb1

3 files changed

Lines changed: 132 additions & 0 deletions

File tree

cubed/core/array.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ def zarray(self):
5757
"""The underlying Zarr array. May only be used during the computation once the array has been created."""
5858
return open_if_lazy_zarr_array(self._zarray)
5959

60+
@property
61+
def blocks(self):
62+
"""An array-like interface to the blocks of an array."""
63+
from cubed.core.indexing import BlockView
64+
65+
return BlockView(self)
66+
6067
@property
6168
def chunkmem(self):
6269
"""Amount of memory in bytes that a single chunk uses."""

cubed/core/indexing.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import math
2+
from typing import TYPE_CHECKING
3+
4+
import ndindex
5+
import numpy as np
6+
from toolz import map
7+
8+
from cubed.core.ops import general_blockwise
9+
10+
if TYPE_CHECKING:
11+
from cubed.array_api.array_object import Array
12+
13+
14+
class BlockView:
15+
"""An array-like interface to the blocks of an array."""
16+
17+
def __init__(self, array: "Array"):
18+
self.array = array
19+
20+
def __getitem__(self, key) -> "Array":
21+
if not isinstance(key, tuple):
22+
key = (key,)
23+
24+
# Canonicalize index
25+
idx = ndindex.ndindex(key)
26+
idx = idx.expand(self.array.numblocks)
27+
28+
if any(isinstance(ia, ndindex.Newaxis) for ia in idx.args):
29+
raise ValueError("Slicing with xp.newaxis is not supported")
30+
31+
if sum(1 for ia in idx.args if isinstance(ia, ndindex.IntegerArray)) > 1:
32+
raise NotImplementedError("Only one integer array index is allowed.")
33+
34+
# convert Integer to Slice so we don't lose dimensions
35+
def convert_integer_index_to_slice(ia):
36+
if isinstance(ia, ndindex.Integer):
37+
return ndindex.Slice(ia.raw, ia.raw + 1)
38+
return ia
39+
40+
idx = ndindex.Tuple(*(convert_integer_index_to_slice(ia) for ia in idx.args))
41+
42+
chunks = tuple(
43+
tuple(np.array(ch)[ia].tolist())
44+
for ia, ch in zip(idx.raw, self.array.chunks)
45+
)
46+
shape = tuple(map(sum, chunks))
47+
48+
identity = lambda a: a
49+
50+
def get_dim_index(ia, i):
51+
if isinstance(ia, ndindex.Slice):
52+
step = ia.step or 1
53+
return ia.start + (step * i)
54+
elif isinstance(ia, ndindex.IntegerArray):
55+
return ia.raw[i]
56+
else:
57+
raise NotImplementedError(
58+
"Only integer, slice, or int array indexes are supported."
59+
)
60+
61+
def key_function(out_key):
62+
out_coords = out_key[1:]
63+
in_coords = tuple(
64+
get_dim_index(ia, bi) for ia, bi in zip(idx.args, out_coords)
65+
)
66+
return ((self.array.name, *in_coords),)
67+
68+
out = general_blockwise(
69+
identity,
70+
key_function,
71+
self.array,
72+
shapes=[shape],
73+
dtypes=[self.array.dtype],
74+
chunkss=[chunks],
75+
)
76+
77+
from cubed import Array
78+
79+
assert isinstance(out, Array) # single output
80+
return out
81+
82+
@property
83+
def size(self) -> int:
84+
"""
85+
The total number of blocks in the array.
86+
"""
87+
return math.prod(self.shape)
88+
89+
@property
90+
def shape(self) -> tuple[int, ...]:
91+
"""
92+
The number of blocks per axis.
93+
"""
94+
return self.array.numblocks

cubed/tests/test_indexing.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,34 @@ def test_multiple_int_array_indexes(spec):
7878
)
7979
with pytest.raises(NotImplementedError):
8080
a[[1, 2, 1], [2, 1, 0]]
81+
82+
83+
def test_blocks():
84+
# based on dask tests
85+
x = xp.arange(10, chunks=2)
86+
assert x.blocks.shape == (5,)
87+
assert x.blocks.size == 5
88+
89+
assert_array_equal(x.blocks[0], x[:2])
90+
assert_array_equal(x.blocks[-1], x[-2:])
91+
assert_array_equal(x.blocks[:3], x[:6])
92+
assert_array_equal(x.blocks[[0, 1, 2]], x[:6])
93+
assert_array_equal(x.blocks[[3, 0, 2]], np.array([6, 7, 0, 1, 4, 5]))
94+
95+
x = cubed.random.random((20, 20), chunks=(4, 5))
96+
assert x.blocks.shape == (5, 4)
97+
assert x.blocks.size == 20
98+
assert_array_equal(x.blocks[0], x[:4])
99+
assert_array_equal(x.blocks[0, :3], x[:4, :15])
100+
assert_array_equal(x.blocks[:, :3], x[:, :15])
101+
102+
x = xp.ones((40, 40, 40), chunks=(10, 10, 10))
103+
assert_array_equal(x.blocks[0, :, 0], np.ones((10, 40, 10)))
104+
105+
x = xp.ones((2, 2), chunks=1)
106+
with pytest.raises(ValueError, match="newaxis is not supported"):
107+
x.blocks[xp.newaxis, :, :]
108+
with pytest.raises(NotImplementedError):
109+
x.blocks[[0, 1], [0, 1]]
110+
with pytest.raises(IndexError, match="out of bounds"):
111+
x.blocks[100, 100]

0 commit comments

Comments
 (0)