Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ jobs:
python-version: "3.12"
- os: "ubuntu-latest"
python-version: "3.13"
- os: "ubuntu-latest"
python-version: "3.14"

steps:
- name: Checkout source
Expand Down
19 changes: 14 additions & 5 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import partial
from itertools import product
from numbers import Integral, Number
from typing import TYPE_CHECKING, Any, Callable, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Sequence, Tuple, Union
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -237,10 +237,19 @@ def key_function(out_key):
raise ValueError(
f"Source array shape {source.shape} does not match region shape {indexer.shape}"
)
# TODO(#800): make Zarr indexer pickle-able so we don't have to materialize all the block IDs
output_blocks = map(
lambda chunk_projection: list(chunk_projection[0]), list(indexer)
)

# use this wrapper to avoid generator pickle error
class OutputBlocksIterable(Iterable[List[int]]):
def __init__(self, region, shape, chunks):
self.region = region
self.shape = shape
self.chunks = chunks

def __iter__(self):
indexer = _create_zarr_indexer(region, shape, chunks)
return map(lambda chunk_projection: list(chunk_projection[0]), indexer)

output_blocks = OutputBlocksIterable(region, shape, chunks)

out = general_blockwise(
identity,
Expand Down
27 changes: 21 additions & 6 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
from cubed.runtime.types import CubedPipeline
from cubed.storage.store import is_storage_array
from cubed.storage.zarr import LazyZarrArray, T_ZarrArray, lazy_zarr_array
from cubed.types import T_Chunks, T_DType, T_RegularChunks, T_Shape, T_Store
from cubed.types import (
T_Chunks,
T_DType,
T_RectangularChunks,
T_RegularChunks,
T_Shape,
T_Store,
)
from cubed.utils import (
array_memory,
chunk_memory,
Expand Down Expand Up @@ -415,17 +422,14 @@ def general_blockwise(
)

# this must be an iterator of lists, not of tuples, otherwise lithops breaks
if output_blocks is None:
output_blocks = map(
list, itertools.product(*[range(len(c)) for c in chunks_normal])
)
mappable = output_blocks if output_blocks is not None else ChunkKeys(chunks_normal)
if num_tasks is None:
num_tasks = math.prod(len(c) for c in chunks_normal)

pipeline = CubedPipeline(
apply_blockwise,
gensym("apply_blockwise"),
output_blocks,
mappable,
spec,
)
return PrimitiveOperation(
Expand All @@ -442,6 +446,17 @@ def general_blockwise(
)


# use this wrapper to avoid itertools.product pickle error
class ChunkKeys(Iterable[List[int]]):
def __init__(self, chunks_normal: T_RectangularChunks):
self.chunks_normal = chunks_normal

def __iter__(self):
return map(
list, itertools.product(*[range(len(c)) for c in self.chunks_normal])
)


# Code for fusing blockwise operations


Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ classifiers = [
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
]
requires-python = ">=3.11"
dependencies = [
Expand Down