Skip to content

Commit c12638f

Browse files
authored
Fix flaky test_call_stack_future (#9277)
1 parent 6f721b8 commit c12638f

1 file changed

Lines changed: 32 additions & 53 deletions

File tree

distributed/tests/test_client.py

Lines changed: 32 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from dataclasses import dataclass
2929
from functools import partial
3030
from operator import add
31-
from threading import Semaphore
3231
from time import sleep
3332
from typing import Any
3433
from unittest import mock
@@ -37,7 +36,7 @@
3736
import pytest
3837
import yaml
3938
from packaging.version import parse as parse_version
40-
from tlz import concat, first, identity, isdistinct, merge, pluck, valmap
39+
from tlz import concat, identity, isdistinct, merge, pluck, valmap
4140

4241
import dask
4342
import dask.bag as db
@@ -3849,7 +3848,7 @@ def test_open_close_many_workers(loop, worker, count, repeat):
38493848
with cluster(nworkers=0, active_rpc_timeout=2) as (s, _):
38503849
gc.collect()
38513850
before = proc.num_fds()
3852-
done = Semaphore(0)
3851+
done = threading.Semaphore(0)
38533852
running = weakref.WeakKeyDictionary()
38543853
workers = set()
38553854
status = True
@@ -5409,63 +5408,43 @@ def test_quiet_quit_when_cluster_leaves(loop_in_thread):
54095408
assert not text
54105409

54115410

5412-
@gen_cluster([("127.0.0.1", 4)] * 2, client=True)
5413-
async def test_call_stack_future(c, s, a, b):
5414-
x = c.submit(slowdec, 1, delay=0.5)
5415-
future = c.submit(slowinc, 1, delay=0.5)
5416-
await asyncio.sleep(0.1)
5417-
results = await asyncio.gather(
5418-
c.call_stack(future), c.call_stack(keys=[future.key])
5419-
)
5420-
assert all(list(first(result.values())) == [future.key] for result in results)
5421-
assert results[0] == results[1]
5422-
result = results[0]
5423-
ts = a.state.tasks.get(future.key)
5424-
if ts is not None and ts.state == "executing":
5425-
w = a
5426-
else:
5427-
w = b
5411+
@gen_cluster(client=True)
5412+
async def test_call_stack(c, s, a, b):
5413+
e1, e2, e3, ew = Event(), Event(), Event(), Event()
54285414

5429-
assert list(result) == [w.address]
5430-
assert list(result[w.address]) == [future.key]
5431-
assert "slowinc" in str(result)
5432-
assert "slowdec" not in str(result)
5415+
def f(es: Event, ew: Event) -> None:
5416+
es.set()
5417+
ew.wait()
54335418

5419+
f1 = c.submit(f, e1, ew, key="f1", workers=[a.address])
5420+
f2 = c.submit(f, e2, ew, key="f2", workers=[b.address])
5421+
d3 = c.persist(delayed(f)(e3, ew, dask_key_name="d3"), workers=[b.address])
5422+
await e1.wait()
5423+
await e2.wait()
5424+
await e3.wait()
54345425

5435-
@gen_cluster([("127.0.0.1", 4)] * 2, client=True)
5436-
async def test_call_stack_all(c, s, a, b):
5437-
future = c.submit(slowinc, 1, delay=0.8)
5438-
while not a.state.executing_count and not b.state.executing_count:
5439-
await asyncio.sleep(0.01)
5440-
result = await c.call_stack()
5441-
w = a if a.state.executing_count else b
5442-
assert list(result) == [w.address]
5443-
assert list(result[w.address]) == [future.key]
5444-
assert "slowinc" in str(result)
5426+
# Test future or keys
5427+
r1a = await c.call_stack(f1)
5428+
r1b = await c.call_stack([f1])
5429+
r1c = await c.call_stack(keys=[f1.key])
54455430

5431+
assert r1a == r1b == r1c
5432+
assert r1a.keys() == {a.address}
5433+
assert r1a[a.address].keys() == {"f1"}
5434+
assert any("event.py" in frame for frame in r1a[a.address]["f1"])
54465435

5447-
@gen_cluster([("127.0.0.1", 4)] * 2, client=True)
5448-
async def test_call_stack_collections(c, s, a, b):
5449-
pytest.importorskip("numpy")
5450-
da = pytest.importorskip("dask.array")
5451-
5452-
x = c.persist(da.random.random(100, chunks=(10,)).map_blocks(slowinc, delay=0.5))
5453-
while not a.state.executing_count and not b.state.executing_count:
5454-
await asyncio.sleep(0.001)
5455-
result = await c.call_stack(x)
5456-
assert result
5436+
# test collection
5437+
r3 = await c.call_stack(d3)
5438+
assert r3.keys() == {b.address}
5439+
assert r3[b.address].keys() == {"d3"}
54575440

5441+
# test all
5442+
r4 = await c.call_stack()
5443+
assert r4.keys() == {a.address, b.address}
5444+
assert r4[a.address].keys() == {"f1"}
5445+
assert r4[b.address].keys() == {"f2", "d3"}
54585446

5459-
@gen_cluster([("127.0.0.1", 4)] * 2, client=True)
5460-
async def test_call_stack_collections_all(c, s, a, b):
5461-
pytest.importorskip("numpy")
5462-
da = pytest.importorskip("dask.array")
5463-
5464-
x = c.persist(da.random.random(100, chunks=(10,)).map_blocks(slowinc, delay=0.5))
5465-
while not a.state.executing_count and not b.state.executing_count:
5466-
await asyncio.sleep(0.001)
5467-
result = await c.call_stack()
5468-
assert result
5447+
await ew.set()
54695448

54705449

54715450
@pytest.mark.skipif(sys.version_info.minor == 11, reason="Profiler disabled")

0 commit comments

Comments
 (0)