2121_logger = logging .getLogger (__name__ )
2222
2323
24- class Broadcast (Generic [ChannelMessageT ]):
24+ class Broadcast ( # pylint: disable=too-many-instance-attributes
25+ Generic [ChannelMessageT ]
26+ ):
2527 """A channel that deliver all messages to all receivers.
2628
2729 # Description
@@ -184,7 +186,13 @@ async def main() -> None:
184186 ```
185187 """
186188
187- def __init__ (self , * , name : str , resend_latest : bool = False ) -> None :
189+ def __init__ (
190+ self ,
191+ * ,
192+ name : str ,
193+ resend_latest : bool = False ,
194+ auto_close : bool = False ,
195+ ) -> None :
188196 """Initialize this channel.
189197
190198 Args:
@@ -197,6 +205,8 @@ def __init__(self, *, name: str, resend_latest: bool = False) -> None:
197205 wait for the next message on the channel to arrive. It is safe to be
198206 set in data/reporting channels, but is not recommended for use in
199207 channels that stream control instructions.
208+ auto_close: If True, the channel will be closed when all senders or all
209+ receivers are closed.
200210 """
201211 self ._name : str = name
202212 """The name of the broadcast channel.
@@ -221,6 +231,9 @@ def __init__(self, *, name: str, resend_latest: bool = False) -> None:
221231 self ._latest : ChannelMessageT | None = None
222232 """The latest message sent to the channel."""
223233
234+ self ._auto_close_enabled : bool = auto_close
235+ """Whether to close the channel when all senders or all receivers are closed."""
236+
224237 self .resend_latest : bool = resend_latest
225238 """Whether to resend the latest message to new receivers.
226239
@@ -367,6 +380,10 @@ async def send(self, message: _T, /) -> None:
367380 raise SenderError ("The channel was closed" , self ) from ChannelClosedError (
368381 self ._channel
369382 )
383+ if self ._channel ._auto_close_enabled and len (self ._channel ._receivers ) == 0 :
384+ raise SenderError ("The channel was closed" , self ) from ChannelClosedError (
385+ self ._channel
386+ )
370387 self ._channel ._latest = message
371388 stale_refs = []
372389 for _hash , recv_ref in self ._channel ._receivers .items ():
@@ -394,6 +411,12 @@ async def aclose(self) -> None:
394411 self ._closed = True
395412 self ._channel ._sender_count -= 1
396413
414+ if (
415+ self ._channel ._sender_count == 0 # pylint: disable=protected-access
416+ and self ._channel ._auto_close_enabled # pylint: disable=protected-access
417+ ):
418+ await self ._channel .aclose ()
419+
397420 def __del__ (self ) -> None :
398421 """Clean up this sender."""
399422 if not self ._closed :
@@ -527,6 +550,11 @@ async def ready(self) -> bool:
527550 while len (self ._q ) == 0 :
528551 if self ._channel ._closed or self ._closed :
529552 return False
553+ if self ._channel ._auto_close_enabled and (
554+ self ._channel ._sender_count == 0 or len (self ._channel ._receivers ) == 0
555+ ):
556+ await self ._channel .aclose ()
557+ return False
530558 async with self ._channel ._recv_cv :
531559 await self ._channel ._recv_cv .wait ()
532560 return True
0 commit comments