Skip to content

Commit d0195ef

Browse files
Copilots3rius
andauthored
fix: wrap __aiter__ with async generator to prevent OTel context leak on break
Agent-Logs-Url: https://github.com/taskiq-python/natsrpy/sessions/12fa2646-b93c-455b-9b63-4a0a6a8d3570 Co-authored-by: s3rius <18153319+s3rius@users.noreply.github.com>
1 parent 426576c commit d0195ef

File tree

1 file changed

+62
-36
lines changed

1 file changed

+62
-36
lines changed

python/natsrpy/instrumentation/nats_core.py

Lines changed: 62 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Awaitable, Callable
1+
from collections.abc import AsyncIterator, Awaitable, Callable
22
from contextlib import AbstractContextManager
33
from contextvars import Token
44
from functools import wraps
@@ -14,6 +14,17 @@
1414
from .span_builder import SpanBuilder
1515

1616

17+
def _cleanup_otel_context(
18+
token: Token[Any] | None,
19+
span_manager: AbstractContextManager[Any] | None,
20+
) -> None:
21+
"""Detach the current OTel context and end the active span, if any."""
22+
if token:
23+
context.detach(token)
24+
if span_manager:
25+
span_manager.__exit__(None, None, None)
26+
27+
1728
class NatsCoreInstrumentator:
1829
"""Instrument core nats methods."""
1930

@@ -37,7 +48,7 @@ def instrument(self) -> None:
3748
def uninstrument() -> None:
3849
"""Remove instrumentaitons from core Nats."""
3950
unwrap(Nats, "publish")
40-
unwrap(IteratorSubscription, "__anext__")
51+
unwrap(IteratorSubscription, "__aiter__")
4152

4253
def _instrument_publish(self) -> None:
4354
def _wrapped_publish(
@@ -83,46 +94,61 @@ def _publish_decorator(
8394

8495
def _instrument_iter_subscription(self) -> None:
8596

86-
current_token: Token[Any] | None = None
87-
span_manager: AbstractContextManager[Any] | None = None
88-
89-
async def _custom_anext(
90-
wrapper: Callable[..., Any],
91-
_: Nats,
92-
args: tuple[Any, ...],
93-
kwargs: dict[str, Any],
94-
) -> Any:
95-
nonlocal current_token
96-
nonlocal span_manager
97+
async def _instrumented_iter(
98+
sub: IteratorSubscription,
99+
) -> AsyncIterator[Message]:
100+
"""Async generator wrapping an iterator subscription with OTel context.
97101
102+
Each ``async for`` loop gets its own generator instance with
103+
independent context state. The ``finally`` block guarantees
104+
cleanup when the loop exits — whether via ``break``, an
105+
exception, or normal ``StopAsyncIteration``.
106+
"""
107+
token: Token[Any] | None = None
108+
span_manager: AbstractContextManager[Any] | None = None
98109
try:
99-
msg = await wrapper(*args, **kwargs)
100-
# For handling StopAsyncIteration error
101-
# and possibly other exceptions.
110+
while True:
111+
# Clean up the *previous* iteration's context before
112+
# waiting for the next message.
113+
_cleanup_otel_context(token, span_manager)
114+
token = None
115+
span_manager = None
116+
117+
try:
118+
msg = await IteratorSubscription.__anext__(sub)
119+
except StopAsyncIteration:
120+
return
121+
122+
if not is_instrumentation_enabled():
123+
yield msg
124+
continue
125+
126+
ctx = propagate.extract(msg.headers)
127+
token = context.attach(ctx)
128+
span = (
129+
SpanBuilder(self.tracer, SpanKind.CONSUMER, "receive")
130+
.with_message(msg)
131+
.build()
132+
)
133+
if span:
134+
span_manager = trace.use_span(span, end_on_exit=True)
135+
span_manager.__enter__()
136+
yield msg
102137
finally:
103-
if current_token:
104-
context.detach(current_token)
105-
if span_manager:
106-
span_manager.__exit__(None, None, None)
107-
108-
if not is_instrumentation_enabled():
109-
return msg
110-
ctx = propagate.extract(msg.headers)
111-
current_token = context.attach(ctx)
112-
span = (
113-
SpanBuilder(self.tracer, SpanKind.CONSUMER, "receive")
114-
.with_message(msg)
115-
.build()
116-
)
117-
if span:
118-
span_manager = trace.use_span(span, end_on_exit=True)
119-
span_manager.__enter__()
120-
return msg
138+
_cleanup_otel_context(token, span_manager)
139+
140+
def _custom_aiter(
141+
wrapper: Any,
142+
instance: IteratorSubscription,
143+
args: tuple[Any, ...],
144+
kwargs: dict[str, Any],
145+
) -> AsyncIterator[Message]:
146+
return _instrumented_iter(instance)
121147

122148
wrap_function_wrapper(
123149
"natsrpy._natsrpy_rs",
124-
"IteratorSubscription.__anext__",
125-
_custom_anext,
150+
"IteratorSubscription.__aiter__",
151+
_custom_aiter,
126152
)
127153

128154
def _instrument_cb_subscription(self) -> None:

0 commit comments

Comments
 (0)