@@ -535,6 +535,20 @@ def _wrap_async_generator_result(
535535observe = _decorator .observe
536536
537537
538+ def _get_generator_output (
539+ items : List [Any ],
540+ transform_fn : Optional [Callable [[Iterable ], str ]],
541+ ) -> Any :
542+ output : Any = items
543+
544+ if transform_fn is not None :
545+ output = transform_fn (items )
546+ elif all (isinstance (item , str ) for item in items ):
547+ output = "" .join (items )
548+
549+ return output
550+
551+
538552class _ContextPreservedSyncGeneratorWrapper :
539553 """Sync generator wrapper that ensures each iteration runs in preserved context."""
540554
@@ -560,9 +574,17 @@ def __init__(
560574 self .items : List [Any ] = []
561575 self .span = span
562576 self .transform_fn = transform_fn
577+ self ._finalized = False
563578
564- def __iter__ (self ) -> "_ContextPreservedSyncGeneratorWrapper" :
565- return self
579+ def __iter__ (self ) -> Generator [Any , None , None ]:
580+ try :
581+ while True :
582+ yield self .__next__ ()
583+ except StopIteration :
584+ return
585+ finally :
586+ if not self ._finalized :
587+ self .close ()
566588
567589 def __next__ (self ) -> Any :
568590 try :
@@ -573,25 +595,65 @@ def __next__(self) -> Any:
573595 return item
574596
575597 except StopIteration :
576- # Handle output and span cleanup when generator is exhausted
577- output : Any = self . items
598+ self . _finalize ()
599+ raise # Re-raise StopIteration
578600
579- if self .transform_fn is not None :
580- output = self .transform_fn (self .items )
601+ except (Exception , asyncio .CancelledError ) as e :
602+ self ._finalize (error = e )
603+ raise
581604
582- elif all (isinstance (item , str ) for item in self .items ):
583- output = "" .join (self .items )
605+ def close (self ) -> None :
606+ if self ._finalized :
607+ return
584608
585- self .span .update (output = output ).end ()
609+ try :
610+ close_method = getattr (self .generator , "close" , None )
611+ if callable (close_method ):
612+ self .context .run (close_method )
613+ except (Exception , asyncio .CancelledError ) as e :
614+ self ._finalize (error = e )
615+ raise
586616
587- raise # Re-raise StopIteration
617+ self . _finalize ()
588618
619+ def throw (self , typ : Any , val : Any = None , tb : Any = None ) -> Any :
620+ throw_method = getattr (self .generator , "throw" , None )
621+ if not callable (throw_method ):
622+ raise AttributeError ("Wrapped generator does not support throw()" )
623+
624+ try :
625+ if tb is not None :
626+ item = self .context .run (throw_method , typ , val , tb )
627+ elif val is not None :
628+ item = self .context .run (throw_method , typ , val )
629+ else :
630+ item = self .context .run (throw_method , typ )
631+
632+ self .items .append (item )
633+
634+ return item
635+ except StopIteration :
636+ self ._finalize ()
637+ raise
589638 except (Exception , asyncio .CancelledError ) as e :
639+ self ._finalize (error = e )
640+ raise
641+
642+ def _finalize (self , error : Optional [BaseException ] = None ) -> None :
643+ if self ._finalized :
644+ return
645+
646+ self ._finalized = True
647+
648+ if error is not None :
590649 self .span .update (
591- level = "ERROR" , status_message = str (e ) or type (e ).__name__
650+ level = "ERROR" , status_message = str (error ) or type (error ).__name__
592651 ).end ()
652+ return
593653
594- raise
654+ self .span .update (
655+ output = _get_generator_output (self .items , self .transform_fn )
656+ ).end ()
595657
596658
597659class _ContextPreservedAsyncGeneratorWrapper :
@@ -619,43 +681,93 @@ def __init__(
619681 self .items : List [Any ] = []
620682 self .span = span
621683 self .transform_fn = transform_fn
684+ self ._finalized = False
622685
623686 def __aiter__ (self ) -> "_ContextPreservedAsyncGeneratorWrapper" :
624687 return self
625688
626689 async def __anext__ (self ) -> Any :
627690 try :
628691 # Run the generator's __anext__ in the preserved context
629- try :
630- # Python 3.10+ approach with context parameter
631- item = await asyncio .create_task (
632- self .generator .__anext__ (), # type: ignore
633- context = self .context ,
634- ) # type: ignore
635- except TypeError :
636- # Python < 3.10 fallback - context parameter not supported
637- item = await self .generator .__anext__ ()
692+ item = await self ._run_in_preserved_context (self .generator .__anext__ )
638693
639694 self .items .append (item )
640695
641696 return item
642697
643698 except StopAsyncIteration :
644- # Handle output and span cleanup when generator is exhausted
645- output : Any = self .items
699+ self ._finalize ()
700+ raise # Re-raise StopAsyncIteration
701+ except (Exception , asyncio .CancelledError ) as e :
702+ self ._finalize (error = e )
703+ raise
646704
647- if self . transform_fn is not None :
648- output = self .transform_fn ( self . items )
705+ async def close ( self ) -> None :
706+ await self .aclose ( )
649707
650- elif all (isinstance (item , str ) for item in self .items ):
651- output = "" .join (self .items )
708+ async def aclose (self ) -> None :
709+ if self ._finalized :
710+ return
652711
653- self .span .update (output = output ).end ()
712+ try :
713+ close_method = getattr (self .generator , "aclose" , None )
714+ if callable (close_method ):
715+ await self ._run_in_preserved_context (close_method )
716+ except (Exception , asyncio .CancelledError ) as e :
717+ self ._finalize (error = e )
718+ raise
654719
655- raise # Re-raise StopAsyncIteration
720+ self ._finalize ()
721+
722+ async def athrow (self , typ : Any , val : Any = None , tb : Any = None ) -> Any :
723+ throw_method = getattr (self .generator , "athrow" , None )
724+ if not callable (throw_method ):
725+ raise AttributeError ("Wrapped async generator does not support athrow()" )
726+
727+ try :
728+ if tb is not None :
729+ item = await self ._run_in_preserved_context (
730+ lambda : throw_method (typ , val , tb )
731+ )
732+ elif val is not None :
733+ item = await self ._run_in_preserved_context (
734+ lambda : throw_method (typ , val )
735+ )
736+ else :
737+ item = await self ._run_in_preserved_context (lambda : throw_method (typ ))
738+
739+ self .items .append (item )
740+
741+ return item
742+ except StopAsyncIteration :
743+ self ._finalize ()
744+ raise
656745 except (Exception , asyncio .CancelledError ) as e :
746+ self ._finalize (error = e )
747+ raise
748+
749+ async def _run_in_preserved_context (self , factory : Callable [[], Any ]) -> Any :
750+ awaitable = self .context .run (factory )
751+
752+ try :
753+ task = asyncio .create_task (awaitable , context = self .context ) # type: ignore[call-arg]
754+ except TypeError :
755+ task = self .context .run (asyncio .create_task , awaitable )
756+
757+ return await task
758+
759+ def _finalize (self , error : Optional [BaseException ] = None ) -> None :
760+ if self ._finalized :
761+ return
762+
763+ self ._finalized = True
764+
765+ if error is not None :
657766 self .span .update (
658- level = "ERROR" , status_message = str (e ) or type (e ).__name__
767+ level = "ERROR" , status_message = str (error ) or type (error ).__name__
659768 ).end ()
769+ return
660770
661- raise
771+ self .span .update (
772+ output = _get_generator_output (self .items , self .transform_fn )
773+ ).end ()
0 commit comments