Skip to content

Commit d2e2af9

Browse files
committed
Update condition.py,test_condition.py
1 parent d8f98d5 commit d2e2af9

2 files changed

Lines changed: 62 additions & 42 deletions

File tree

distributed/condition.py

Lines changed: 58 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from dask.utils import parse_timedelta
99

10-
from distributed.semaphore import Semaphore
10+
from distributed.lock import Lock
1111
from distributed.utils import SyncMethodMixin, TimeoutError, log_errors, wait_for
1212
from distributed.worker import get_client
1313

@@ -17,8 +17,8 @@
1717
class ConditionExtension:
1818
"""Scheduler extension for managing Condition variable notifications
1919
20-
This extension only handles wait/notify coordination.
21-
The underlying lock is a Semaphore managed by SemaphoreExtension.
20+
Coordinates wait/notify between distributed clients.
21+
The lock itself is managed by LockExtension.
2222
"""
2323

2424
def __init__(self, scheduler):
@@ -36,16 +36,13 @@ def __init__(self, scheduler):
3636

3737
@log_errors
3838
async def wait(self, name=None, id=None, timeout=None):
39-
"""Wait to be notified
39+
"""Register waiter and block until notified
4040
41-
Caller must already hold the lock (Semaphore lease).
42-
This only manages the wait/notify Events.
41+
Caller must have released the lock before calling this.
4342
"""
44-
# Create event for this waiter
4543
event = asyncio.Event()
4644
self._waiters[name][id] = event
4745

48-
# Wait on event
4946
future = event.wait()
5047
if timeout is not None:
5148
future = wait_for(future, timeout)
@@ -56,7 +53,6 @@ async def wait(self, name=None, id=None, timeout=None):
5653
except TimeoutError:
5754
result = False
5855
finally:
59-
# Cleanup waiter
6056
self._waiters[name].pop(id, None)
6157
if not self._waiters[name]:
6258
del self._waiters[name]
@@ -65,7 +61,7 @@ async def wait(self, name=None, id=None, timeout=None):
6561

6662
@log_errors
6763
def notify(self, name=None, n=1):
68-
"""Notify n waiters"""
64+
"""Wake up n waiters"""
6965
waiters = self._waiters.get(name, {})
7066
count = 0
7167
for event in list(waiters.values())[:n]:
@@ -75,7 +71,7 @@ def notify(self, name=None, n=1):
7571

7672
@log_errors
7773
def notify_all(self, name=None):
78-
"""Notify all waiters"""
74+
"""Wake up all waiters"""
7975
waiters = self._waiters.get(name, {})
8076
for event in waiters.values():
8177
event.set()
@@ -85,33 +81,36 @@ def notify_all(self, name=None):
8581
class Condition(SyncMethodMixin):
8682
"""Distributed Condition Variable
8783
88-
Combines a Semaphore (lock) with wait/notify coordination.
84+
Combines a Lock with wait/notify coordination across the cluster.
8985
9086
Parameters
9187
----------
9288
name : str, optional
93-
Name of the condition. Same name = shared state.
89+
Name of the condition. Conditions with the same name share state.
9490
client : Client, optional
9591
Client for scheduler communication.
9692
9793
Examples
9894
--------
99-
>>> from distributed import Condition
100-
>>> condition = Condition('my-condition')
95+
Producer-consumer pattern:
96+
97+
>>> condition = Condition('data-ready')
98+
>>> # Consumer
10199
>>> async with condition:
102-
... await condition.wait()
100+
... while not data_available():
101+
... await condition.wait()
102+
... process_data()
103103
104-
>>> # In another worker/client
105-
>>> condition = Condition('my-condition')
104+
>>> # Producer
106105
>>> async with condition:
107-
... condition.notify()
106+
... produce_data()
107+
... condition.notify_all()
108108
"""
109109

110110
def __init__(self, name=None, client=None):
111111
self.name = name or f"condition-{uuid.uuid4().hex}"
112112
self.id = uuid.uuid4().hex
113-
# Use Semaphore(max_leases=1) as the underlying lock
114-
self._lock = Semaphore(max_leases=1, name=f"{self.name}-lock")
113+
self._lock = Lock(name=f"{self.name}-lock")
115114
self._client = client
116115

117116
@property
@@ -136,70 +135,92 @@ def _verify_running(self):
136135
)
137136

138137
async def acquire(self):
139-
"""Acquire underlying lock"""
140-
result = await self._lock.acquire()
141-
return result
138+
"""Acquire the underlying lock"""
139+
return await self._lock.acquire()
142140

143141
async def release(self):
144-
"""Release underlying lock"""
142+
"""Release the underlying lock"""
145143
await self._lock.release()
146144

147145
async def wait(self, timeout=None):
148146
"""Wait until notified
149147
150-
Must be called while lock is held. Releases lock and waits
151-
for notify(), then reacquires lock before returning.
148+
Must be called while lock is held. Atomically releases lock,
149+
waits for notify(), then reacquires lock before returning.
152150
153151
Parameters
154152
----------
155153
timeout : number or string or timedelta, optional
156-
Seconds to wait on the condition in the scheduler.
154+
Maximum time to wait for notification.
157155
158156
Returns
159157
-------
160158
bool
161159
True if notified, False if timeout occurred
160+
161+
Raises
162+
------
163+
RuntimeError
164+
If called without holding the lock
162165
"""
163166
if not self._lock.locked():
164167
raise RuntimeError("wait() called without holding the lock")
165168

166169
self._verify_running()
167170
timeout = parse_timedelta(timeout)
168171

169-
# Release lock
172+
# Atomically: release lock, wait for notify, reacquire lock
170173
await self._lock.release()
171-
172-
# Wait for notification
173174
try:
174175
result = await self.client.scheduler.condition_wait(
175176
name=self.name, id=self.id, timeout=timeout
176177
)
177178
finally:
178-
# Reacquire lock
179179
await self._lock.acquire()
180180

181181
return result
182182

183183
def notify(self, n=1):
184-
"""Wake up one or more waiters"""
184+
"""Wake up one or more waiters
185+
186+
Must be called while holding the lock.
187+
188+
Parameters
189+
----------
190+
n : int, optional
191+
Number of waiters to wake. Default is 1.
192+
193+
Returns
194+
-------
195+
int
196+
Number of waiters actually notified
197+
"""
185198
if not self._lock.locked():
186-
raise RuntimeError("Cannot notify without holding the lock")
199+
raise RuntimeError("notify() called without holding the lock")
187200
self._verify_running()
188201
return self.client.sync(
189202
self.client.scheduler.condition_notify, name=self.name, n=n
190203
)
191204

192205
def notify_all(self):
193-
"""Wake up all waiters"""
206+
"""Wake up all waiters
207+
208+
Must be called while holding the lock.
209+
210+
Returns
211+
-------
212+
int
213+
Number of waiters notified
214+
"""
194215
if not self._lock.locked():
195-
raise RuntimeError("Cannot notify without holding the lock")
216+
raise RuntimeError("notify_all() called without holding the lock")
196217
self._verify_running()
197218
return self.client.sync(
198219
self.client.scheduler.condition_notify_all, name=self.name
199220
)
200221

201222
def locked(self):
202-
"""Return True if lock is held"""
223+
"""Return True if the lock is currently held"""
203224
return self._lock.locked()
204225

205226
async def __aenter__(self):

distributed/tests/test_condition.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ async def test_condition_error_release_without_acquire(c, s, a, b):
150150
"""Test error when releasing without acquiring"""
151151
condition = Condition("test-release-error")
152152

153-
with pytest.raises(RuntimeError, match="Cannot release"):
153+
with pytest.raises(RuntimeError, match="Released too often"):
154154
await condition.release()
155155

156156

@@ -286,17 +286,16 @@ async def test_condition_cleanup(c, s, a, b):
286286
"""Test that condition state is cleaned up after use"""
287287
condition = Condition("cleanup-test")
288288

289-
# Check initial state
290-
assert "cleanup-test" not in s.extensions["conditions"]._lock_holders
289+
# Check initial state - only check waiters since locks are managed by LockExtension
291290
assert "cleanup-test" not in s.extensions["conditions"]._waiters
292291

293292
# Use condition
294293
async with condition:
295294
condition.notify()
296295

297-
# State should be cleaned up
296+
# Waiter state should be cleaned up
298297
await asyncio.sleep(0.1)
299-
assert "cleanup-test" not in s.extensions["conditions"]._lock_holders
298+
assert "cleanup-test" not in s.extensions["conditions"]._waiters
300299

301300

302301
@gen_cluster(client=True)

0 commit comments

Comments
 (0)