Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions bindings/distrdf/python/DistRDF/Backends/Dask/Backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,20 @@ def __init__(self, daskclient: Optional[Client] = None):
# N is the number of cores on the local machine.
self.client = (daskclient if daskclient is not None else
Client(LocalCluster(n_workers=os.cpu_count(), threads_per_worker=1, processes=True)))

workers = self.client.scheduler_info().get("workers", None)

if workers is None:
return

for worker in workers.values():
threads = worker.get("nthreads", 1)

if threads > 1:
raise RuntimeError(
"DistRDF with Dask does not support threaded workers. "
"Please use processes=True and threads_per_worker=1."
)

def optimize_npartitions(self) -> int:
"""
Expand Down
44 changes: 44 additions & 0 deletions roottest/python/distrdf/backends/check_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,50 @@ def test_optimize_npartitions(self, payload):
backend = Backend.SparkBackend(sparkcontext=connection)
assert backend.optimize_npartitions() == 2

def test_dask_backend_handles_missing_workers(self, payload):
"""
Check that DaskBackend initialization succeeds when scheduler_info
does not provide worker information.
"""
connection, backend = payload

if backend != "dask":
return

from ROOT._distrdf.Backends.Dask import Backend

original_scheduler_info = connection.scheduler_info

try:
connection.scheduler_info = lambda: {}

backend = Backend.DaskBackend(daskclient=connection)
assert backend.client is connection

df = ROOT.RDataFrame(10, executor=connection)
assert df.Count().GetValue() == 10

finally:
connection.scheduler_info = original_scheduler_info
Comment on lines +63 to +86

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting test, but I would appreciate if you added also some actual computation so we can check that on top of having a valid connection, RDataFrame can also still work with it.

Afterwards, there needs to be another test, namely that checks the new RuntimeError introduced with this PR

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a distributed RDataFrame computation to the missing-workers test to verify that execution still works when scheduler_info() does not provide worker information. Also added a new test that checks the expected RuntimeError is raised when using threaded Dask workers.


def test_dask_backend_rejects_threaded_workers(self):
"""
Check that DaskBackend rejects threaded workers.
"""
from dask.distributed import Client, LocalCluster
from ROOT._distrdf.Backends.Dask import Backend

cluster = LocalCluster(n_workers=1, threads_per_worker=2, processes=False)
client = Client(cluster)

try:
with pytest.raises(RuntimeError, match="DistRDF with Dask does not support threaded workers"):
Backend.DaskBackend(daskclient=client)
finally:
client.close()
cluster.close()



class TestInitialization:
"""Check initialization method in the Dask backend"""
Expand Down
Loading