3636from scheduler .ProcessTask import ProcessTask
3737from scheduler .Task import Task
3838from scheduler .ThreadTask import ThreadTask
39- from scheduler .utils import SchedulerException , TaskFailedException
39+ from scheduler .utils import SchedulerException , TaskFailedException , StdOut
4040
4141
4242class Scheduler :
@@ -62,6 +62,7 @@ def __init__(
6262 shared_memory_threshold : int = 1e7 ,
6363 run_in_thread : bool = False ,
6464 raise_exceptions : bool = False ,
65+ capture_stdout : bool = False ,
6566 ):
6667 """
6768 :param progress_callback: a function taking the number of finished tasks and the total number of tasks, which is
@@ -72,13 +73,17 @@ def __init__(
7273 to be below the threshold, the number of simultaneous tasks will be increased
7374 :param cpu_update_interval: the time, in seconds, between consecutive CPU usage checks when `dynamic` is enabled
7475 :param shared_memory: whether to use shared memory if possible
75- :param shared_memory_threshold: the minimum size of a Numpy array which will cause it to be transferred using shared memory if possible
76- :param run_in_thread: if True, a single task will be run in a thread instead of a process.
76+ :param shared_memory_threshold: the minimum size of a Numpy array which will cause it to be transferred
77+ using shared memory if possible
78+ :param run_in_thread: if True, a single task will be run in a thread instead of a process. This reduces
79+ the overhead (caused by spawning processes instead of forking) on Windows/macOS systems
7780 :param raise_exceptions: if True, Exceptions raised in processes will also be raised in
7881 the process which the Scheduler was started in.
79- This reduces the overhead (caused by spawning processes instead of forking) on Windows/macOS systems
82+ :param capture_stdout: if True, `stdout` from processes will be captured and written to the main process'
83+ `stdout`.
8084 """
8185 self .raise_exceptions = raise_exceptions
86+ self .capture_stdout = capture_stdout
8287
8388 self .run_in_thread = run_in_thread
8489 if self .run_in_thread and shared_memory :
@@ -178,7 +183,13 @@ def add(
178183 queue = MTQueue ()
179184 exc_queue = MTQueue ()
180185
181- _args = (queue , self .mgr , self .shared_memory_threshold , exc_queue ) + args
186+ _args = (
187+ queue ,
188+ self .mgr ,
189+ self .shared_memory_threshold ,
190+ exc_queue ,
191+ None ,
192+ ) + args
182193 _wrapper = functools .partial (wrapper , target )
183194
184195 task = ThreadTask (
@@ -190,12 +201,25 @@ def add(
190201 else :
191202 queue = queue_type ()
192203 exc_queue = queue_type ()
204+ stdout_queue = queue_type ()
193205
194- _args = (queue , self .mgr , self .shared_memory_threshold , exc_queue ) + args
206+ _args = (
207+ queue ,
208+ self .mgr ,
209+ self .shared_memory_threshold ,
210+ exc_queue ,
211+ stdout_queue ,
212+ ) + args
195213 _wrapper = functools .partial (wrapper , target )
196214
197215 process = process_type (target = _wrapper , args = _args )
198- task = ProcessTask (process , queue , exc_queue = exc_queue , subtasks = subtasks )
216+ task = ProcessTask (
217+ process ,
218+ queue ,
219+ exc_queue = exc_queue ,
220+ stdout_queue = stdout_queue ,
221+ subtasks = subtasks ,
222+ )
199223
200224 self .tasks .append (task )
201225
@@ -337,6 +361,7 @@ def terminate(self) -> None:
337361 """Terminates all running tasks by killing their processes."""
338362 if not self .terminated :
339363 [t .terminate () for t in self .tasks ]
364+ [self .stdout (t ) for t in self .tasks ]
340365 self .terminated = True
341366
342367 self ._shutdown ()
@@ -412,8 +437,11 @@ def _update(self) -> None:
412437
413438 for t in self .running_tasks :
414439 t .update ()
440+ self .stdout (t )
441+
415442 if self .raise_exceptions and t .failed :
416443 self .failed = True
444+
417445 if t .exception_tb :
418446 raise TaskFailedException (t .exception_tb )
419447 else :
@@ -431,6 +459,13 @@ def _update(self) -> None:
431459 if schedule_new_tasks :
432460 self ._schedule_tasks ()
433461
462+ def stdout (self , task : Task ) -> None :
463+ if self .capture_stdout :
464+ text = task .get_stdout ()
465+
466+ if text :
467+ sys .stdout .write (text )
468+
434469 def _start (self ) -> None :
435470 """
436471 Starts the scheduler running its tasks.
@@ -514,16 +549,29 @@ def wrapper(
514549 manager : Optional ["SharedMemoryManager" ],
515550 threshold : int ,
516551 exc_queue : Queue = None ,
552+ stdout_queue : Queue = None ,
517553 * args : Any ,
518554) -> None :
519555 """
520556 Wrapper which calls a function with its specified arguments and puts the output in a queue.
521557
558+ This function will be the Callable executed in a process/thread.
559+
522560 :param function: the function which will be executed
523561 :param queue: a Queue object which may be used to transfer data between processes
524562 :param manager: a SharedMemoryManager or None; used to handle shared memory between processes
563+ :param threshold:
564+ :param exc_queue:
565+ :param stdout_queue:
525566 """
567+ stdout = None
568+
526569 try :
570+ if stdout_queue :
571+ stdout = StdOut (stdout_queue )
572+ sys .stdout = stdout
573+ sys .stderr = stdout
574+
527575 result = function (* args )
528576 out = []
529577
@@ -542,6 +590,9 @@ def wrapper(
542590 else :
543591 out = tuple (out )
544592
593+ if stdout :
594+ stdout .update (force = True )
595+
545596 queue .put (out )
546597
547598 except Exception as e :
0 commit comments