diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index fd23ccebe..41755dd88 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -456,6 +456,53 @@ def __iter__(self): list, itertools.product(*[range(len(c)) for c in self.chunks_normal]) ) + def range(self, start, stop=None): + """Allow efficient subsetting of the key space by range.""" + if stop is None: + it = product_from(*[range(len(c)) for c in self.chunks_normal], start=start) + else: + it = itertools.islice( + product_from(*[range(len(c)) for c in self.chunks_normal], start=start), + stop - start, + ) + return map(list, it) + + +def product_from(*iterables, start=0): + """Efficient implementation of 'itertools.product' starting at an arbitrary index.""" + pools = [tuple(pool) for pool in iterables] + + if not pools or any(len(pool) == 0 for pool in pools): + return + + # Compute the total number of combinations + total = 1 + for pool in pools: + total *= len(pool) + + if start >= total: + return + + # Decompose `start` into per-pool indices via mixed-radix conversion + indices = [0] * len(pools) + remainder = start + for k in range(len(pools) - 1, -1, -1): + indices[k] = remainder % len(pools[k]) + remainder //= len(pools[k]) + + yield tuple(pool[i] for pool, i in zip(pools, indices)) + + while True: + for k in range(len(pools) - 1, -1, -1): + if indices[k] < len(pools[k]) - 1: + indices[k] += 1 + for j in range(k + 1, len(pools)): + indices[j] = 0 + yield tuple(pools[j][indices[j]] for j in range(len(pools))) + break + else: + return + # Code for fusing blockwise operations diff --git a/cubed/tests/primitive/test_chunk_keys.py b/cubed/tests/primitive/test_chunk_keys.py new file mode 100644 index 000000000..289c4ee33 --- /dev/null +++ b/cubed/tests/primitive/test_chunk_keys.py @@ -0,0 +1,22 @@ +import pytest + +from cubed.primitive.blockwise import ChunkKeys +from cubed.utils import normalize_chunks + + +def test_chunk_keys_iter(): + chunks = (3, 2) + chunks_normal = normalize_chunks(chunks, shape=(4, 5)) + chunk_keys = ChunkKeys(chunks_normal) + assert list(chunk_keys) == [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]] + + +@pytest.mark.parametrize( + ("start", "stop"), [(0, None), (3, None), (3, 3), (3, 4), (3, 5), (3, 6), (5, None)] +) +def test_chunk_keys_range(start, stop): + chunks = (3, 2) + chunks_normal = normalize_chunks(chunks, shape=(4, 5)) + chunk_keys = ChunkKeys(chunks_normal) + all_keys = [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]] + assert list(chunk_keys.range(start, stop)) == all_keys[slice(start, stop)]