77
88from dask .utils import parse_timedelta
99
10- from distributed .semaphore import Semaphore
10+ from distributed .lock import Lock
1111from distributed .utils import SyncMethodMixin , TimeoutError , log_errors , wait_for
1212from distributed .worker import get_client
1313
1717class ConditionExtension :
1818 """Scheduler extension for managing Condition variable notifications
1919
20- This extension only handles wait/notify coordination .
21- The underlying lock is a Semaphore managed by SemaphoreExtension .
20+ Coordinates wait/notify between distributed clients .
21+ The lock itself is managed by LockExtension .
2222 """
2323
2424 def __init__ (self , scheduler ):
@@ -36,16 +36,13 @@ def __init__(self, scheduler):
3636
3737 @log_errors
3838 async def wait (self , name = None , id = None , timeout = None ):
39- """Wait to be notified
39+ """Register waiter and block until notified
4040
41- Caller must already hold the lock (Semaphore lease).
42- This only manages the wait/notify Events.
41+ Caller must have released the lock before calling this.
4342 """
44- # Create event for this waiter
4543 event = asyncio .Event ()
4644 self ._waiters [name ][id ] = event
4745
48- # Wait on event
4946 future = event .wait ()
5047 if timeout is not None :
5148 future = wait_for (future , timeout )
@@ -56,7 +53,6 @@ async def wait(self, name=None, id=None, timeout=None):
5653 except TimeoutError :
5754 result = False
5855 finally :
59- # Cleanup waiter
6056 self ._waiters [name ].pop (id , None )
6157 if not self ._waiters [name ]:
6258 del self ._waiters [name ]
@@ -65,7 +61,7 @@ async def wait(self, name=None, id=None, timeout=None):
6561
6662 @log_errors
6763 def notify (self , name = None , n = 1 ):
68- """Notify n waiters"""
64+ """Wake up n waiters"""
6965 waiters = self ._waiters .get (name , {})
7066 count = 0
7167 for event in list (waiters .values ())[:n ]:
@@ -75,7 +71,7 @@ def notify(self, name=None, n=1):
7571
7672 @log_errors
7773 def notify_all (self , name = None ):
78- """Notify all waiters"""
74+ """Wake up all waiters"""
7975 waiters = self ._waiters .get (name , {})
8076 for event in waiters .values ():
8177 event .set ()
@@ -85,33 +81,36 @@ def notify_all(self, name=None):
8581class Condition (SyncMethodMixin ):
8682 """Distributed Condition Variable
8783
88- Combines a Semaphore (lock) with wait/notify coordination.
84+ Combines a Lock with wait/notify coordination across the cluster .
8985
9086 Parameters
9187 ----------
9288 name : str, optional
93- Name of the condition. Same name = shared state.
89+ Name of the condition. Conditions with the same name share state.
9490 client : Client, optional
9591 Client for scheduler communication.
9692
9793 Examples
9894 --------
99- >>> from distributed import Condition
100- >>> condition = Condition('my-condition')
95+ Producer-consumer pattern:
96+
97+ >>> condition = Condition('data-ready')
98+ >>> # Consumer
10199 >>> async with condition:
102- ... await condition.wait()
100+ ... while not data_available():
101+ ... await condition.wait()
102+ ... process_data()
103103
104- >>> # In another worker/client
105- >>> condition = Condition('my-condition')
104+ >>> # Producer
106105 >>> async with condition:
107- ... condition.notify()
106+ ... produce_data()
107+ ... condition.notify_all()
108108 """
109109
110110 def __init__ (self , name = None , client = None ):
111111 self .name = name or f"condition-{ uuid .uuid4 ().hex } "
112112 self .id = uuid .uuid4 ().hex
113- # Use Semaphore(max_leases=1) as the underlying lock
114- self ._lock = Semaphore (max_leases = 1 , name = f"{ self .name } -lock" )
113+ self ._lock = Lock (name = f"{ self .name } -lock" )
115114 self ._client = client
116115
117116 @property
@@ -136,70 +135,92 @@ def _verify_running(self):
136135 )
137136
138137 async def acquire (self ):
139- """Acquire underlying lock"""
140- result = await self ._lock .acquire ()
141- return result
138+ """Acquire the underlying lock"""
139+ return await self ._lock .acquire ()
142140
143141 async def release (self ):
144- """Release underlying lock"""
142+ """Release the underlying lock"""
145143 await self ._lock .release ()
146144
147145 async def wait (self , timeout = None ):
148146 """Wait until notified
149147
150- Must be called while lock is held. Releases lock and waits
151- for notify(), then reacquires lock before returning.
148+ Must be called while lock is held. Atomically releases lock,
149+ waits for notify(), then reacquires lock before returning.
152150
153151 Parameters
154152 ----------
155153 timeout : number or string or timedelta, optional
156- Seconds to wait on the condition in the scheduler .
154+ Maximum time to wait for notification .
157155
158156 Returns
159157 -------
160158 bool
161159 True if notified, False if timeout occurred
160+
161+ Raises
162+ ------
163+ RuntimeError
164+ If called without holding the lock
162165 """
163166 if not self ._lock .locked ():
164167 raise RuntimeError ("wait() called without holding the lock" )
165168
166169 self ._verify_running ()
167170 timeout = parse_timedelta (timeout )
168171
169- # Release lock
172+ # Atomically: release lock, wait for notify, reacquire lock
170173 await self ._lock .release ()
171-
172- # Wait for notification
173174 try :
174175 result = await self .client .scheduler .condition_wait (
175176 name = self .name , id = self .id , timeout = timeout
176177 )
177178 finally :
178- # Reacquire lock
179179 await self ._lock .acquire ()
180180
181181 return result
182182
183183 def notify (self , n = 1 ):
184- """Wake up one or more waiters"""
184+ """Wake up one or more waiters
185+
186+ Must be called while holding the lock.
187+
188+ Parameters
189+ ----------
190+ n : int, optional
191+ Number of waiters to wake. Default is 1.
192+
193+ Returns
194+ -------
195+ int
196+ Number of waiters actually notified
197+ """
185198 if not self ._lock .locked ():
186- raise RuntimeError ("Cannot notify without holding the lock" )
199+ raise RuntimeError ("notify() called without holding the lock" )
187200 self ._verify_running ()
188201 return self .client .sync (
189202 self .client .scheduler .condition_notify , name = self .name , n = n
190203 )
191204
192205 def notify_all (self ):
193- """Wake up all waiters"""
206+ """Wake up all waiters
207+
208+ Must be called while holding the lock.
209+
210+ Returns
211+ -------
212+ int
213+ Number of waiters notified
214+ """
194215 if not self ._lock .locked ():
195- raise RuntimeError ("Cannot notify without holding the lock" )
216+ raise RuntimeError ("notify_all() called without holding the lock" )
196217 self ._verify_running ()
197218 return self .client .sync (
198219 self .client .scheduler .condition_notify_all , name = self .name
199220 )
200221
201222 def locked (self ):
202- """Return True if lock is held"""
223+ """Return True if the lock is currently held"""
203224 return self ._lock .locked ()
204225
205226 async def __aenter__ (self ):
0 commit comments