diff --git a/bindings/distrdf/python/DistRDF/Backends/Dask/Backend.py b/bindings/distrdf/python/DistRDF/Backends/Dask/Backend.py index 84a9a0bbd4abc..5e3da6ad5a836 100644 --- a/bindings/distrdf/python/DistRDF/Backends/Dask/Backend.py +++ b/bindings/distrdf/python/DistRDF/Backends/Dask/Backend.py @@ -104,6 +104,20 @@ def __init__(self, daskclient: Optional[Client] = None): # N is the number of cores on the local machine. self.client = (daskclient if daskclient is not None else Client(LocalCluster(n_workers=os.cpu_count(), threads_per_worker=1, processes=True))) + + workers = self.client.scheduler_info().get("workers", None) + + if workers is None: + return + + for worker in workers.values(): + threads = worker.get("nthreads", 1) + + if threads > 1: + raise RuntimeError( + "DistRDF with Dask does not support threaded workers. " + "Please use processes=True and threads_per_worker=1." + ) def optimize_npartitions(self) -> int: """ diff --git a/roottest/python/distrdf/backends/check_backend.py b/roottest/python/distrdf/backends/check_backend.py index b7ae271ca26ea..f8d15ee313dff 100644 --- a/roottest/python/distrdf/backends/check_backend.py +++ b/roottest/python/distrdf/backends/check_backend.py @@ -59,6 +59,50 @@ def test_optimize_npartitions(self, payload): backend = Backend.SparkBackend(sparkcontext=connection) assert backend.optimize_npartitions() == 2 + def test_dask_backend_handles_missing_workers(self, payload): + """ + Check that DaskBackend initialization succeeds when scheduler_info + does not provide worker information. + """ + connection, backend = payload + + if backend != "dask": + return + + from ROOT._distrdf.Backends.Dask import Backend + + original_scheduler_info = connection.scheduler_info + + try: + connection.scheduler_info = lambda: {} + + backend = Backend.DaskBackend(daskclient=connection) + assert backend.client is connection + + df = ROOT.RDataFrame(10, executor=connection) + assert df.Count().GetValue() == 10 + + finally: + connection.scheduler_info = original_scheduler_info + + def test_dask_backend_rejects_threaded_workers(self): + """ + Check that DaskBackend rejects threaded workers. + """ + from dask.distributed import Client, LocalCluster + from ROOT._distrdf.Backends.Dask import Backend + + cluster = LocalCluster(n_workers=1, threads_per_worker=2, processes=False) + client = Client(cluster) + + try: + with pytest.raises(RuntimeError, match="DistRDF with Dask does not support threaded workers"): + Backend.DaskBackend(daskclient=client) + finally: + client.close() + cluster.close() + + class TestInitialization: """Check initialization method in the Dask backend"""