diff --git a/cubed/core/plan.py b/cubed/core/plan.py index 03370f862..7101112a5 100644 --- a/cubed/core/plan.py +++ b/cubed/core/plan.py @@ -591,7 +591,8 @@ def execute( if callbacks is not None: event = ComputeStartEventWithPlan(compute_id, dag, self) - [callback.on_compute_start(event) for callback in callbacks] + for callback in callbacks: + callback.on_compute_start(event) executor.execute_dag( dag, compute_id=compute_id, @@ -601,7 +602,8 @@ def execute( ) if callbacks is not None: event = ComputeEndEvent(compute_id, dag) - [callback.on_compute_end(event) for callback in callbacks] + for callback in callbacks: + callback.on_compute_end(event) def visualize( self, diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index fa55a6ce6..ed4d5e79e 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -5,7 +5,7 @@ from collections.abc import Iterator from dataclasses import dataclass from functools import partial -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import toolz import zarr @@ -355,7 +355,7 @@ def general_blockwise( output_chunk_memory = 0 target_arrays = [] - numblocks0 = None + numblocks0: Optional[Tuple[int, ...]] = None for i, target_store in enumerate(target_stores): chunks_normal = normalize_chunks(chunkss[i], shape=shapes[i], dtype=dtypes[i]) chunksize = to_chunksize(chunks_normal) @@ -465,7 +465,7 @@ def can_fuse_primitive_ops( def can_fuse_multiple_primitive_ops( name: str, primitive_op: PrimitiveOperation, - predecessor_primitive_ops: List[PrimitiveOperation], + predecessor_primitive_ops: List[Optional[PrimitiveOperation]], *, max_total_num_input_blocks: Optional[int] = None, ) -> bool: @@ -531,7 +531,7 @@ def can_fuse_multiple_primitive_ops( return False -def peak_projected_mem(primitive_ops): +def peak_projected_mem(primitive_ops: Iterable[Optional[PrimitiveOperation]]) -> int: """Calculate the peak projected memory for running a series of primitive ops and retaining their return values in memory.""" memory_modeller = MemoryModeller() @@ -612,7 +612,8 @@ def fused_func(*args): def fuse_multiple( - primitive_op: PrimitiveOperation, *predecessor_primitive_ops: PrimitiveOperation + primitive_op: PrimitiveOperation, + *predecessor_primitive_ops: Optional[PrimitiveOperation], ) -> PrimitiveOperation: """ Fuse a blockwise operation and its predecessors into a single operation, avoiding writing to (or reading from) the targets of the predecessor operations. diff --git a/cubed/primitive/rechunk.py b/cubed/primitive/rechunk.py index 177ac8a42..2e4678cca 100644 --- a/cubed/primitive/rechunk.py +++ b/cubed/primitive/rechunk.py @@ -131,7 +131,7 @@ def _setup_array_rechunk( read_chunks, int_chunks, write_chunks = rechunking_plan( shape, - source_chunks, + source_chunks, # type: ignore target_chunks, itemsize(dtype), max_mem, diff --git a/cubed/runtime/asyncio.py b/cubed/runtime/asyncio.py index b9965339a..e3f265770 100644 --- a/cubed/runtime/asyncio.py +++ b/cubed/runtime/asyncio.py @@ -39,7 +39,7 @@ async def async_map_unordered( batch_size: Optional[int] = None, return_stats: bool = False, name: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> AsyncIterator[Any]: """ Asynchronous parallel map over an iterable input, with support for backups and batching. @@ -128,7 +128,7 @@ async def async_map_dag( dag: MultiDiGraph, callbacks: Optional[Sequence[Callback]] = None, compute_arrays_in_parallel: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> None: """ Asynchronous parallel map over multiple pipelines from a DAG, with support for backups and batching. @@ -170,7 +170,7 @@ def pipeline_to_stream( create_futures_func: Callable, name: str, pipeline: CubedPipeline, - **kwargs, + **kwargs: Any, ) -> Stream: """ Turn a pipeline into an asynchronous stream of results. diff --git a/cubed/runtime/create.py b/cubed/runtime/create.py index 8f49ce34c..81f5dbc0f 100644 --- a/cubed/runtime/create.py +++ b/cubed/runtime/create.py @@ -1,9 +1,11 @@ -from typing import Optional +from typing import Any, Dict, Optional from cubed.runtime.types import Executor -def create_executor(name: str, executor_options: Optional[dict] = None) -> Executor: +def create_executor( + name: str, executor_options: Optional[Dict[Any, Any]] = None +) -> Executor: """Create an executor from an executor name.""" executor_options = executor_options or {} if name == "beam": diff --git a/cubed/runtime/executors/coiled.py b/cubed/runtime/executors/coiled.py index 8c176e9fd..f7548cc55 100644 --- a/cubed/runtime/executors/coiled.py +++ b/cubed/runtime/executors/coiled.py @@ -21,7 +21,7 @@ def make_coiled_function(func, name, coiled_kwargs): class CoiledExecutor(DagExecutor): """An execution engine that uses Coiled Functions.""" - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @property diff --git a/cubed/runtime/executors/dask.py b/cubed/runtime/executors/dask.py index 09eb23fe0..a28ec4b41 100644 --- a/cubed/runtime/executors/dask.py +++ b/cubed/runtime/executors/dask.py @@ -58,7 +58,7 @@ def create_futures_func(input, **kwargs): class DaskExecutor(DagExecutor): """An execution engine that uses Dask Distributed's async API.""" - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @property diff --git a/cubed/runtime/executors/lithops.py b/cubed/runtime/executors/lithops.py index f15ebf5dd..47b597b2a 100644 --- a/cubed/runtime/executors/lithops.py +++ b/cubed/runtime/executors/lithops.py @@ -256,7 +256,7 @@ def standardise_lithops_stats(name: str, future: RetryingFuture) -> Dict[str, An class LithopsExecutor(DagExecutor): """An execution engine that uses Lithops.""" - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @property diff --git a/cubed/runtime/executors/local.py b/cubed/runtime/executors/local.py index dbb5cfe8e..e26658e98 100644 --- a/cubed/runtime/executors/local.py +++ b/cubed/runtime/executors/local.py @@ -42,7 +42,7 @@ def execute_dag( callbacks: Optional[Sequence[Callback]] = None, spec: Optional[Spec] = None, compute_id: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> None: for name, node in visit_nodes(dag): handle_operation_start_callbacks(callbacks, name) @@ -57,7 +57,8 @@ def execute_dag( ) if callbacks is not None: event = TaskEndEvent(name=name, result=result) - [callback.on_task_end(event) for callback in callbacks] + for callback in callbacks: + callback.on_task_end(event) handle_operation_end_callbacks(callbacks, name) @@ -81,7 +82,7 @@ def unpickle_and_call(f, inp, **kwargs): return f(inp, **kwargs) -def check_runtime_memory(spec, max_workers): +def check_runtime_memory(spec: Optional[Spec], max_workers: int) -> None: allowed_mem = spec.allowed_mem if spec is not None else None total_mem = psutil.virtual_memory().total if allowed_mem is not None: @@ -113,7 +114,7 @@ def create_futures_func(input, **kwargs): class ThreadsExecutor(DagExecutor): """An execution engine that uses Python asyncio.""" - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) # Tell NumPy to use a single thread @@ -133,7 +134,7 @@ def execute_dag( callbacks: Optional[Sequence[Callback]] = None, spec: Optional[Spec] = None, compute_id: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> None: merged_kwargs = {**self.kwargs, **kwargs} asyncio_run( @@ -152,7 +153,7 @@ async def _async_execute_dag( callbacks: Optional[Sequence[Callback]] = None, spec: Optional[Spec] = None, compute_arrays_in_parallel: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> None: max_workers = kwargs.pop("max_workers", os.cpu_count()) if spec is not None: @@ -203,7 +204,7 @@ def create_futures_func(input, **kwargs): class ProcessesExecutor(DagExecutor): """An execution engine that uses local processes.""" - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) # Tell NumPy to use a single thread @@ -223,7 +224,7 @@ def execute_dag( callbacks: Optional[Sequence[Callback]] = None, spec: Optional[Spec] = None, compute_id: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> None: merged_kwargs = {**self.kwargs, **kwargs} asyncio_run( @@ -242,7 +243,7 @@ async def _async_execute_dag( callbacks: Optional[Sequence[Callback]] = None, spec: Optional[Spec] = None, compute_arrays_in_parallel: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> None: max_workers = kwargs.pop("max_workers", os.cpu_count()) if spec is not None: diff --git a/cubed/runtime/executors/modal.py b/cubed/runtime/executors/modal.py index d0b837971..484d6869c 100644 --- a/cubed/runtime/executors/modal.py +++ b/cubed/runtime/executors/modal.py @@ -140,7 +140,7 @@ def create_futures_func(input, **kwargs): class ModalExecutor(DagExecutor): """An execution engine that uses Modal's async API.""" - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @property diff --git a/cubed/runtime/executors/ray.py b/cubed/runtime/executors/ray.py index 8ca739721..153e9f9e8 100644 --- a/cubed/runtime/executors/ray.py +++ b/cubed/runtime/executors/ray.py @@ -1,5 +1,5 @@ import asyncio -from typing import Optional, Sequence +from typing import Any, Optional, Sequence import ray from networkx import MultiDiGraph @@ -14,7 +14,7 @@ class RayExecutor(DagExecutor): """An execution engine that uses Ray.""" - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @property diff --git a/cubed/runtime/executors/spark.py b/cubed/runtime/executors/spark.py index dffc9bf17..794aeb9df 100644 --- a/cubed/runtime/executors/spark.py +++ b/cubed/runtime/executors/spark.py @@ -19,7 +19,7 @@ class SparkExecutor(DagExecutor): # Minimum memory allowed for Spark (512MB) MIN_MEMORY_MiB = 512 - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @property diff --git a/cubed/runtime/pipeline.py b/cubed/runtime/pipeline.py index 7d763ed0c..65d15e39c 100644 --- a/cubed/runtime/pipeline.py +++ b/cubed/runtime/pipeline.py @@ -1,9 +1,9 @@ -from typing import Any, Dict +from typing import Any, Dict, Iterator, List, Tuple import networkx as nx -def skip_node(name, dag, nodes: Dict[str, Any]) -> bool: +def skip_node(name: str, dag: nx.MultiDiGraph, nodes: Dict[str, Any]) -> bool: """ Return True if the array for a node doesn't have a pipeline to compute it, or if it is marked as already computed. @@ -15,7 +15,7 @@ def skip_node(name, dag, nodes: Dict[str, Any]) -> bool: return nodes[name].get("computed", False) -def visit_nodes(dag): +def visit_nodes(dag: nx.MultiDiGraph) -> Iterator[Tuple[Any, Any]]: """Return a generator that visits the nodes in the DAG in topological order.""" nodes = {n: d for (n, d) in dag.nodes(data=True)} for name in list(nx.topological_sort(dag)): @@ -24,7 +24,7 @@ def visit_nodes(dag): yield name, nodes[name] -def visit_node_generations(dag): +def visit_node_generations(dag: nx.MultiDiGraph) -> Iterator[List[Tuple[Any, Any]]]: """Return a generator that visits the nodes in the DAG in groups of topological generations.""" nodes = {n: d for (n, d) in dag.nodes(data=True)} for names in nx.topological_generations(dag): diff --git a/cubed/runtime/types.py b/cubed/runtime/types.py index 52858314a..251712298 100644 --- a/cubed/runtime/types.py +++ b/cubed/runtime/types.py @@ -24,7 +24,7 @@ def __repr__(self) -> str: def name(self) -> str: raise NotImplementedError # pragma: no cover - def execute_dag(self, dag: MultiDiGraph, **kwargs) -> None: + def execute_dag(self, dag: MultiDiGraph, **kwargs: Any) -> None: raise NotImplementedError # pragma: no cover @@ -37,7 +37,7 @@ class CubedPipeline: function: Callable[..., Any] name: str - mappable: Iterable + mappable: Iterable[Any] config: Config @@ -134,7 +134,7 @@ def register(self) -> None: def unregister(self) -> None: Callback.active.remove(self) - def on_compute_start(self, event): + def on_compute_start(self, event: ComputeStartEvent) -> None: """Called when the computation is about to start. Parameters @@ -144,7 +144,7 @@ def on_compute_start(self, event): """ pass # pragma: no cover - def on_compute_end(self, ComputeEndEvent): + def on_compute_end(self, event: ComputeEndEvent) -> None: """Called when the computation has finished. Parameters @@ -154,13 +154,13 @@ def on_compute_end(self, ComputeEndEvent): """ pass # pragma: no cover - def on_operation_start(self, event): + def on_operation_start(self, event: OperationStartEvent) -> None: pass - def on_operation_end(self, event): + def on_operation_end(self, event: OperationEndEvent) -> None: pass - def on_task_end(self, event): + def on_task_end(self, event: TaskEndEvent) -> None: """Called when the a task ends. Parameters diff --git a/cubed/runtime/utils.py b/cubed/runtime/utils.py index d664eacb6..694678777 100644 --- a/cubed/runtime/utils.py +++ b/cubed/runtime/utils.py @@ -5,8 +5,14 @@ from functools import partial from itertools import islice from pathlib import Path - -from cubed.runtime.types import OperationEndEvent, OperationStartEvent, TaskEndEvent +from typing import Any, Coroutine, Dict, Iterable, Optional, TypeVar + +from cubed.runtime.types import ( + Callback, + OperationEndEvent, + OperationStartEvent, + TaskEndEvent, +) from cubed.utils import peak_measured_mem try: @@ -14,6 +20,8 @@ except ImportError: memray = None # type: ignore +T = TypeVar("T") + sym_counter = 0 @@ -99,19 +107,27 @@ def profile_memray(func): return partial(execute_with_memray, func) -def handle_operation_start_callbacks(callbacks, name): +def handle_operation_start_callbacks( + callbacks: Optional[Iterable[Callback]], name: str +) -> None: if callbacks is not None: event = OperationStartEvent(name) - [callback.on_operation_start(event) for callback in callbacks] + for callback in callbacks: + callback.on_operation_start(event) -def handle_operation_end_callbacks(callbacks, name): +def handle_operation_end_callbacks(callbacks, name) -> None: if callbacks is not None: event = OperationEndEvent(name) - [callback.on_operation_end(event) for callback in callbacks] + for callback in callbacks: + callback.on_operation_end(event) -def handle_callbacks(callbacks, result, stats): +def handle_callbacks( + callbacks: Optional[Iterable[Callback]], + result: Optional[Any], + stats: Dict[str, Any], +) -> None: """Construct a TaskEndEvent from stats and send to all callbacks.""" if callbacks is not None: @@ -124,7 +140,8 @@ def handle_callbacks(callbacks, result, stats): ) else: event = TaskEndEvent(result=result, **stats) - [callback.on_task_end(event) for callback in callbacks] + for callback in callbacks: + callback.on_task_end(event) def raise_if_computes(): @@ -137,7 +154,7 @@ def raise_if_computes(): # Like asyncio.run(), but works in a Jupyter notebook # Based on https://stackoverflow.com/a/75341431 -def asyncio_run(coro): +def asyncio_run(coro: Coroutine[Any, Any, T]) -> T: try: asyncio.get_running_loop() # Triggers RuntimeError if no running event loop except RuntimeError: diff --git a/cubed/storage/store.py b/cubed/storage/store.py index d66b8e3ec..dfdac6735 100644 --- a/cubed/storage/store.py +++ b/cubed/storage/store.py @@ -1,10 +1,10 @@ -from typing import Optional +from typing import Any, Optional from cubed import config from cubed.types import T_DType, T_RegularChunks, T_Shape, T_Store -def get_storage_name(): +def get_storage_name() -> str: # get storage name from top-level config # e.g. set globally with CUBED_STORAGE_NAME=tensorstore storage_name = config.get("storage_name", None) @@ -20,13 +20,13 @@ def get_storage_name(): return storage_name -def is_storage_array(obj): +def is_storage_array(obj) -> bool: storage_name = get_storage_name() if storage_name == "zarr-python": import zarr - from cubed.storage.stores.zarr_python import ZarrArrayGroup + from cubed.storage.stores.zarr_python import ZarrArrayGroup # type: ignore return isinstance(obj, (zarr.Array, ZarrArrayGroup)) elif storage_name in ("zarr-python-v3", "zarrs-python"): @@ -54,7 +54,7 @@ def open_storage_array( dtype: Optional[T_DType] = None, chunks: Optional[T_RegularChunks] = None, path: Optional[str] = None, - **kwargs, + **kwargs: Any, ): storage_name = get_storage_name() diff --git a/cubed/storage/virtual.py b/cubed/storage/virtual.py index c2f227924..c328dc383 100644 --- a/cubed/storage/virtual.py +++ b/cubed/storage/virtual.py @@ -1,14 +1,17 @@ from numbers import Integral -from typing import Any +from typing import TYPE_CHECKING, Any, Tuple import numpy as np from cubed.backend_array_api import namespace as nxp from cubed.backend_array_api import numpy_array_to_backend_array from cubed.storage.types import ArrayMetadata -from cubed.types import T_DType, T_RegularChunks, T_Shape +from cubed.types import T_DType, T_RegularChunks, T_Shape, T_StandardArray from cubed.utils import array_memory, broadcast_trick, memory_repr +if TYPE_CHECKING: + from zarr.core.indexing import Selection + class VirtualArray(ArrayMetadata): pass @@ -25,7 +28,7 @@ def __init__( ): super().__init__(shape, dtype, chunks) - def __getitem__(self, key): + def __getitem__(self, key: "Selection") -> T_StandardArray: from ndindex import ndindex # import as needed to avoid slow 'import cubed' idx = ndindex[key] @@ -34,7 +37,7 @@ def __getitem__(self, key): return broadcast_trick(nxp.empty)(newshape, dtype=self.dtype) @property - def chunkmem(self): + def chunkmem(self) -> int: # take broadcast trick into account return array_memory(self.dtype, (1,)) @@ -52,7 +55,7 @@ def __init__( super().__init__(shape, dtype, chunks) self.fill_value = fill_value - def __getitem__(self, key): + def __getitem__(self, key: "Selection") -> T_StandardArray: from ndindex import ndindex # import as needed to avoid slow 'import cubed' idx = ndindex[key] @@ -63,7 +66,7 @@ def __getitem__(self, key): ) @property - def chunkmem(self): + def chunkmem(self) -> int: # take broadcast trick into account return array_memory(self.dtype, (1,)) @@ -76,11 +79,12 @@ def __init__(self, shape: T_Shape): chunks = (1,) * len(shape) super().__init__(shape, dtype, chunks) - def __getitem__(self, key): + def __getitem__(self, key: "Selection") -> T_StandardArray: if key == () and self.shape == (): return nxp.asarray(0, dtype=self.dtype) return numpy_array_to_backend_array( - np.ravel_multi_index(_key_to_index_tuple(key), self.shape), dtype=self.dtype + np.ravel_multi_index(_key_to_index_tuple(key), self.shape), + dtype=self.dtype, # type: ignore[arg-type] ) @@ -89,7 +93,7 @@ class VirtualInMemoryArray(VirtualArray): def __init__( self, - array: np.ndarray, # TODO: generalise to array API type + array: T_StandardArray, chunks: T_RegularChunks, max_nbytes: int = 10**6, ): @@ -101,18 +105,18 @@ def __init__( self.array = array super().__init__(array.shape, array.dtype, chunks) - def __getitem__(self, key): + def __getitem__(self, key: "Selection") -> T_StandardArray: return self.array.__getitem__(key) -def _key_to_index_tuple(selection): - if isinstance(selection, slice): +def _key_to_index_tuple(selection: "Selection") -> Tuple[int, ...]: + if isinstance(selection, (slice, Integral)): selection = (selection,) - assert all(isinstance(s, (slice, Integral)) for s in selection) + assert all(isinstance(s, (slice, Integral)) for s in selection) # type: ignore[union-attr] sel = [] - for s in selection: + for s in selection: # type: ignore[union-attr] if isinstance(s, Integral): - sel.append(s) + sel.append(int(s)) elif ( isinstance(s, slice) and s.stop == s.start + 1 @@ -125,7 +129,7 @@ def _key_to_index_tuple(selection): def virtual_empty( - shape: T_Shape, *, dtype: T_DType, chunks: T_RegularChunks, **kwargs + shape: T_Shape, *, dtype: T_DType, chunks: T_RegularChunks, **kwargs: Any ) -> VirtualEmptyArray: return VirtualEmptyArray(shape, dtype, chunks, **kwargs) @@ -136,7 +140,7 @@ def virtual_full( *, dtype: T_DType, chunks: T_RegularChunks, - **kwargs, + **kwargs: Any, ) -> VirtualFullArray: return VirtualFullArray(shape, dtype, chunks, fill_value, **kwargs) @@ -146,7 +150,7 @@ def virtual_offsets(shape: T_Shape) -> VirtualOffsetsArray: def virtual_in_memory( - array: np.ndarray, + array: T_StandardArray, chunks: T_RegularChunks, ) -> VirtualInMemoryArray: return VirtualInMemoryArray(array, chunks) diff --git a/cubed/storage/zarr.py b/cubed/storage/zarr.py index 1f67b613a..e10b34d5c 100644 --- a/cubed/storage/zarr.py +++ b/cubed/storage/zarr.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Any, Optional, Union import zarr @@ -22,8 +22,8 @@ def __init__( dtype: T_DType, chunks: T_RegularChunks, path: Optional[str] = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: """Create a Zarr array lazily in memory.""" super().__init__(shape, dtype, chunks) self.store = store @@ -82,7 +82,7 @@ def lazy_zarr_array( dtype: T_DType, chunks: T_RegularChunks, path: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> LazyZarrArray: return LazyZarrArray( store, diff --git a/cubed/tests/storage/test_virtual.py b/cubed/tests/storage/test_virtual.py index d46deece4..336cce883 100644 --- a/cubed/tests/storage/test_virtual.py +++ b/cubed/tests/storage/test_virtual.py @@ -4,7 +4,8 @@ import numpy as np import pytest -from cubed.storage.virtual import virtual_empty, virtual_offsets +from cubed._testing import assert_array_equal +from cubed.storage.virtual import virtual_empty, virtual_full, virtual_offsets @pytest.mark.parametrize( @@ -26,6 +27,23 @@ def test_virtual_empty(shape, chunks, index): assert v_empty[...].shape == empty[...].shape +@pytest.mark.parametrize( + ("shape", "chunks", "index"), + [ + ((3,), (2,), 2), + ((3, 2), (2, 1), (2, 1)), + ((3, 2), (2, 1), (2, slice(0, 1))), + ((3, 2), (2, 1), (slice(1, 3), 1)), + ((3, 2), (2, 1), (slice(1, 3), slice(0, 1))), + ], +) +def test_virtual_full(shape, chunks, index): + v_full = virtual_full(shape, fill_value=7, dtype=np.int32, chunks=chunks) + full = np.full(shape, fill_value=7, dtype=np.int32) + assert_array_equal(v_full[index], full[index]) + assert_array_equal(v_full[...], full[...]) + + @pytest.mark.parametrize("shape", [(), (3,), (3, 2)]) def test_virtual_offsets(shape): v_offsets = virtual_offsets(shape) @@ -35,6 +53,7 @@ def test_virtual_offsets(shape): # test some length 1 slices if len(shape) == 1: + assert v_offsets[1] == offsets[1] assert v_offsets[1:2] == offsets[1:2] elif len(shape) == 2: assert v_offsets[1:2, 0:1] == offsets[1:2, 0:1] diff --git a/cubed/types.py b/cubed/types.py index 1c3296adf..3f9c8c35b 100644 --- a/cubed/types.py +++ b/cubed/types.py @@ -19,3 +19,6 @@ ] T_Store = Any # TODO: improve this + +# Use https://github.com/data-apis/array-api-typing when ready +T_StandardArray = Any diff --git a/cubed/utils.py b/cubed/utils.py index 5878d406a..2a21ea25c 100644 --- a/cubed/utils.py +++ b/cubed/utils.py @@ -14,7 +14,7 @@ from pathlib import Path from posixpath import join from types import FrameType -from typing import Dict, Optional, Tuple, Union, cast +from typing import Any, Callable, Dict, Optional, Tuple, Union, cast from urllib.parse import quote, unquote, urlsplit, urlunsplit import numpy as np @@ -23,7 +23,14 @@ from cubed.backend_array_api import backend_dtype_to_numpy_dtype from cubed.backend_array_api import namespace as nxp -from cubed.types import T_Chunks, T_DType, T_RectangularChunks, T_RegularChunks, T_Shape +from cubed.types import ( + T_Chunks, + T_DType, + T_RectangularChunks, + T_RegularChunks, + T_Shape, + T_StandardArray, +) from cubed.vendor.dask.array.core import _check_regular_chunks from cubed.vendor.dask.array.core import normalize_chunks as dask_normalize_chunks @@ -360,14 +367,18 @@ def map_nested(func, seq): return func(seq) -def _broadcast_trick_inner(func, shape, *args, **kwargs): +def _broadcast_trick_inner( + func: Callable[..., T_StandardArray], shape: T_Shape, *args: Any, **kwargs: Any +) -> T_StandardArray: # cupy-specific hack. numpy is happy with hardcoded shape=(). null_shape = () if shape == () else 1 return nxp.broadcast_to(func(*args, shape=null_shape, **kwargs), shape) -def broadcast_trick(func): +def broadcast_trick( + func: Callable[..., T_StandardArray], +) -> Callable[..., T_StandardArray]: """Apply Dask's broadcast trick to array API functions that produce arrays containing a single value to save space in memory. @@ -375,11 +386,11 @@ def broadcast_trick(func): """ inner = partial(_broadcast_trick_inner, func) inner.__doc__ = func.__doc__ - inner.__name__ = func.__name__ + inner.__name__ = func.__name__ # type: ignore[attr-defined] return inner -def normalize_shape(shape: Union[int, Tuple[int, ...], None]) -> Tuple[int, ...]: +def normalize_shape(shape: Union[int, Tuple[int, ...], None]) -> T_Shape: """Normalize a `shape` argument to a tuple of ints.""" if shape is None: