diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a998dc1ce..282a799a6 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -24,6 +24,8 @@ jobs: python-version: "3.12" - os: "ubuntu-latest" python-version: "3.13" + - os: "ubuntu-latest" + python-version: "3.14" steps: - name: Checkout source diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 71cff5be2..fbf66233f 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -5,7 +5,7 @@ from functools import partial from itertools import product from numbers import Integral, Number -from typing import TYPE_CHECKING, Any, Callable, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Sequence, Tuple, Union from warnings import warn import numpy as np @@ -237,10 +237,19 @@ def key_function(out_key): raise ValueError( f"Source array shape {source.shape} does not match region shape {indexer.shape}" ) - # TODO(#800): make Zarr indexer pickle-able so we don't have to materialize all the block IDs - output_blocks = map( - lambda chunk_projection: list(chunk_projection[0]), list(indexer) - ) + + # use this wrapper to avoid generator pickle error + class OutputBlocksIterable(Iterable[List[int]]): + def __init__(self, region, shape, chunks): + self.region = region + self.shape = shape + self.chunks = chunks + + def __iter__(self): + indexer = _create_zarr_indexer(region, shape, chunks) + return map(lambda chunk_projection: list(chunk_projection[0]), indexer) + + output_blocks = OutputBlocksIterable(region, shape, chunks) out = general_blockwise( identity, diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index ed4d5e79e..fd23ccebe 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -20,7 +20,14 @@ from cubed.runtime.types import CubedPipeline from cubed.storage.store import is_storage_array from cubed.storage.zarr import LazyZarrArray, T_ZarrArray, lazy_zarr_array -from cubed.types import T_Chunks, T_DType, T_RegularChunks, T_Shape, T_Store +from cubed.types import ( + T_Chunks, + T_DType, + T_RectangularChunks, + T_RegularChunks, + T_Shape, + T_Store, +) from cubed.utils import ( array_memory, chunk_memory, @@ -415,17 +422,14 @@ def general_blockwise( ) # this must be an iterator of lists, not of tuples, otherwise lithops breaks - if output_blocks is None: - output_blocks = map( - list, itertools.product(*[range(len(c)) for c in chunks_normal]) - ) + mappable = output_blocks if output_blocks is not None else ChunkKeys(chunks_normal) if num_tasks is None: num_tasks = math.prod(len(c) for c in chunks_normal) pipeline = CubedPipeline( apply_blockwise, gensym("apply_blockwise"), - output_blocks, + mappable, spec, ) return PrimitiveOperation( @@ -442,6 +446,17 @@ def general_blockwise( ) +# use this wrapper to avoid itertools.product pickle error +class ChunkKeys(Iterable[List[int]]): + def __init__(self, chunks_normal: T_RectangularChunks): + self.chunks_normal = chunks_normal + + def __iter__(self): + return map( + list, itertools.product(*[range(len(c)) for c in self.chunks_normal]) + ) + + # Code for fusing blockwise operations diff --git a/pyproject.toml b/pyproject.toml index 422000e23..fef84eee5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", ] requires-python = ">=3.11" dependencies = [