Skip to content

Commit a8645ca

Browse files
Add ability to capture stdout
1 parent 2086544 commit a8645ca

8 files changed

Lines changed: 196 additions & 13 deletions

File tree

scheduler/ProcessTask.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,14 @@
2727

2828
class ProcessTask(Task):
2929
def __init__(
30-
self, process: Process, queue: Queue, exc_queue=None, subtasks: int = 0
30+
self,
31+
process: Process,
32+
queue: Queue,
33+
exc_queue=None,
34+
stdout_queue=None,
35+
subtasks: int = 0,
3136
):
32-
super(ProcessTask, self).__init__(queue, exc_queue, subtasks)
37+
super(ProcessTask, self).__init__(queue, exc_queue, stdout_queue, subtasks)
3338

3439
self.process = process
3540

scheduler/Scheduler.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from scheduler.ProcessTask import ProcessTask
3737
from scheduler.Task import Task
3838
from scheduler.ThreadTask import ThreadTask
39-
from scheduler.utils import SchedulerException, TaskFailedException
39+
from scheduler.utils import SchedulerException, TaskFailedException, StdOut
4040

4141

4242
class Scheduler:
@@ -62,6 +62,7 @@ def __init__(
6262
shared_memory_threshold: int = 1e7,
6363
run_in_thread: bool = False,
6464
raise_exceptions: bool = False,
65+
capture_stdout: bool = False,
6566
):
6667
"""
6768
:param progress_callback: a function taking the number of finished tasks and the total number of tasks, which is
@@ -72,13 +73,17 @@ def __init__(
7273
to be below the threshold, the number of simultaneous tasks will be increased
7374
:param cpu_update_interval: the time, in seconds, between consecutive CPU usage checks when `dynamic` is enabled
7475
:param shared_memory: whether to use shared memory if possible
75-
:param shared_memory_threshold: the minimum size of a Numpy array which will cause it to be transferred using shared memory if possible
76-
:param run_in_thread: if True, a single task will be run in a thread instead of a process.
76+
:param shared_memory_threshold: the minimum size of a Numpy array which will cause it to be transferred
77+
using shared memory if possible
78+
:param run_in_thread: if True, a single task will be run in a thread instead of a process. This reduces
79+
the overhead (caused by spawning processes instead of forking) on Windows/macOS systems
7780
:param raise_exceptions: if True, Exceptions raised in processes will also be raised in
7881
the process which the Scheduler was started in.
79-
This reduces the overhead (caused by spawning processes instead of forking) on Windows/macOS systems
82+
:param capture_stdout: if True, `stdout` from processes will be captured and written to the main process'
83+
`stdout`.
8084
"""
8185
self.raise_exceptions = raise_exceptions
86+
self.capture_stdout = capture_stdout
8287

8388
self.run_in_thread = run_in_thread
8489
if self.run_in_thread and shared_memory:
@@ -178,7 +183,13 @@ def add(
178183
queue = MTQueue()
179184
exc_queue = MTQueue()
180185

181-
_args = (queue, self.mgr, self.shared_memory_threshold, exc_queue) + args
186+
_args = (
187+
queue,
188+
self.mgr,
189+
self.shared_memory_threshold,
190+
exc_queue,
191+
None,
192+
) + args
182193
_wrapper = functools.partial(wrapper, target)
183194

184195
task = ThreadTask(
@@ -190,12 +201,25 @@ def add(
190201
else:
191202
queue = queue_type()
192203
exc_queue = queue_type()
204+
stdout_queue = queue_type()
193205

194-
_args = (queue, self.mgr, self.shared_memory_threshold, exc_queue) + args
206+
_args = (
207+
queue,
208+
self.mgr,
209+
self.shared_memory_threshold,
210+
exc_queue,
211+
stdout_queue,
212+
) + args
195213
_wrapper = functools.partial(wrapper, target)
196214

197215
process = process_type(target=_wrapper, args=_args)
198-
task = ProcessTask(process, queue, exc_queue=exc_queue, subtasks=subtasks)
216+
task = ProcessTask(
217+
process,
218+
queue,
219+
exc_queue=exc_queue,
220+
stdout_queue=stdout_queue,
221+
subtasks=subtasks,
222+
)
199223

200224
self.tasks.append(task)
201225

@@ -337,6 +361,7 @@ def terminate(self) -> None:
337361
"""Terminates all running tasks by killing their processes."""
338362
if not self.terminated:
339363
[t.terminate() for t in self.tasks]
364+
[self.stdout(t) for t in self.tasks]
340365
self.terminated = True
341366

342367
self._shutdown()
@@ -412,8 +437,11 @@ def _update(self) -> None:
412437

413438
for t in self.running_tasks:
414439
t.update()
440+
self.stdout(t)
441+
415442
if self.raise_exceptions and t.failed:
416443
self.failed = True
444+
417445
if t.exception_tb:
418446
raise TaskFailedException(t.exception_tb)
419447
else:
@@ -431,6 +459,13 @@ def _update(self) -> None:
431459
if schedule_new_tasks:
432460
self._schedule_tasks()
433461

462+
def stdout(self, task: Task) -> None:
463+
if self.capture_stdout:
464+
text = task.get_stdout()
465+
466+
if text:
467+
sys.stdout.write(text)
468+
434469
def _start(self) -> None:
435470
"""
436471
Starts the scheduler running its tasks.
@@ -514,16 +549,29 @@ def wrapper(
514549
manager: Optional["SharedMemoryManager"],
515550
threshold: int,
516551
exc_queue: Queue = None,
552+
stdout_queue: Queue = None,
517553
*args: Any,
518554
) -> None:
519555
"""
520556
Wrapper which calls a function with its specified arguments and puts the output in a queue.
521557
558+
This function will be the Callable executed in a process/thread.
559+
522560
:param function: the function which will be executed
523561
:param queue: a Queue object which may be used to transfer data between processes
524562
:param manager: a SharedMemoryManager or None; used to handle shared memory between processes
563+
:param threshold:
564+
:param exc_queue:
565+
:param stdout_queue:
525566
"""
567+
stdout = None
568+
526569
try:
570+
if stdout_queue:
571+
stdout = StdOut(stdout_queue)
572+
sys.stdout = stdout
573+
sys.stderr = stdout
574+
527575
result = function(*args)
528576
out = []
529577

@@ -542,6 +590,9 @@ def wrapper(
542590
else:
543591
out = tuple(out)
544592

593+
if stdout:
594+
stdout.update(force=True)
595+
545596
queue.put(out)
546597

547598
except Exception as e:

scheduler/Task.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
2222
from abc import ABC, abstractmethod
23+
from typing import Optional
2324

2425

2526
class Task(ABC):
@@ -28,16 +29,18 @@ class Task(ABC):
2829
passed to the process, so that the process can put its output in the queue.
2930
"""
3031

31-
def __init__(self, queue, exc_queue=None, subtasks: int = 0):
32+
def __init__(self, queue, exc_queue=None, stdout_queue=None, subtasks: int = 0):
3233
"""
33-
3434
:param queue: Queue for data to be passed.
3535
:param exc_queue: Queue for exceptions to be passed.
3636
:param subtasks: Number of subtasks.
3737
"""
3838
self.queue = queue
3939
self.exc_queue = exc_queue
4040

41+
self.stdout_queue = stdout_queue
42+
self.stdout_text = None
43+
4144
self.running = False
4245
self.finished = False
4346

@@ -89,3 +92,29 @@ def has_exception(self) -> bool:
8992
Returns whether the task has raised an exception.
9093
"""
9194
return self.exc_queue and not self.exc_queue.empty()
95+
96+
def has_stdout(self) -> bool:
97+
"""
98+
Returns whether the task has provided any `stdout`.
99+
"""
100+
return self.stdout_queue and not self.stdout_queue.empty()
101+
102+
def get_stdout(self) -> Optional[str]:
103+
"""
104+
Returns the `stdout` from the task.
105+
"""
106+
text = self.stdout_text or ""
107+
if self.has_stdout():
108+
lines = self.stdout_queue.get()
109+
110+
if lines:
111+
if isinstance(lines, tuple):
112+
lines = "\n".join(lines)
113+
114+
text = f"{text}{lines}"
115+
116+
if text:
117+
self.stdout_text = None
118+
return text
119+
120+
return None

scheduler/ThreadTask.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828

2929
class ThreadTask(Task):
3030
def __init__(self, thread: Thread, queue: Queue, exc_queue=None, subtasks: int = 0):
31-
super(ThreadTask, self).__init__(queue, exc_queue, subtasks)
31+
# stdout queue is not necessary for thread tasks.
32+
super(ThreadTask, self).__init__(queue, exc_queue, None, subtasks)
3233

3334
self.thread = thread
3435

scheduler/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,40 @@
1919
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
22+
import sys
23+
import time
2224
from multiprocessing import Process
2325

2426
import psutil
2527

2628

29+
class StdOut:
30+
def __init__(self, queue):
31+
self.period = 1
32+
33+
self.last_update = time.time()
34+
self.text = []
35+
36+
self.queue = queue
37+
38+
def write(self, text: str) -> None:
39+
self.text.append(text)
40+
self.update()
41+
42+
def update(self, force: bool = False) -> None:
43+
if self.text and (force or time.time() - self.last_update > self.period):
44+
out = "".join(self.text)
45+
46+
if out.strip():
47+
self.queue.put(out)
48+
49+
self.last_update = time.time()
50+
self.text = []
51+
52+
def flush(self) -> None:
53+
return
54+
55+
2756
def terminate_tree(process: Process):
2857
"""
2958
Terminates a process along with all of its child processes.

test/impl.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@
2121
# SOFTWARE.
2222

2323
import asyncio
24+
import sys
25+
import time
2426
from multiprocessing import Process, Queue
2527
from typing import List, Tuple
2628

2729
import multiprocess
2830

2931
from scheduler.ProcessTask import ProcessTask
30-
from scheduler.utils import SchedulerException, TaskFailedException
32+
from scheduler.utils import TaskFailedException
3133
from test.utils import (
3234
_get_input_output_numpy,
3335
_func_numpy,
@@ -44,6 +46,7 @@
4446
_func_no_return,
4547
_long_task,
4648
_func_raise_exception,
49+
_func_print,
4750
)
4851

4952

@@ -342,3 +345,27 @@ def test_raise_exception(scheduler):
342345
assert isinstance(e, TaskFailedException)
343346

344347
assert scheduler.failed
348+
349+
350+
text = ""
351+
352+
353+
def test_stdout(scheduler):
354+
global text
355+
356+
expected = "\n".join([f"{i}" for i in range(1000)]) + "\n"
357+
358+
def temp(_stdout):
359+
global text
360+
text += _stdout
361+
362+
write = sys.stdout.write
363+
sys.stdout.write = temp
364+
365+
scheduler.map_blocking(target=_func_print, args=[(expected,),])
366+
367+
while not text:
368+
time.sleep(0.01)
369+
370+
sys.stdout.write = write
371+
assert text == expected, "Text from stdout is incorrect."

test/test_stdout.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# MIT License
2+
#
3+
# Copyright (c) 2020 Sam McCormack
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
#
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
#
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
22+
from scheduler.Scheduler import Scheduler
23+
from test import impl
24+
25+
26+
def scheduler():
27+
return Scheduler(capture_stdout=True)
28+
29+
30+
def test_stdout():
31+
impl.test_stdout(scheduler())
32+
33+
34+
if __name__ == "__main__":
35+
test_stdout()

test/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,12 @@ def _func_raise_exception(x, y, z) -> None:
118118
raise TestException("Test exception.")
119119

120120

121+
def _func_print(text: str) -> None:
122+
for i in range(1000):
123+
print(i)
124+
time.sleep(0.001)
125+
126+
121127
class TestException(Exception):
122128
"""
123129
Exception raised for testing.

0 commit comments

Comments
 (0)