Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,8 @@ def execute(

if callbacks is not None:
event = ComputeStartEventWithPlan(compute_id, dag, self)
[callback.on_compute_start(event) for callback in callbacks]
for callback in callbacks:
callback.on_compute_start(event)
executor.execute_dag(
dag,
compute_id=compute_id,
Expand All @@ -601,7 +602,8 @@ def execute(
)
if callbacks is not None:
event = ComputeEndEvent(compute_id, dag)
[callback.on_compute_end(event) for callback in callbacks]
for callback in callbacks:
callback.on_compute_end(event)

def visualize(
self,
Expand Down
11 changes: 6 additions & 5 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Iterator
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import toolz
import zarr
Expand Down Expand Up @@ -355,7 +355,7 @@ def general_blockwise(
output_chunk_memory = 0
target_arrays = []

numblocks0 = None
numblocks0: Optional[Tuple[int, ...]] = None
for i, target_store in enumerate(target_stores):
chunks_normal = normalize_chunks(chunkss[i], shape=shapes[i], dtype=dtypes[i])
chunksize = to_chunksize(chunks_normal)
Expand Down Expand Up @@ -465,7 +465,7 @@ def can_fuse_primitive_ops(
def can_fuse_multiple_primitive_ops(
name: str,
primitive_op: PrimitiveOperation,
predecessor_primitive_ops: List[PrimitiveOperation],
predecessor_primitive_ops: List[Optional[PrimitiveOperation]],
*,
max_total_num_input_blocks: Optional[int] = None,
) -> bool:
Expand Down Expand Up @@ -531,7 +531,7 @@ def can_fuse_multiple_primitive_ops(
return False


def peak_projected_mem(primitive_ops):
def peak_projected_mem(primitive_ops: Iterable[Optional[PrimitiveOperation]]) -> int:
"""Calculate the peak projected memory for running a series of primitive ops
and retaining their return values in memory."""
memory_modeller = MemoryModeller()
Expand Down Expand Up @@ -612,7 +612,8 @@ def fused_func(*args):


def fuse_multiple(
primitive_op: PrimitiveOperation, *predecessor_primitive_ops: PrimitiveOperation
primitive_op: PrimitiveOperation,
*predecessor_primitive_ops: Optional[PrimitiveOperation],
) -> PrimitiveOperation:
"""
Fuse a blockwise operation and its predecessors into a single operation, avoiding writing to (or reading from) the targets of the predecessor operations.
Expand Down
2 changes: 1 addition & 1 deletion cubed/primitive/rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _setup_array_rechunk(

read_chunks, int_chunks, write_chunks = rechunking_plan(
shape,
source_chunks,
source_chunks, # type: ignore
target_chunks,
itemsize(dtype),
max_mem,
Expand Down
6 changes: 3 additions & 3 deletions cubed/runtime/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def async_map_unordered(
batch_size: Optional[int] = None,
return_stats: bool = False,
name: Optional[str] = None,
**kwargs,
**kwargs: Any,
) -> AsyncIterator[Any]:
"""
Asynchronous parallel map over an iterable input, with support for backups and batching.
Expand Down Expand Up @@ -128,7 +128,7 @@ async def async_map_dag(
dag: MultiDiGraph,
callbacks: Optional[Sequence[Callback]] = None,
compute_arrays_in_parallel: Optional[bool] = None,
**kwargs,
**kwargs: Any,
) -> None:
"""
Asynchronous parallel map over multiple pipelines from a DAG, with support for backups and batching.
Expand Down Expand Up @@ -170,7 +170,7 @@ def pipeline_to_stream(
create_futures_func: Callable,
name: str,
pipeline: CubedPipeline,
**kwargs,
**kwargs: Any,
) -> Stream:
"""
Turn a pipeline into an asynchronous stream of results.
Expand Down
6 changes: 4 additions & 2 deletions cubed/runtime/create.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Optional
from typing import Any, Dict, Optional

from cubed.runtime.types import Executor


def create_executor(name: str, executor_options: Optional[dict] = None) -> Executor:
def create_executor(
name: str, executor_options: Optional[Dict[Any, Any]] = None
) -> Executor:
"""Create an executor from an executor name."""
executor_options = executor_options or {}
if name == "beam":
Expand Down
2 changes: 1 addition & 1 deletion cubed/runtime/executors/coiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def make_coiled_function(func, name, coiled_kwargs):
class CoiledExecutor(DagExecutor):
"""An execution engine that uses Coiled Functions."""

def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

@property
Expand Down
2 changes: 1 addition & 1 deletion cubed/runtime/executors/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def create_futures_func(input, **kwargs):
class DaskExecutor(DagExecutor):
"""An execution engine that uses Dask Distributed's async API."""

def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

@property
Expand Down
2 changes: 1 addition & 1 deletion cubed/runtime/executors/lithops.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def standardise_lithops_stats(name: str, future: RetryingFuture) -> Dict[str, An
class LithopsExecutor(DagExecutor):
"""An execution engine that uses Lithops."""

def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

@property
Expand Down
19 changes: 10 additions & 9 deletions cubed/runtime/executors/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def execute_dag(
callbacks: Optional[Sequence[Callback]] = None,
spec: Optional[Spec] = None,
compute_id: Optional[str] = None,
**kwargs,
**kwargs: Any,
) -> None:
for name, node in visit_nodes(dag):
handle_operation_start_callbacks(callbacks, name)
Expand All @@ -57,7 +57,8 @@ def execute_dag(
)
if callbacks is not None:
event = TaskEndEvent(name=name, result=result)
[callback.on_task_end(event) for callback in callbacks]
for callback in callbacks:
callback.on_task_end(event)
handle_operation_end_callbacks(callbacks, name)


Expand All @@ -81,7 +82,7 @@ def unpickle_and_call(f, inp, **kwargs):
return f(inp, **kwargs)


def check_runtime_memory(spec, max_workers):
def check_runtime_memory(spec: Optional[Spec], max_workers: int) -> None:
allowed_mem = spec.allowed_mem if spec is not None else None
total_mem = psutil.virtual_memory().total
if allowed_mem is not None:
Expand Down Expand Up @@ -113,7 +114,7 @@ def create_futures_func(input, **kwargs):
class ThreadsExecutor(DagExecutor):
"""An execution engine that uses Python asyncio."""

def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

# Tell NumPy to use a single thread
Expand All @@ -133,7 +134,7 @@ def execute_dag(
callbacks: Optional[Sequence[Callback]] = None,
spec: Optional[Spec] = None,
compute_id: Optional[str] = None,
**kwargs,
**kwargs: Any,
) -> None:
merged_kwargs = {**self.kwargs, **kwargs}
asyncio_run(
Expand All @@ -152,7 +153,7 @@ async def _async_execute_dag(
callbacks: Optional[Sequence[Callback]] = None,
spec: Optional[Spec] = None,
compute_arrays_in_parallel: Optional[bool] = None,
**kwargs,
**kwargs: Any,
) -> None:
max_workers = kwargs.pop("max_workers", os.cpu_count())
if spec is not None:
Expand Down Expand Up @@ -203,7 +204,7 @@ def create_futures_func(input, **kwargs):
class ProcessesExecutor(DagExecutor):
"""An execution engine that uses local processes."""

def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

# Tell NumPy to use a single thread
Expand All @@ -223,7 +224,7 @@ def execute_dag(
callbacks: Optional[Sequence[Callback]] = None,
spec: Optional[Spec] = None,
compute_id: Optional[str] = None,
**kwargs,
**kwargs: Any,
) -> None:
merged_kwargs = {**self.kwargs, **kwargs}
asyncio_run(
Expand All @@ -242,7 +243,7 @@ async def _async_execute_dag(
callbacks: Optional[Sequence[Callback]] = None,
spec: Optional[Spec] = None,
compute_arrays_in_parallel: Optional[bool] = None,
**kwargs,
**kwargs: Any,
) -> None:
max_workers = kwargs.pop("max_workers", os.cpu_count())
if spec is not None:
Expand Down
2 changes: 1 addition & 1 deletion cubed/runtime/executors/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def create_futures_func(input, **kwargs):
class ModalExecutor(DagExecutor):
"""An execution engine that uses Modal's async API."""

def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

@property
Expand Down
4 changes: 2 additions & 2 deletions cubed/runtime/executors/ray.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Optional, Sequence
from typing import Any, Optional, Sequence

import ray
from networkx import MultiDiGraph
Expand All @@ -14,7 +14,7 @@
class RayExecutor(DagExecutor):
"""An execution engine that uses Ray."""

def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

@property
Expand Down
2 changes: 1 addition & 1 deletion cubed/runtime/executors/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class SparkExecutor(DagExecutor):
# Minimum memory allowed for Spark (512MB)
MIN_MEMORY_MiB = 512

def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

@property
Expand Down
8 changes: 4 additions & 4 deletions cubed/runtime/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, Dict
from typing import Any, Dict, Iterator, List, Tuple

import networkx as nx


def skip_node(name, dag, nodes: Dict[str, Any]) -> bool:
def skip_node(name: str, dag: nx.MultiDiGraph, nodes: Dict[str, Any]) -> bool:
"""
Return True if the array for a node doesn't have a pipeline to compute it,
or if it is marked as already computed.
Expand All @@ -15,7 +15,7 @@ def skip_node(name, dag, nodes: Dict[str, Any]) -> bool:
return nodes[name].get("computed", False)


def visit_nodes(dag):
def visit_nodes(dag: nx.MultiDiGraph) -> Iterator[Tuple[Any, Any]]:
"""Return a generator that visits the nodes in the DAG in topological order."""
nodes = {n: d for (n, d) in dag.nodes(data=True)}
for name in list(nx.topological_sort(dag)):
Expand All @@ -24,7 +24,7 @@ def visit_nodes(dag):
yield name, nodes[name]


def visit_node_generations(dag):
def visit_node_generations(dag: nx.MultiDiGraph) -> Iterator[List[Tuple[Any, Any]]]:
"""Return a generator that visits the nodes in the DAG in groups of topological generations."""
nodes = {n: d for (n, d) in dag.nodes(data=True)}
for names in nx.topological_generations(dag):
Expand Down
14 changes: 7 additions & 7 deletions cubed/runtime/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __repr__(self) -> str:
def name(self) -> str:
raise NotImplementedError # pragma: no cover

def execute_dag(self, dag: MultiDiGraph, **kwargs) -> None:
def execute_dag(self, dag: MultiDiGraph, **kwargs: Any) -> None:
raise NotImplementedError # pragma: no cover


Expand All @@ -37,7 +37,7 @@ class CubedPipeline:

function: Callable[..., Any]
name: str
mappable: Iterable
mappable: Iterable[Any]
config: Config


Expand Down Expand Up @@ -134,7 +134,7 @@ def register(self) -> None:
def unregister(self) -> None:
Callback.active.remove(self)

def on_compute_start(self, event):
def on_compute_start(self, event: ComputeStartEvent) -> None:
"""Called when the computation is about to start.

Parameters
Expand All @@ -144,7 +144,7 @@ def on_compute_start(self, event):
"""
pass # pragma: no cover

def on_compute_end(self, ComputeEndEvent):
def on_compute_end(self, event: ComputeEndEvent) -> None:
"""Called when the computation has finished.

Parameters
Expand All @@ -154,13 +154,13 @@ def on_compute_end(self, ComputeEndEvent):
"""
pass # pragma: no cover

def on_operation_start(self, event):
def on_operation_start(self, event: OperationStartEvent) -> None:
pass

def on_operation_end(self, event):
def on_operation_end(self, event: OperationEndEvent) -> None:
pass

def on_task_end(self, event):
def on_task_end(self, event: TaskEndEvent) -> None:
"""Called when the a task ends.

Parameters
Expand Down
Loading
Loading