@@ -128,8 +128,9 @@ class _OpenMessage:
128128 open_blocks : dict [int , str ] = field (default_factory = dict )
129129 text_contents : dict [int , str ] = field (default_factory = dict )
130130
131- def __init__ (self , * , thread_id : str ) -> None :
131+ def __init__ (self , * , thread_id : str , run_id : str ) -> None :
132132 self .thread_id = thread_id
133+ self .run_id = run_id
133134 self ._open_message_ids : dict [str , _ProtocolMessageStreamState ._OpenMessage ] = {}
134135 self .saw_live_messages = False
135136
@@ -148,6 +149,7 @@ def _finish_blocks(
148149 self .thread_id ,
149150 index = index ,
150151 namespace = effective_namespace ,
152+ run_id = self .run_id ,
151153 )
152154 del state .open_blocks [index ]
153155
@@ -166,6 +168,7 @@ def _publish_text_block(
166168 index = index ,
167169 content = {"type" : "text" , "text" : "" },
168170 namespace = effective_namespace ,
171+ run_id = self .run_id ,
169172 )
170173 state .open_blocks [index ] = "text"
171174 previous_text = state .text_contents .get (index , "" )
@@ -178,6 +181,7 @@ def _publish_text_block(
178181 index = index ,
179182 delta = {"type" : "text-delta" , "text" : delta_text },
180183 namespace = effective_namespace ,
184+ run_id = self .run_id ,
181185 )
182186 state .text_contents [index ] = text
183187
@@ -197,13 +201,15 @@ def _publish_nontext_block(
197201 index = index ,
198202 content = block ,
199203 namespace = effective_namespace ,
204+ run_id = self .run_id ,
200205 )
201206 if final :
202207 publish_content_block_finish (
203208 self .thread_id ,
204209 index = index ,
205210 content = block ,
206211 namespace = effective_namespace ,
212+ run_id = self .run_id ,
207213 )
208214 return
209215 state .open_blocks [index ] = str (block .get ("type" , "block" ))
@@ -214,13 +220,15 @@ def _publish_nontext_block(
214220 index = index ,
215221 delta = block ,
216222 namespace = effective_namespace ,
223+ run_id = self .run_id ,
217224 )
218225 if final :
219226 publish_content_block_finish (
220227 self .thread_id ,
221228 index = index ,
222229 content = block ,
223230 namespace = effective_namespace ,
231+ run_id = self .run_id ,
224232 )
225233 del state .open_blocks [index ]
226234
@@ -239,6 +247,7 @@ def publish_blocks(
239247 message_id = message_id ,
240248 role = role ,
241249 namespace = namespace ,
250+ run_id = self .run_id ,
242251 )
243252 state = self ._OpenMessage (
244253 role = role ,
@@ -317,7 +326,7 @@ def finish_all(self, *, namespace: list[str] | None = None) -> None:
317326 state = self ._open_message_ids .pop (message_id )
318327 message_namespace = state .namespace or namespace
319328 self ._finish_blocks (state , namespace = message_namespace )
320- publish_message_complete (self .thread_id , namespace = message_namespace )
329+ publish_message_complete (self .thread_id , namespace = message_namespace , run_id = self . run_id )
321330
322331 async def afinish_blocks (
323332 self ,
@@ -334,6 +343,7 @@ async def afinish_blocks(
334343 self .thread_id ,
335344 index = index ,
336345 namespace = effective_namespace ,
346+ run_id = self .run_id ,
337347 )
338348 del state .open_blocks [index ]
339349
@@ -352,6 +362,7 @@ async def apublish_text_block(
352362 index = index ,
353363 content = {"type" : "text" , "text" : "" },
354364 namespace = effective_namespace ,
365+ run_id = self .run_id ,
355366 )
356367 state .open_blocks [index ] = "text"
357368 previous_text = state .text_contents .get (index , "" )
@@ -364,6 +375,7 @@ async def apublish_text_block(
364375 index = index ,
365376 delta = {"type" : "text-delta" , "text" : delta_text },
366377 namespace = effective_namespace ,
378+ run_id = self .run_id ,
367379 )
368380 state .text_contents [index ] = text
369381
@@ -383,13 +395,15 @@ async def apublish_nontext_block(
383395 index = index ,
384396 content = block ,
385397 namespace = effective_namespace ,
398+ run_id = self .run_id ,
386399 )
387400 if final :
388401 await apublish_content_block_finish (
389402 self .thread_id ,
390403 index = index ,
391404 content = block ,
392405 namespace = effective_namespace ,
406+ run_id = self .run_id ,
393407 )
394408 return
395409 state .open_blocks [index ] = str (block .get ("type" , "block" ))
@@ -400,13 +414,15 @@ async def apublish_nontext_block(
400414 index = index ,
401415 delta = block ,
402416 namespace = effective_namespace ,
417+ run_id = self .run_id ,
403418 )
404419 if final :
405420 await apublish_content_block_finish (
406421 self .thread_id ,
407422 index = index ,
408423 content = block ,
409424 namespace = effective_namespace ,
425+ run_id = self .run_id ,
410426 )
411427 del state .open_blocks [index ]
412428
@@ -425,6 +441,7 @@ async def apublish_blocks(
425441 message_id = message_id ,
426442 role = role ,
427443 namespace = namespace ,
444+ run_id = self .run_id ,
428445 )
429446 state = self ._OpenMessage (
430447 role = role ,
@@ -503,7 +520,7 @@ async def afinish_all(self, *, namespace: list[str] | None = None) -> None:
503520 state = self ._open_message_ids .pop (message_id )
504521 message_namespace = state .namespace or namespace
505522 await self .afinish_blocks (state , namespace = message_namespace )
506- await apublish_message_complete (self .thread_id , namespace = message_namespace )
523+ await apublish_message_complete (self .thread_id , namespace = message_namespace , run_id = self . run_id )
507524
508525
509526def _protocol_blocks_for_message (message : BaseMessage ) -> list [dict [str , Any ]]:
@@ -734,7 +751,7 @@ async def execute_run(
734751 result : Any = None
735752 interrupt_chunk : Any = None
736753 interrupt_namespace : list [str ] | None = None
737- protocol_messages = _ProtocolMessageStreamState (thread_id = thread_id )
754+ protocol_messages = _ProtocolMessageStreamState (thread_id = thread_id , run_id = run_id )
738755 async for stream_event in graph .astream_events (invocation , config , version = "v2" ):
739756 protocol_namespace = _protocol_namespace_for_event (stream_event )
740757 for event_name , event_payload in _translate_stream_events (stream_event ):
@@ -804,7 +821,12 @@ async def execute_run(
804821 if isinstance (normalized_chunk , dict ):
805822 normalized_chunk .pop ("__interrupt__" , None )
806823 if normalized_chunk :
807- await apublish_updates_event (thread_id , values = normalized_chunk , namespace = protocol_namespace )
824+ await apublish_updates_event (
825+ thread_id ,
826+ values = normalized_chunk ,
827+ namespace = protocol_namespace ,
828+ run_id = run_id ,
829+ )
808830 if stream_event .get ("event" ) == "on_chain_end" and _is_root_stream_event (stream_event ):
809831 data = stream_event .get ("data" , {})
810832 if isinstance (data , dict ) and "output" in data :
@@ -818,7 +840,12 @@ async def execute_run(
818840 else :
819841 await apublish_message_transcript (thread_id , run_id = run_id , messages = messages )
820842 await protocol_messages .afinish_all ()
821- await apublish_values_event (thread_id , values = normalized_result , namespace = protocol_namespace )
843+ await apublish_values_event (
844+ thread_id ,
845+ values = normalized_result ,
846+ namespace = protocol_namespace ,
847+ run_id = run_id ,
848+ )
822849
823850 if interrupt_chunk is not None :
824851 if isinstance (result , dict ):
0 commit comments