|
1 | | -import contextlib |
2 | | -import sys |
3 | 1 | from collections.abc import AsyncGenerator, Awaitable, Callable |
4 | 2 | from functools import wraps |
5 | 3 | from types import TracebackType |
@@ -62,6 +60,26 @@ async def __anext__(self) -> Any: |
62 | 60 | return next_msg |
63 | 61 |
|
64 | 62 |
|
| 63 | +class SubscriptionCtxProxy(ObjectProxy): # type: ignore |
| 64 | + """Proxy object for subscription context manager.""" |
| 65 | + |
| 66 | + def __init__(self, wrapped: Any, tracer: Tracer) -> None: |
| 67 | + super().__init__(wrapped) |
| 68 | + self._self_tracer = tracer |
| 69 | + self._self_sub = None |
| 70 | + |
| 71 | + async def __aenter__(self) -> Any: |
| 72 | + sub = await self.__wrapped__.__aenter__() |
| 73 | + if isinstance(sub, IteratorSubscription): |
| 74 | + sub = IterableSubscriptionProxy(sub, self._self_tracer) |
| 75 | + self._self_sub = sub |
| 76 | + return sub |
| 77 | + |
| 78 | + def __aexit__(self, *args: Any, **kwargs: dict[Any, Any]) -> Any: |
| 79 | + if self._self_sub and isinstance(self._self_sub, IterableSubscriptionProxy): |
| 80 | + self._self_sub.__cancel_ctx__(*args, **kwargs) |
| 81 | + |
| 82 | + |
65 | 83 | class NatsCoreInstrumentator: |
66 | 84 | """Instrument core nats methods.""" |
67 | 85 |
|
@@ -127,7 +145,7 @@ def _publish_decorator( |
127 | 145 |
|
128 | 146 | wrap_function_wrapper("natsrpy._natsrpy_rs", "Nats.publish", _publish_decorator) |
129 | 147 |
|
130 | | - def _instrument_subscriptions(self) -> None: # noqa: C901 |
| 148 | + def _instrument_subscriptions(self) -> None: |
131 | 149 | """Create instrumentation for.""" |
132 | 150 |
|
133 | 151 | def callback_wrapper( |
@@ -170,27 +188,15 @@ def process_args( |
170 | 188 | callback = callback_wrapper(callback) |
171 | 189 | return (subject, callback, queue) |
172 | 190 |
|
173 | | - @contextlib.asynccontextmanager |
174 | | - async def wrapper( |
| 191 | + def wrapper( |
175 | 192 | wrapper: Any, |
176 | 193 | _: Nats, |
177 | 194 | args: tuple[Any, ...], |
178 | 195 | kwargs: dict[str, Any], |
179 | 196 | ) -> AsyncGenerator[Any, None]: |
180 | | - |
181 | | - async with wrapper(*process_args(*args, **kwargs)) as original_sub: |
182 | | - if isinstance(original_sub, IteratorSubscription): |
183 | | - ret = IterableSubscriptionProxy(original_sub, self.tracer) |
184 | | - else: |
185 | | - ret = original_sub |
186 | | - try: |
187 | | - yield ret |
188 | | - except BaseException: |
189 | | - if isinstance(ret, IterableSubscriptionProxy): |
190 | | - ret.__cancel_ctx__(*sys.exc_info()) |
191 | | - raise |
192 | | - finally: |
193 | | - if isinstance(ret, IterableSubscriptionProxy): |
194 | | - ret.__cancel_ctx__() |
| 197 | + return SubscriptionCtxProxy( |
| 198 | + wrapper(*process_args(*args, **kwargs)), |
| 199 | + self.tracer, |
| 200 | + ) |
195 | 201 |
|
196 | 202 | wrap_function_wrapper("natsrpy._natsrpy_rs", "Nats.subscribe", wrapper) |
0 commit comments