2121_logger = logging .getLogger (__name__ )
2222
2323
24- class Broadcast (Generic [ChannelMessageT ]):
24+ class Broadcast ( # pylint: disable=too-many-instance-attributes
25+ Generic [ChannelMessageT ]
26+ ):
2527 """A channel that deliver all messages to all receivers.
2628
2729 # Description
@@ -184,7 +186,13 @@ async def main() -> None:
184186 ```
185187 """
186188
187- def __init__ (self , * , name : str , resend_latest : bool = False ) -> None :
189+ def __init__ (
190+ self ,
191+ * ,
192+ name : str ,
193+ resend_latest : bool = False ,
194+ auto_close : bool = False ,
195+ ) -> None :
188196 """Initialize this channel.
189197
190198 Args:
@@ -197,6 +205,8 @@ def __init__(self, *, name: str, resend_latest: bool = False) -> None:
197205 wait for the next message on the channel to arrive. It is safe to be
198206 set in data/reporting channels, but is not recommended for use in
199207 channels that stream control instructions.
208+ auto_close: If True, the channel will be closed when all senders or all
209+ receivers are closed.
200210 """
201211 self ._name : str = name
202212 """The name of the broadcast channel.
@@ -221,6 +231,9 @@ def __init__(self, *, name: str, resend_latest: bool = False) -> None:
221231 self ._latest : ChannelMessageT | None = None
222232 """The latest message sent to the channel."""
223233
234+ self ._auto_close : bool = auto_close
235+ """Whether to close the channel when all senders or all receivers are closed."""
236+
224237 self .resend_latest : bool = resend_latest
225238 """Whether to resend the latest message to new receivers.
226239
@@ -355,13 +368,20 @@ async def send(self, message: _T, /) -> None:
355368 set as the cause.
356369 SenderClosedError: If this sender was closed.
357370 """
358- if self ._closed :
359- raise SenderClosedError (self )
360371 # pylint: disable=protected-access
361372 if self ._channel ._closed :
362373 raise SenderError ("The channel was closed" , self ) from ChannelClosedError (
363374 self ._channel
364375 )
376+ if self ._channel ._auto_close and (
377+ self ._channel ._sender_count == 0 or len (self ._channel ._receivers ) == 0
378+ ):
379+ await self ._channel .aclose ()
380+ raise SenderError ("The channel was closed" , self ) from ChannelClosedError (
381+ self ._channel
382+ )
383+ if self ._closed :
384+ raise SenderClosedError (self )
365385 self ._channel ._latest = message
366386 stale_refs = []
367387 for _hash , recv_ref in self ._channel ._receivers .items ():
@@ -508,6 +528,11 @@ async def ready(self) -> bool:
508528 while len (self ._q ) == 0 :
509529 if self ._channel ._closed or self ._closed :
510530 return False
531+ if self ._channel ._auto_close and (
532+ self ._channel ._sender_count == 0 or len (self ._channel ._receivers ) == 0
533+ ):
534+ await self ._channel .aclose ()
535+ return False
511536 async with self ._channel ._recv_cv :
512537 await self ._channel ._recv_cv .wait ()
513538 return True
0 commit comments