Skip to content

Commit 14c59d2

Browse files
Cleanup live topics on disconnect or if OpenSpace crashes
1 parent bc9d6bd commit 14c59d2

2 files changed

Lines changed: 42 additions & 12 deletions

File tree

src/openspace/src/api.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@ class Api:
3939

4040
def __init__(self, address: str, port: int):
4141
self._callbacks: dict[int, Callable[[Any], None]] = {}
42+
self._topicCancelEvents: dict[int, asyncio.Event] = {}
4243
self._nextTopicId: int = 0
4344
self._userOnConnect: Callable[[], Coroutine[Any, Any, None]] | None = None
45+
self._userOnDisconnect: Callable[[], None] | None = None
4446

4547
socket = SocketWrapper(address, port)
46-
socket.onConnect(self._onConnect)
47-
socket.onDisconnect(lambda: None)
48+
socket.onConnect(self.__onConnect)
49+
socket.onDisconnect(self.__onDisconnect)
4850
socket.onMessage(self._handle_message)
4951

5052
self._socket = socket
@@ -59,13 +61,23 @@ def _handle_message(self, message: str) -> None:
5961
else:
6062
print(f"Error handling message: {messageObject}")
6163

62-
async def _onConnect(self):
64+
async def __onConnect(self):
6365
# Send API handshake before any user-registered onConnect
6466
self._socket.send(json.dumps(ApiVersion))
6567
# Call user defined onConnect if it exists
6668
if self._userOnConnect is not None:
6769
await self._userOnConnect()
6870

71+
def __onDisconnect(self) -> None:
72+
# Signal all live topic iterators to stop
73+
for cancelEvent in self._topicCancelEvents.values():
74+
cancelEvent.set()
75+
self._topicCancelEvents.clear()
76+
self._callbacks.clear()
77+
# Call user defined onDisconnect if it exists
78+
if self._userOnDisconnect is not None:
79+
self._userOnDisconnect()
80+
6981
def onConnect(self, callback: Callable[[], Coroutine[Any, Any, None]]) -> None:
7082
"""
7183
Set the async function to call when a connection is established.
@@ -75,8 +87,8 @@ def onConnect(self, callback: Callable[[], Coroutine[Any, Any, None]]) -> None:
7587
self._userOnConnect = callback
7688

7789
def onDisconnect(self, callback: Callable[[], None]):
78-
"""Set the function to execute when socket is dicsonnected."""
79-
self._socket.onDisconnect(callback)
90+
"""Set the function to execute when socket is disconnected."""
91+
self._userOnDisconnect = callback
8092

8193
async def connect(self):
8294
"""Connect to OpenSpace."""
@@ -114,22 +126,35 @@ def startTopic(self, type: str, payload: Any, cancelPayload: Any = None) -> Topi
114126
self._callbacks[topicId] = lambda payload: queue.put_nowait(payload)
115127

116128
cancelEvent = asyncio.Event()
129+
self._topicCancelEvents[topicId] = cancelEvent
117130

118131
async def iterator() -> AsyncGenerator[Any, None]:
119132
while not cancelEvent.is_set():
120133
try:
121-
# Poll the queue with a timeout to allow checking for cancellation
122-
# without blocking indefinitely on queue.get()
123-
value = await asyncio.wait_for(queue.get(), timeout=0.1)
124-
yield value
125-
except asyncio.TimeoutError:
126-
continue
134+
# Race the queue against both the cancel event so we don't block indefinitely
135+
# when the connection drops
136+
get = asyncio.ensure_future(queue.get())
137+
cancel_wait = asyncio.ensure_future(cancelEvent.wait())
138+
done, pending = await asyncio.wait(
139+
[get, cancel_wait],
140+
return_when=asyncio.FIRST_COMPLETED
141+
)
142+
# Clean up pending tasks to avoid leaks
143+
for task in pending:
144+
task.cancel()
145+
if cancelEvent.is_set():
146+
# Topic was cancelled, exit the iterator
147+
break
148+
# If the get completed successfully, we have a new value to yield
149+
if get in done and not get.cancelled():
150+
yield get.result()
127151
except Exception as e:
128152
print(f"Error in topic {topicId} iterator: {e}")
129153
print_exc()
130154
break
131-
# Topic has been canceled, remove callback
155+
# Topic has been canceled, remove callback and cancel event
132156
self._callbacks.pop(topicId, None)
157+
self._topicCancelEvents.pop(topicId, None)
133158

134159

135160
def talk(payload: Any) -> None:
@@ -144,6 +169,7 @@ def cancel () -> None:
144169
talk(cancelPayload)
145170
cancelEvent.set()
146171
self._callbacks.pop(topicId, None)
172+
self._topicCancelEvents.pop(topicId, None)
147173

148174
return Topic(iterator(), talk, cancel)
149175

src/openspace/src/socketwrapper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ async def _handleReceive(self) -> None:
6969
print(f"Connection error: {e}")
7070
print_exc()
7171
break
72+
except Exception as e:
73+
print(f"Unexpected error: {type(e)}: {e}")
74+
print_exc()
75+
break
7276
self.disconnect()
7377

7478
async def connect(self) -> None:

0 commit comments

Comments
 (0)