@@ -39,12 +39,14 @@ class Api:
3939
4040 def __init__ (self , address : str , port : int ):
4141 self ._callbacks : dict [int , Callable [[Any ], None ]] = {}
42+ self ._topicCancelEvents : dict [int , asyncio .Event ] = {}
4243 self ._nextTopicId : int = 0
4344 self ._userOnConnect : Callable [[], Coroutine [Any , Any , None ]] | None = None
45+ self ._userOnDisconnect : Callable [[], None ] | None = None
4446
4547 socket = SocketWrapper (address , port )
46- socket .onConnect (self ._onConnect )
47- socket .onDisconnect (lambda : None )
48+ socket .onConnect (self .__onConnect )
49+ socket .onDisconnect (self . __onDisconnect )
4850 socket .onMessage (self ._handle_message )
4951
5052 self ._socket = socket
@@ -59,13 +61,23 @@ def _handle_message(self, message: str) -> None:
5961 else :
6062 print (f"Error handling message: { messageObject } " )
6163
62- async def _onConnect (self ):
64+ async def __onConnect (self ):
6365 # Send API handshake before any user-registered onConnect
6466 self ._socket .send (json .dumps (ApiVersion ))
6567 # Call user defined onConnect if it exists
6668 if self ._userOnConnect is not None :
6769 await self ._userOnConnect ()
6870
71+ def __onDisconnect (self ) -> None :
72+ # Signal all live topic iterators to stop
73+ for cancelEvent in self ._topicCancelEvents .values ():
74+ cancelEvent .set ()
75+ self ._topicCancelEvents .clear ()
76+ self ._callbacks .clear ()
77+ # Call user defined onDisconnect if it exists
78+ if self ._userOnDisconnect is not None :
79+ self ._userOnDisconnect ()
80+
6981 def onConnect (self , callback : Callable [[], Coroutine [Any , Any , None ]]) -> None :
7082 """
7183 Set the async function to call when a connection is established.
@@ -75,8 +87,8 @@ def onConnect(self, callback: Callable[[], Coroutine[Any, Any, None]]) -> None:
7587 self ._userOnConnect = callback
7688
7789 def onDisconnect (self , callback : Callable [[], None ]):
78- """Set the function to execute when socket is dicsonnected ."""
79- self ._socket . onDisconnect ( callback )
90+ """Set the function to execute when socket is disconnected ."""
91+ self ._userOnDisconnect = callback
8092
8193 async def connect (self ):
8294 """Connect to OpenSpace."""
@@ -114,22 +126,35 @@ def startTopic(self, type: str, payload: Any, cancelPayload: Any = None) -> Topi
114126 self ._callbacks [topicId ] = lambda payload : queue .put_nowait (payload )
115127
116128 cancelEvent = asyncio .Event ()
129+ self ._topicCancelEvents [topicId ] = cancelEvent
117130
118131 async def iterator () -> AsyncGenerator [Any , None ]:
119132 while not cancelEvent .is_set ():
120133 try :
121- # Poll the queue with a timeout to allow checking for cancellation
122- # without blocking indefinitely on queue.get()
123- value = await asyncio .wait_for (queue .get (), timeout = 0.1 )
124- yield value
125- except asyncio .TimeoutError :
126- continue
134+ # Race the queue against both the cancel event so we don't block indefinitely
135+ # when the connection drops
136+ get = asyncio .ensure_future (queue .get ())
137+ cancel_wait = asyncio .ensure_future (cancelEvent .wait ())
138+ done , pending = await asyncio .wait (
139+ [get , cancel_wait ],
140+ return_when = asyncio .FIRST_COMPLETED
141+ )
142+ # Clean up pending tasks to avoid leaks
143+ for task in pending :
144+ task .cancel ()
145+ if cancelEvent .is_set ():
146+ # Topic was cancelled, exit the iterator
147+ break
148+ # If the get completed successfully, we have a new value to yield
149+ if get in done and not get .cancelled ():
150+ yield get .result ()
127151 except Exception as e :
128152 print (f"Error in topic { topicId } iterator: { e } " )
129153 print_exc ()
130154 break
131- # Topic has been canceled, remove callback
155+ # Topic has been canceled, remove callback and cancel event
132156 self ._callbacks .pop (topicId , None )
157+ self ._topicCancelEvents .pop (topicId , None )
133158
134159
135160 def talk (payload : Any ) -> None :
@@ -144,6 +169,7 @@ def cancel () -> None:
144169 talk (cancelPayload )
145170 cancelEvent .set ()
146171 self ._callbacks .pop (topicId , None )
172+ self ._topicCancelEvents .pop (topicId , None )
147173
148174 return Topic (iterator (), talk , cancel )
149175
0 commit comments