@@ -496,6 +496,7 @@ def __init__(self):
496496 self ._token = None
497497 self ._output_chunks = []
498498 self ._trace_initialized = False
499+ self ._captured_context = None # Capture context for ASGI compatibility
499500
500501 def __iter__ (self ):
501502 return self
@@ -517,6 +518,10 @@ def __next__(self):
517518 func_kwargs = func_kwargs ,
518519 context_kwarg = context_kwarg ,
519520 )
521+ # Capture context after trace is initialized for ASGI compatibility
522+ # This ensures we can finalize in the correct context even if
523+ # later __next__ calls happen in different worker threads
524+ self ._captured_context = contextvars .copy_context ()
520525 self ._trace_initialized = True
521526
522527 try :
@@ -525,23 +530,21 @@ def __next__(self):
525530 return chunk
526531 except StopIteration :
527532 # Finalize trace when generator is exhausted
533+ # Use captured context to ensure we have access to the trace
528534 output = _join_output_chunks (self ._output_chunks )
529- _finalize_sync_generator_step (
530- step = self ._step ,
531- token = self ._token ,
532- is_root_step = self ._is_root_step ,
533- step_name = step_name ,
534- inputs = self ._inputs ,
535- output = output ,
536- inference_pipeline_id = inference_pipeline_id ,
537- on_flush_failure = on_flush_failure ,
538- )
539- raise
540- except Exception as exc :
541- # Handle exceptions
542- if self ._step :
543- _log_step_exception (self ._step , exc )
544- output = _join_output_chunks (self ._output_chunks )
535+ if self ._captured_context :
536+ self ._captured_context .run (
537+ _finalize_sync_generator_step ,
538+ step = self ._step ,
539+ token = self ._token ,
540+ is_root_step = self ._is_root_step ,
541+ step_name = step_name ,
542+ inputs = self ._inputs ,
543+ output = output ,
544+ inference_pipeline_id = inference_pipeline_id ,
545+ on_flush_failure = on_flush_failure ,
546+ )
547+ else :
545548 _finalize_sync_generator_step (
546549 step = self ._step ,
547550 token = self ._token ,
@@ -553,6 +556,35 @@ def __next__(self):
553556 on_flush_failure = on_flush_failure ,
554557 )
555558 raise
559+ except Exception as exc :
560+ # Handle exceptions
561+ if self ._step :
562+ _log_step_exception (self ._step , exc )
563+ output = _join_output_chunks (self ._output_chunks )
564+ if self ._captured_context :
565+ self ._captured_context .run (
566+ _finalize_sync_generator_step ,
567+ step = self ._step ,
568+ token = self ._token ,
569+ is_root_step = self ._is_root_step ,
570+ step_name = step_name ,
571+ inputs = self ._inputs ,
572+ output = output ,
573+ inference_pipeline_id = inference_pipeline_id ,
574+ on_flush_failure = on_flush_failure ,
575+ )
576+ else :
577+ _finalize_sync_generator_step (
578+ step = self ._step ,
579+ token = self ._token ,
580+ is_root_step = self ._is_root_step ,
581+ step_name = step_name ,
582+ inputs = self ._inputs ,
583+ output = output ,
584+ inference_pipeline_id = inference_pipeline_id ,
585+ on_flush_failure = on_flush_failure ,
586+ )
587+ raise
556588
557589 return TracedSyncGenerator ()
558590
0 commit comments