|
28 | 28 | from dataclasses import dataclass |
29 | 29 | from functools import partial |
30 | 30 | from operator import add |
31 | | -from threading import Semaphore |
32 | 31 | from time import sleep |
33 | 32 | from typing import Any |
34 | 33 | from unittest import mock |
|
37 | 36 | import pytest |
38 | 37 | import yaml |
39 | 38 | from packaging.version import parse as parse_version |
40 | | -from tlz import concat, first, identity, isdistinct, merge, pluck, valmap |
| 39 | +from tlz import concat, identity, isdistinct, merge, pluck, valmap |
41 | 40 |
|
42 | 41 | import dask |
43 | 42 | import dask.bag as db |
@@ -3849,7 +3848,7 @@ def test_open_close_many_workers(loop, worker, count, repeat): |
3849 | 3848 | with cluster(nworkers=0, active_rpc_timeout=2) as (s, _): |
3850 | 3849 | gc.collect() |
3851 | 3850 | before = proc.num_fds() |
3852 | | - done = Semaphore(0) |
| 3851 | + done = threading.Semaphore(0) |
3853 | 3852 | running = weakref.WeakKeyDictionary() |
3854 | 3853 | workers = set() |
3855 | 3854 | status = True |
@@ -5409,63 +5408,43 @@ def test_quiet_quit_when_cluster_leaves(loop_in_thread): |
5409 | 5408 | assert not text |
5410 | 5409 |
|
5411 | 5410 |
|
5412 | | -@gen_cluster([("127.0.0.1", 4)] * 2, client=True) |
5413 | | -async def test_call_stack_future(c, s, a, b): |
5414 | | - x = c.submit(slowdec, 1, delay=0.5) |
5415 | | - future = c.submit(slowinc, 1, delay=0.5) |
5416 | | - await asyncio.sleep(0.1) |
5417 | | - results = await asyncio.gather( |
5418 | | - c.call_stack(future), c.call_stack(keys=[future.key]) |
5419 | | - ) |
5420 | | - assert all(list(first(result.values())) == [future.key] for result in results) |
5421 | | - assert results[0] == results[1] |
5422 | | - result = results[0] |
5423 | | - ts = a.state.tasks.get(future.key) |
5424 | | - if ts is not None and ts.state == "executing": |
5425 | | - w = a |
5426 | | - else: |
5427 | | - w = b |
| 5411 | +@gen_cluster(client=True) |
| 5412 | +async def test_call_stack(c, s, a, b): |
| 5413 | + e1, e2, e3, ew = Event(), Event(), Event(), Event() |
5428 | 5414 |
|
5429 | | - assert list(result) == [w.address] |
5430 | | - assert list(result[w.address]) == [future.key] |
5431 | | - assert "slowinc" in str(result) |
5432 | | - assert "slowdec" not in str(result) |
| 5415 | + def f(es: Event, ew: Event) -> None: |
| 5416 | + es.set() |
| 5417 | + ew.wait() |
5433 | 5418 |
|
| 5419 | + f1 = c.submit(f, e1, ew, key="f1", workers=[a.address]) |
| 5420 | + f2 = c.submit(f, e2, ew, key="f2", workers=[b.address]) |
| 5421 | + d3 = c.persist(delayed(f)(e3, ew, dask_key_name="d3"), workers=[b.address]) |
| 5422 | + await e1.wait() |
| 5423 | + await e2.wait() |
| 5424 | + await e3.wait() |
5434 | 5425 |
|
5435 | | -@gen_cluster([("127.0.0.1", 4)] * 2, client=True) |
5436 | | -async def test_call_stack_all(c, s, a, b): |
5437 | | - future = c.submit(slowinc, 1, delay=0.8) |
5438 | | - while not a.state.executing_count and not b.state.executing_count: |
5439 | | - await asyncio.sleep(0.01) |
5440 | | - result = await c.call_stack() |
5441 | | - w = a if a.state.executing_count else b |
5442 | | - assert list(result) == [w.address] |
5443 | | - assert list(result[w.address]) == [future.key] |
5444 | | - assert "slowinc" in str(result) |
| 5426 | + # Test future or keys |
| 5427 | + r1a = await c.call_stack(f1) |
| 5428 | + r1b = await c.call_stack([f1]) |
| 5429 | + r1c = await c.call_stack(keys=[f1.key]) |
5445 | 5430 |
|
| 5431 | + assert r1a == r1b == r1c |
| 5432 | + assert r1a.keys() == {a.address} |
| 5433 | + assert r1a[a.address].keys() == {"f1"} |
| 5434 | + assert any("event.py" in frame for frame in r1a[a.address]["f1"]) |
5446 | 5435 |
|
5447 | | -@gen_cluster([("127.0.0.1", 4)] * 2, client=True) |
5448 | | -async def test_call_stack_collections(c, s, a, b): |
5449 | | - pytest.importorskip("numpy") |
5450 | | - da = pytest.importorskip("dask.array") |
5451 | | - |
5452 | | - x = c.persist(da.random.random(100, chunks=(10,)).map_blocks(slowinc, delay=0.5)) |
5453 | | - while not a.state.executing_count and not b.state.executing_count: |
5454 | | - await asyncio.sleep(0.001) |
5455 | | - result = await c.call_stack(x) |
5456 | | - assert result |
| 5436 | + # test collection |
| 5437 | + r3 = await c.call_stack(d3) |
| 5438 | + assert r3.keys() == {b.address} |
| 5439 | + assert r3[b.address].keys() == {"d3"} |
5457 | 5440 |
|
| 5441 | + # test all |
| 5442 | + r4 = await c.call_stack() |
| 5443 | + assert r4.keys() == {a.address, b.address} |
| 5444 | + assert r4[a.address].keys() == {"f1"} |
| 5445 | + assert r4[b.address].keys() == {"f2", "d3"} |
5458 | 5446 |
|
5459 | | -@gen_cluster([("127.0.0.1", 4)] * 2, client=True) |
5460 | | -async def test_call_stack_collections_all(c, s, a, b): |
5461 | | - pytest.importorskip("numpy") |
5462 | | - da = pytest.importorskip("dask.array") |
5463 | | - |
5464 | | - x = c.persist(da.random.random(100, chunks=(10,)).map_blocks(slowinc, delay=0.5)) |
5465 | | - while not a.state.executing_count and not b.state.executing_count: |
5466 | | - await asyncio.sleep(0.001) |
5467 | | - result = await c.call_stack() |
5468 | | - assert result |
| 5447 | + await ew.set() |
5469 | 5448 |
|
5470 | 5449 |
|
5471 | 5450 | @pytest.mark.skipif(sys.version_info.minor == 11, reason="Profiler disabled") |
|
0 commit comments