diff --git a/cubed/runtime/executors/beam.py b/cubed/runtime/executors/beam.py index e6861e452..d997cd866 100644 --- a/cubed/runtime/executors/beam.py +++ b/cubed/runtime/executors/beam.py @@ -83,7 +83,7 @@ class BeamExecutor(DagExecutor): """An execution engine that uses Apache Beam.""" def __init__(self, **kwargs): - self.kwargs = kwargs + super().__init__(**kwargs) @property def name(self) -> str: diff --git a/cubed/runtime/executors/coiled.py b/cubed/runtime/executors/coiled.py index 751c4b3ac..8c176e9fd 100644 --- a/cubed/runtime/executors/coiled.py +++ b/cubed/runtime/executors/coiled.py @@ -22,7 +22,7 @@ class CoiledExecutor(DagExecutor): """An execution engine that uses Coiled Functions.""" def __init__(self, **kwargs): - self.kwargs = kwargs + super().__init__(**kwargs) @property def name(self) -> str: diff --git a/cubed/runtime/executors/dask.py b/cubed/runtime/executors/dask.py index d4bfbc8a8..09eb23fe0 100644 --- a/cubed/runtime/executors/dask.py +++ b/cubed/runtime/executors/dask.py @@ -59,7 +59,7 @@ class DaskExecutor(DagExecutor): """An execution engine that uses Dask Distributed's async API.""" def __init__(self, **kwargs): - self.kwargs = kwargs + super().__init__(**kwargs) @property def name(self) -> str: diff --git a/cubed/runtime/executors/lithops.py b/cubed/runtime/executors/lithops.py index f0904a9d7..f15ebf5dd 100644 --- a/cubed/runtime/executors/lithops.py +++ b/cubed/runtime/executors/lithops.py @@ -257,7 +257,7 @@ class LithopsExecutor(DagExecutor): """An execution engine that uses Lithops.""" def __init__(self, **kwargs): - self.kwargs = kwargs + super().__init__(**kwargs) @property def name(self) -> str: diff --git a/cubed/runtime/executors/local.py b/cubed/runtime/executors/local.py index 2148cf12d..dbb5cfe8e 100644 --- a/cubed/runtime/executors/local.py +++ b/cubed/runtime/executors/local.py @@ -114,7 +114,7 @@ class ThreadsExecutor(DagExecutor): """An execution engine that uses Python asyncio.""" def __init__(self, **kwargs): - self.kwargs = kwargs + super().__init__(**kwargs) # Tell NumPy to use a single thread # from https://stackoverflow.com/questions/30791550/limit-number-of-threads-in-numpy @@ -204,7 +204,7 @@ class ProcessesExecutor(DagExecutor): """An execution engine that uses local processes.""" def __init__(self, **kwargs): - self.kwargs = kwargs + super().__init__(**kwargs) # Tell NumPy to use a single thread # from https://stackoverflow.com/questions/30791550/limit-number-of-threads-in-numpy diff --git a/cubed/runtime/executors/modal.py b/cubed/runtime/executors/modal.py index fc3d2fd69..d0b837971 100644 --- a/cubed/runtime/executors/modal.py +++ b/cubed/runtime/executors/modal.py @@ -141,7 +141,7 @@ class ModalExecutor(DagExecutor): """An execution engine that uses Modal's async API.""" def __init__(self, **kwargs): - self.kwargs = kwargs + super().__init__(**kwargs) @property def name(self) -> str: diff --git a/cubed/runtime/executors/ray.py b/cubed/runtime/executors/ray.py index def8ab7d4..8ca739721 100644 --- a/cubed/runtime/executors/ray.py +++ b/cubed/runtime/executors/ray.py @@ -15,7 +15,7 @@ class RayExecutor(DagExecutor): """An execution engine that uses Ray.""" def __init__(self, **kwargs): - self.kwargs = kwargs + super().__init__(**kwargs) @property def name(self) -> str: diff --git a/cubed/runtime/executors/spark.py b/cubed/runtime/executors/spark.py index bdd13236f..dffc9bf17 100644 --- a/cubed/runtime/executors/spark.py +++ b/cubed/runtime/executors/spark.py @@ -20,7 +20,6 @@ class SparkExecutor(DagExecutor): MIN_MEMORY_MiB = 512 def __init__(self, **kwargs): - self._callbacks = None super().__init__(**kwargs) @property @@ -57,9 +56,6 @@ def execute_dag( compute_id: Optional[str] = None, **kwargs: Any, ): - # Store callbacks for later use during computation - self._callbacks = callbacks - # Configure Spark memory settings from Spec if provided spark_builder = SparkSession.builder if spec is not None and hasattr(spec, "allowed_mem") and spec.allowed_mem: diff --git a/cubed/runtime/types.py b/cubed/runtime/types.py index 9a3636450..52858314a 100644 --- a/cubed/runtime/types.py +++ b/cubed/runtime/types.py @@ -8,6 +8,18 @@ class DagExecutor: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def __eq__(self, other): + if isinstance(other, DagExecutor): + return self.name == other.name and self.kwargs == other.kwargs + else: + return False + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(kwargs={self.kwargs})" + @property def name(self) -> str: raise NotImplementedError # pragma: no cover diff --git a/cubed/spec.py b/cubed/spec.py index da6b9c8ec..22d0ba6bb 100644 --- a/cubed/spec.py +++ b/cubed/spec.py @@ -125,14 +125,14 @@ def zarr_compressor(self) -> Union[dict, str, None]: return self._zarr_compressor @property - def intermediate_store(self) -> Union[dict, str, None]: + def intermediate_store(self) -> Union[T_Store, None]: """The Zarr store for intermediate data. Takes precedence over ``work_dir``.""" return self._intermediate_store def __repr__(self) -> str: return ( - f"cubed.Spec(work_dir={self._work_dir}, intermediate_store={self._intermediate_store}, allowed_mem={self._allowed_mem}, " - f"reserved_mem={self._reserved_mem}, executor={self._executor}, storage_options={self._storage_options}, zarr_compressor={self._zarr_compressor})" + f"cubed.Spec(work_dir={self.work_dir}, intermediate_store={self.intermediate_store}, allowed_mem={self.allowed_mem}, " + f"reserved_mem={self.reserved_mem}, executor={self.executor}, storage_options={self.storage_options}, zarr_compressor={self.zarr_compressor})" ) def __eq__(self, other): diff --git a/cubed/tests/test_executor_features.py b/cubed/tests/test_executor_features.py index 24160029c..0a319f1b5 100644 --- a/cubed/tests/test_executor_features.py +++ b/cubed/tests/test_executor_features.py @@ -73,6 +73,22 @@ def mock_apply_blockwise(*args, **kwargs): return apply_blockwise(*args, **kwargs) +def test_equality(): + executor = create_executor("threads") + assert executor == create_executor("threads") + + executor_2_max_workers = create_executor( + "threads", executor_options=dict(max_workers=2) + ) + assert executor_2_max_workers == create_executor( + "threads", executor_options=dict(max_workers=2) + ) + assert executor_2_max_workers != create_executor("threads") + assert executor_2_max_workers != create_executor( + "threads", executor_options=dict(max_workers=1) + ) + + # see tests/runtime for more tests for retries for other executors @pytest.mark.skipif( platform.system() == "Windows", reason="measuring memory does not run on windows"