Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cubed/runtime/executors/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class BeamExecutor(DagExecutor):
"""An execution engine that uses Apache Beam."""

def __init__(self, **kwargs):
self.kwargs = kwargs
super().__init__(**kwargs)

@property
def name(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion cubed/runtime/executors/coiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class CoiledExecutor(DagExecutor):
"""An execution engine that uses Coiled Functions."""

def __init__(self, **kwargs):
self.kwargs = kwargs
super().__init__(**kwargs)

@property
def name(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion cubed/runtime/executors/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class DaskExecutor(DagExecutor):
"""An execution engine that uses Dask Distributed's async API."""

def __init__(self, **kwargs):
self.kwargs = kwargs
super().__init__(**kwargs)

@property
def name(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion cubed/runtime/executors/lithops.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ class LithopsExecutor(DagExecutor):
"""An execution engine that uses Lithops."""

def __init__(self, **kwargs):
self.kwargs = kwargs
super().__init__(**kwargs)

@property
def name(self) -> str:
Expand Down
4 changes: 2 additions & 2 deletions cubed/runtime/executors/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class ThreadsExecutor(DagExecutor):
"""An execution engine that uses Python asyncio."""

def __init__(self, **kwargs):
self.kwargs = kwargs
super().__init__(**kwargs)

# Tell NumPy to use a single thread
# from https://stackoverflow.com/questions/30791550/limit-number-of-threads-in-numpy
Expand Down Expand Up @@ -204,7 +204,7 @@ class ProcessesExecutor(DagExecutor):
"""An execution engine that uses local processes."""

def __init__(self, **kwargs):
self.kwargs = kwargs
super().__init__(**kwargs)

# Tell NumPy to use a single thread
# from https://stackoverflow.com/questions/30791550/limit-number-of-threads-in-numpy
Expand Down
2 changes: 1 addition & 1 deletion cubed/runtime/executors/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class ModalExecutor(DagExecutor):
"""An execution engine that uses Modal's async API."""

def __init__(self, **kwargs):
self.kwargs = kwargs
super().__init__(**kwargs)

@property
def name(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion cubed/runtime/executors/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class RayExecutor(DagExecutor):
"""An execution engine that uses Ray."""

def __init__(self, **kwargs):
self.kwargs = kwargs
super().__init__(**kwargs)

@property
def name(self) -> str:
Expand Down
4 changes: 0 additions & 4 deletions cubed/runtime/executors/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class SparkExecutor(DagExecutor):
MIN_MEMORY_MiB = 512

def __init__(self, **kwargs):
self._callbacks = None
super().__init__(**kwargs)

@property
Expand Down Expand Up @@ -57,9 +56,6 @@ def execute_dag(
compute_id: Optional[str] = None,
**kwargs: Any,
):
# Store callbacks for later use during computation
self._callbacks = callbacks

# Configure Spark memory settings from Spec if provided
spark_builder = SparkSession.builder
if spec is not None and hasattr(spec, "allowed_mem") and spec.allowed_mem:
Expand Down
12 changes: 12 additions & 0 deletions cubed/runtime/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@


class DagExecutor:
def __init__(self, **kwargs):
self.kwargs = kwargs

def __eq__(self, other):
if isinstance(other, DagExecutor):
return self.name == other.name and self.kwargs == other.kwargs
else:
return False

def __repr__(self) -> str:
return f"{self.__class__.__name__}(kwargs={self.kwargs})"

@property
def name(self) -> str:
raise NotImplementedError # pragma: no cover
Expand Down
6 changes: 3 additions & 3 deletions cubed/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@ def zarr_compressor(self) -> Union[dict, str, None]:
return self._zarr_compressor

@property
def intermediate_store(self) -> Union[dict, str, None]:
def intermediate_store(self) -> Union[T_Store, None]:
"""The Zarr store for intermediate data. Takes precedence over ``work_dir``."""
return self._intermediate_store

def __repr__(self) -> str:
return (
f"cubed.Spec(work_dir={self._work_dir}, intermediate_store={self._intermediate_store}, allowed_mem={self._allowed_mem}, "
f"reserved_mem={self._reserved_mem}, executor={self._executor}, storage_options={self._storage_options}, zarr_compressor={self._zarr_compressor})"
f"cubed.Spec(work_dir={self.work_dir}, intermediate_store={self.intermediate_store}, allowed_mem={self.allowed_mem}, "
f"reserved_mem={self.reserved_mem}, executor={self.executor}, storage_options={self.storage_options}, zarr_compressor={self.zarr_compressor})"
)

def __eq__(self, other):
Expand Down
16 changes: 16 additions & 0 deletions cubed/tests/test_executor_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,22 @@ def mock_apply_blockwise(*args, **kwargs):
return apply_blockwise(*args, **kwargs)


def test_equality():
executor = create_executor("threads")
assert executor == create_executor("threads")

executor_2_max_workers = create_executor(
"threads", executor_options=dict(max_workers=2)
)
assert executor_2_max_workers == create_executor(
"threads", executor_options=dict(max_workers=2)
)
assert executor_2_max_workers != create_executor("threads")
assert executor_2_max_workers != create_executor(
"threads", executor_options=dict(max_workers=1)
)


# see tests/runtime for more tests for retries for other executors
@pytest.mark.skipif(
platform.system() == "Windows", reason="measuring memory does not run on windows"
Expand Down