Skip to content

Commit d7ff40d

Browse files
committed
Add cross-thread suscription handling
1 parent d98da5d commit d7ff40d

2 files changed

Lines changed: 12 additions & 10 deletions

File tree

aidbox_python_sdk/sdk.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111

1212
logger = logging.getLogger("aidbox_sdk")
1313

14+
# (target_loop, future, counter) per entity for was_subscription_triggered_*
15+
_SubTriggered = dict[str, tuple[asyncio.AbstractEventLoop, asyncio.Future[bool], int]]
16+
1417

1518
class SDK:
1619
def __init__( # noqa: PLR0913
@@ -42,7 +45,7 @@ def __init__( # noqa: PLR0913
4245
self._seeds = seeds or {}
4346
self._migrations = migrations or []
4447
self._app_endpoint_name = f"{settings.APP_ID}-endpoint"
45-
self._sub_triggered = {}
48+
self._sub_triggered: _SubTriggered = {}
4649
self._test_start_txid = None
4750

4851
async def apply_migrations(self, client: AsyncAidboxClient):
@@ -112,14 +115,14 @@ async def handler(event, request):
112115
result = coro_or_result
113116

114117
if entity in self._sub_triggered:
115-
future, counter = self._sub_triggered[entity]
118+
target_loop, future, counter = self._sub_triggered[entity]
116119
if counter > 1:
117-
self._sub_triggered[entity] = (future, counter - 1)
120+
self._sub_triggered[entity] = (target_loop, future, counter - 1)
118121
elif future.done():
119122
pass
120123
# logger.warning('Uncaught subscription for %s', entity)
121124
else:
122-
future.set_result(True)
125+
target_loop.call_soon_threadsafe(future.set_result, True)
123126

124127
return result
125128

@@ -133,14 +136,13 @@ def get_subscription_handler(self, path):
133136

134137
def was_subscription_triggered_n_times(self, entity, counter):
135138
timeout = 10
136-
137139
future = asyncio.Future()
138-
self._sub_triggered[entity] = (future, counter)
139-
asyncio.get_event_loop().call_later(
140+
target_loop = asyncio.get_running_loop()
141+
self._sub_triggered[entity] = (target_loop, future, counter)
142+
target_loop.call_later(
140143
timeout,
141144
lambda: None if future.done() else future.set_exception(Exception()),
142145
)
143-
144146
return future
145147

146148
def was_subscription_triggered(self, entity):

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ pre-commit = ["autohooks.plugins.black", "autohooks.plugins.ruff"]
1212

1313
[tool.black]
1414
line-length = 100
15-
target-version = ['py311']
15+
target-version = ['py39']
1616
exclude = '''
1717
(
1818
/(
@@ -25,7 +25,7 @@ exclude = '''
2525
'''
2626

2727
[tool.ruff]
28-
target-version = "py311"
28+
target-version = "py39"
2929
line-length = 100
3030
extend-exclude = ["example"]
3131

0 commit comments

Comments
 (0)