88import subprocess
99import threading
1010import uuid
11+ from contextlib import suppress
1112from pathlib import Path
12- from typing import Any , Callable
13+ from typing import TYPE_CHECKING , Any
14+
15+ if TYPE_CHECKING :
16+ from collections .abc import Callable
1317
1418from .constants import TRACE_MARKER
1519from .models import Block , BlockExecutor , CommandLog , ExecutionContext , ExecutionResult , SSHConfig
@@ -316,6 +320,7 @@ def _run_remote_subprocess(
316320 * ,
317321 timeout_seconds : int | None ,
318322 on_command : Callable [[str ], None ] | None = None ,
323+ on_output : Callable [[str ], None ] | None = None ,
319324) -> tuple [str , str , int , bool , bool ]:
320325 """Run an SSH subprocess and preserve partial output on timeout or interruption."""
321326 process = subprocess .Popen (
@@ -330,6 +335,7 @@ def _run_remote_subprocess(
330335 input_text = remote_script ,
331336 timeout_seconds = timeout_seconds ,
332337 on_command = on_command ,
338+ on_output = on_output ,
333339 )
334340
335341
@@ -418,6 +424,7 @@ def execute_local_traced(
418424 context : ExecutionContext ,
419425 no_input : bool = False ,
420426 on_command : Callable [[str ], None ] | None = None ,
427+ on_output : Callable [[str ], None ] | None = None ,
421428) -> ExecutionResult :
422429 """Execute a local block while preserving per-command trace logs."""
423430 if not block .commands :
@@ -453,6 +460,7 @@ def execute_local_traced(
453460 input_text = input_text ,
454461 timeout_seconds = block .timeout_seconds ,
455462 on_command = on_command ,
463+ on_output = on_output ,
456464 )
457465 cleaned_stderr = _strip_trace_markers (stderr ).strip ()
458466 command_logs = _parse_grouped_trace_output (
@@ -505,6 +513,7 @@ def execute_remote(
505513 no_input : bool = False ,
506514 servers : dict [str , dict [str , str ]] | None = None ,
507515 on_command : Callable [[str ], None ] | None = None ,
516+ on_output : Callable [[str ], None ] | None = None ,
508517) -> ExecutionResult :
509518 """Execute a remote block via SSH.
510519
@@ -580,6 +589,7 @@ def execute_remote(
580589 remote_script ,
581590 timeout_seconds = block .timeout_seconds ,
582591 on_command = on_command ,
592+ on_output = on_output ,
583593 )
584594 cleaned_stdout = _strip_trace_markers (stdout )
585595 cleaned_stderr = _strip_trace_markers (stderr )
@@ -779,6 +789,7 @@ def _run_traced_subprocess(
779789 input_text : str | None ,
780790 timeout_seconds : int | None ,
781791 on_command : Callable [[str ], None ] | None = None ,
792+ on_output : Callable [[str ], None ] | None = None ,
782793) -> tuple [str , str , int , bool , bool ]:
783794 """Read traced stdout/stderr streams while surfacing live command markers."""
784795 stdout_chunks : list [str ] = []
@@ -789,25 +800,16 @@ def _run_traced_subprocess(
789800 if stdout_stream is None or stderr_stream is None :
790801 raise subprocess .SubprocessError ("traced subprocess did not expose stdout/stderr pipes" )
791802
792- def read_stdout () -> None :
793- try :
794- for line in iter (stdout_stream .readline , "" ):
795- stdout_chunks .append (line )
796- command = _extract_trace_command_line (line )
797- if command is not None and on_command is not None :
798- on_command (command )
799- finally :
800- stdout_stream .close ()
801-
802- def read_stderr () -> None :
803- try :
804- for line in iter (stderr_stream .readline , "" ):
805- stderr_chunks .append (line )
806- finally :
807- stderr_stream .close ()
808-
809- stdout_thread = threading .Thread (target = read_stdout , daemon = True )
810- stderr_thread = threading .Thread (target = read_stderr , daemon = True )
803+ stdout_thread = threading .Thread (
804+ target = _read_traced_stdout_stream ,
805+ args = (stdout_stream , stdout_chunks , on_command , on_output ),
806+ daemon = True ,
807+ )
808+ stderr_thread = threading .Thread (
809+ target = _read_traced_output_stream ,
810+ args = (stderr_stream , stderr_chunks , on_output ),
811+ daemon = True ,
812+ )
811813 stdout_thread .start ()
812814 stderr_thread .start ()
813815
@@ -816,14 +818,7 @@ def read_stderr() -> None:
816818 stdin_stream = process .stdin
817819
818820 try :
819- if stdin_stream is not None :
820- if input_text is not None :
821- try :
822- stdin_stream .write (input_text )
823- except BrokenPipeError :
824- pass
825- stdin_stream .close ()
826-
821+ _write_process_input (stdin_stream , input_text )
827822 process .wait (timeout = timeout_seconds )
828823 except subprocess .TimeoutExpired :
829824 timed_out = True
@@ -853,6 +848,50 @@ def read_stderr() -> None:
853848 )
854849
855850
851+ def _read_traced_stdout_stream (
852+ stdout_stream : Any ,
853+ stdout_chunks : list [str ],
854+ on_command : Callable [[str ], None ] | None ,
855+ on_output : Callable [[str ], None ] | None ,
856+ ) -> None :
857+ """Read traced stdout, separating command markers from regular output."""
858+ try :
859+ for line in iter (stdout_stream .readline , "" ):
860+ stdout_chunks .append (line )
861+ command = _extract_trace_command_line (line )
862+ if command is not None and on_command is not None :
863+ on_command (command )
864+ elif command is None and on_output is not None :
865+ on_output (line )
866+ finally :
867+ stdout_stream .close ()
868+
869+
870+ def _read_traced_output_stream (
871+ stream : Any ,
872+ chunks : list [str ],
873+ on_output : Callable [[str ], None ] | None ,
874+ ) -> None :
875+ """Read a traced process output stream and forward raw chunks."""
876+ try :
877+ for line in iter (stream .readline , "" ):
878+ chunks .append (line )
879+ if on_output is not None :
880+ on_output (line )
881+ finally :
882+ stream .close ()
883+
884+
885+ def _write_process_input (stdin_stream : Any , input_text : str | None ) -> None :
886+ """Write the traced script to stdin, tolerating early process exits."""
887+ if stdin_stream is None :
888+ return
889+ if input_text is not None :
890+ with suppress (BrokenPipeError ):
891+ stdin_stream .write (input_text )
892+ stdin_stream .close ()
893+
894+
856895def _parse_grouped_trace_output (
857896 combined_output : str ,
858897 * ,
0 commit comments