Skip to content

Commit edf45e3

Browse files
committed
fix(async-streaming): harden context preservation
1 parent c5dc24d commit edf45e3

File tree

4 files changed

+443
-37
lines changed

4 files changed

+443
-37
lines changed

langfuse/_client/observe.py

Lines changed: 143 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,20 @@ def _wrap_async_generator_result(
535535
observe = _decorator.observe
536536

537537

538+
def _get_generator_output(
539+
items: List[Any],
540+
transform_fn: Optional[Callable[[Iterable], str]],
541+
) -> Any:
542+
output: Any = items
543+
544+
if transform_fn is not None:
545+
output = transform_fn(items)
546+
elif all(isinstance(item, str) for item in items):
547+
output = "".join(items)
548+
549+
return output
550+
551+
538552
class _ContextPreservedSyncGeneratorWrapper:
539553
"""Sync generator wrapper that ensures each iteration runs in preserved context."""
540554

@@ -560,9 +574,17 @@ def __init__(
560574
self.items: List[Any] = []
561575
self.span = span
562576
self.transform_fn = transform_fn
577+
self._finalized = False
563578

564-
def __iter__(self) -> "_ContextPreservedSyncGeneratorWrapper":
565-
return self
579+
def __iter__(self) -> Generator[Any, None, None]:
580+
try:
581+
while True:
582+
yield self.__next__()
583+
except StopIteration:
584+
return
585+
finally:
586+
if not self._finalized:
587+
self.close()
566588

567589
def __next__(self) -> Any:
568590
try:
@@ -573,25 +595,65 @@ def __next__(self) -> Any:
573595
return item
574596

575597
except StopIteration:
576-
# Handle output and span cleanup when generator is exhausted
577-
output: Any = self.items
598+
self._finalize()
599+
raise # Re-raise StopIteration
578600

579-
if self.transform_fn is not None:
580-
output = self.transform_fn(self.items)
601+
except (Exception, asyncio.CancelledError) as e:
602+
self._finalize(error=e)
603+
raise
581604

582-
elif all(isinstance(item, str) for item in self.items):
583-
output = "".join(self.items)
605+
def close(self) -> None:
606+
if self._finalized:
607+
return
584608

585-
self.span.update(output=output).end()
609+
try:
610+
close_method = getattr(self.generator, "close", None)
611+
if callable(close_method):
612+
self.context.run(close_method)
613+
except (Exception, asyncio.CancelledError) as e:
614+
self._finalize(error=e)
615+
raise
586616

587-
raise # Re-raise StopIteration
617+
self._finalize()
588618

619+
def throw(self, typ: Any, val: Any = None, tb: Any = None) -> Any:
620+
throw_method = getattr(self.generator, "throw", None)
621+
if not callable(throw_method):
622+
raise AttributeError("Wrapped generator does not support throw()")
623+
624+
try:
625+
if tb is not None:
626+
item = self.context.run(throw_method, typ, val, tb)
627+
elif val is not None:
628+
item = self.context.run(throw_method, typ, val)
629+
else:
630+
item = self.context.run(throw_method, typ)
631+
632+
self.items.append(item)
633+
634+
return item
635+
except StopIteration:
636+
self._finalize()
637+
raise
589638
except (Exception, asyncio.CancelledError) as e:
639+
self._finalize(error=e)
640+
raise
641+
642+
def _finalize(self, error: Optional[BaseException] = None) -> None:
643+
if self._finalized:
644+
return
645+
646+
self._finalized = True
647+
648+
if error is not None:
590649
self.span.update(
591-
level="ERROR", status_message=str(e) or type(e).__name__
650+
level="ERROR", status_message=str(error) or type(error).__name__
592651
).end()
652+
return
593653

594-
raise
654+
self.span.update(
655+
output=_get_generator_output(self.items, self.transform_fn)
656+
).end()
595657

596658

597659
class _ContextPreservedAsyncGeneratorWrapper:
@@ -619,43 +681,93 @@ def __init__(
619681
self.items: List[Any] = []
620682
self.span = span
621683
self.transform_fn = transform_fn
684+
self._finalized = False
622685

623686
def __aiter__(self) -> "_ContextPreservedAsyncGeneratorWrapper":
624687
return self
625688

626689
async def __anext__(self) -> Any:
627690
try:
628691
# Run the generator's __anext__ in the preserved context
629-
try:
630-
# Python 3.10+ approach with context parameter
631-
item = await asyncio.create_task(
632-
self.generator.__anext__(), # type: ignore
633-
context=self.context,
634-
) # type: ignore
635-
except TypeError:
636-
# Python < 3.10 fallback - context parameter not supported
637-
item = await self.generator.__anext__()
692+
item = await self._run_in_preserved_context(self.generator.__anext__)
638693

639694
self.items.append(item)
640695

641696
return item
642697

643698
except StopAsyncIteration:
644-
# Handle output and span cleanup when generator is exhausted
645-
output: Any = self.items
699+
self._finalize()
700+
raise # Re-raise StopAsyncIteration
701+
except (Exception, asyncio.CancelledError) as e:
702+
self._finalize(error=e)
703+
raise
646704

647-
if self.transform_fn is not None:
648-
output = self.transform_fn(self.items)
705+
async def close(self) -> None:
706+
await self.aclose()
649707

650-
elif all(isinstance(item, str) for item in self.items):
651-
output = "".join(self.items)
708+
async def aclose(self) -> None:
709+
if self._finalized:
710+
return
652711

653-
self.span.update(output=output).end()
712+
try:
713+
close_method = getattr(self.generator, "aclose", None)
714+
if callable(close_method):
715+
await self._run_in_preserved_context(close_method)
716+
except (Exception, asyncio.CancelledError) as e:
717+
self._finalize(error=e)
718+
raise
654719

655-
raise # Re-raise StopAsyncIteration
720+
self._finalize()
721+
722+
async def athrow(self, typ: Any, val: Any = None, tb: Any = None) -> Any:
723+
throw_method = getattr(self.generator, "athrow", None)
724+
if not callable(throw_method):
725+
raise AttributeError("Wrapped async generator does not support athrow()")
726+
727+
try:
728+
if tb is not None:
729+
item = await self._run_in_preserved_context(
730+
lambda: throw_method(typ, val, tb)
731+
)
732+
elif val is not None:
733+
item = await self._run_in_preserved_context(
734+
lambda: throw_method(typ, val)
735+
)
736+
else:
737+
item = await self._run_in_preserved_context(lambda: throw_method(typ))
738+
739+
self.items.append(item)
740+
741+
return item
742+
except StopAsyncIteration:
743+
self._finalize()
744+
raise
656745
except (Exception, asyncio.CancelledError) as e:
746+
self._finalize(error=e)
747+
raise
748+
749+
async def _run_in_preserved_context(self, factory: Callable[[], Any]) -> Any:
750+
awaitable = self.context.run(factory)
751+
752+
try:
753+
task = asyncio.create_task(awaitable, context=self.context) # type: ignore[call-arg]
754+
except TypeError:
755+
task = self.context.run(asyncio.create_task, awaitable)
756+
757+
return await task
758+
759+
def _finalize(self, error: Optional[BaseException] = None) -> None:
760+
if self._finalized:
761+
return
762+
763+
self._finalized = True
764+
765+
if error is not None:
657766
self.span.update(
658-
level="ERROR", status_message=str(e) or type(e).__name__
767+
level="ERROR", status_message=str(error) or type(error).__name__
659768
).end()
769+
return
660770

661-
raise
771+
self.span.update(
772+
output=_get_generator_output(self.items, self.transform_fn)
773+
).end()

langfuse/openai.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,7 @@ def __init__(
10101010
self.response = response
10111011
self.generation = generation
10121012
self.completion_start_time: Optional[datetime] = None
1013+
self._finalized = False
10131014

10141015
def __iter__(self) -> Any:
10151016
try:
@@ -1039,12 +1040,31 @@ def __next__(self) -> Any:
10391040
raise
10401041

10411042
def __enter__(self) -> Any:
1042-
return self.__iter__()
1043+
return self
10431044

10441045
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
1045-
pass
1046+
self.close()
1047+
1048+
def close(self) -> None:
1049+
if self._finalized:
1050+
return
1051+
1052+
close_method = getattr(self.response, "close", None)
1053+
if callable(close_method):
1054+
try:
1055+
close_method()
1056+
finally:
1057+
self._finalize()
1058+
return
1059+
1060+
self._finalize()
10461061

10471062
def _finalize(self) -> None:
1063+
if self._finalized:
1064+
return
1065+
1066+
self._finalized = True
1067+
10481068
try:
10491069
model, completion, usage, metadata = (
10501070
_extract_streamed_response_api_response(self.items)
@@ -1081,6 +1101,7 @@ def __init__(
10811101
self.response = response
10821102
self.generation = generation
10831103
self.completion_start_time: Optional[datetime] = None
1104+
self._finalized = False
10841105

10851106
async def __aiter__(self) -> Any:
10861107
try:
@@ -1110,12 +1131,17 @@ async def __anext__(self) -> Any:
11101131
raise
11111132

11121133
async def __aenter__(self) -> Any:
1113-
return self.__aiter__()
1134+
return self
11141135

11151136
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
1116-
pass
1137+
await self.aclose()
11171138

11181139
async def _finalize(self) -> None:
1140+
if self._finalized:
1141+
return
1142+
1143+
self._finalized = True
1144+
11191145
try:
11201146
model, completion, usage, metadata = (
11211147
_extract_streamed_response_api_response(self.items)
@@ -1142,11 +1168,40 @@ async def close(self) -> None:
11421168
11431169
Automatically called if the response body is read to completion.
11441170
"""
1145-
await self.response.close()
1171+
if self._finalized:
1172+
return
1173+
1174+
close_method = getattr(self.response, "close", None)
1175+
if callable(close_method):
1176+
try:
1177+
await close_method()
1178+
finally:
1179+
await self._finalize()
1180+
return
1181+
1182+
await self._finalize()
11461183

11471184
async def aclose(self) -> None:
11481185
"""Close the response and release the connection.
11491186
11501187
Automatically called if the response body is read to completion.
11511188
"""
1152-
await self.response.aclose()
1189+
if self._finalized:
1190+
return
1191+
1192+
close_method = getattr(self.response, "aclose", None)
1193+
if callable(close_method):
1194+
try:
1195+
await close_method()
1196+
finally:
1197+
await self._finalize()
1198+
else:
1199+
close_method = getattr(self.response, "close", None)
1200+
if callable(close_method):
1201+
try:
1202+
await close_method()
1203+
finally:
1204+
await self._finalize()
1205+
return
1206+
1207+
await self._finalize()

0 commit comments

Comments
 (0)