Skip to content

Commit 39589c5

Browse files
committed
perf(backtest): pin BLAS threads and use spawn ctx for parallel windows backtests - Use a single 'spawn' multiprocessing context so worker BLAS / OpenMP / Polars thread limits set in the initializer take effect (with 'fork' numpy/polars are already loaded in the parent and env vars are ignored). - Pin OMP/OPENBLAS/MKL/NUMEXPR/VECLIB/POLARS thread pools to 1 per worker in _init_worker to avoid N^2 thread oversubscription that was causing severe slowdowns on Windows/WSL. - Replace Manager().Value progress counter with mp_ctx.Value (shared memory + semaphore) and inherit it via the worker initializer. Manager proxies do an IPC round-trip per read/write through a separate manager process, which was making the progress bar appear frozen and serialized worker writes. - Cap default n_workers at min(cpu_count-1, 8) to reduce IPC / BLAS contention; full cpu_count workers is consistently slower on Windows/WSL. - Tighten tqdm refresh (mininterval=0, miniters=1, monitor wait 0.25s) so the progress bar updates promptly on Windows where stdout is line-buffered. - Use lock-protected increments on the shared counter for atomicity.
1 parent 59463ed commit 39589c5

1 file changed

Lines changed: 87 additions & 25 deletions

File tree

investing_algorithm_framework/infrastructure/services/backtesting/backtest_service.py

Lines changed: 87 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,42 @@
3131

3232
logger = logging.getLogger(__name__)
3333

34-
# Module-level global used by worker processes. Set via _init_worker
34+
# Module-level globals used by worker processes. Set via _init_worker
3535
# which is called once per worker by ProcessPoolExecutor's initializer.
3636
_worker_data_provider_service = None
37+
_worker_progress_counter = None
3738

3839

39-
def _init_worker(data_provider_service):
40+
def _init_worker(data_provider_service, progress_counter=None):
4041
"""Initializer for ProcessPoolExecutor workers.
4142
42-
Stores the data_provider_service in a module-level global so each
43-
worker pickles/unpickles it only once at startup rather than per task.
44-
This dramatically reduces overhead on Windows/WSL (spawn start method).
43+
Stores the data_provider_service and a shared progress counter in
44+
module-level globals so each worker inherits them once at startup
45+
rather than pickling them per task. This dramatically reduces
46+
overhead on Windows/WSL (spawn start method).
47+
48+
Also pins BLAS / OpenMP / Polars thread pools to a single thread per
49+
worker. Without this each worker tries to use all CPU cores for
50+
numpy / pandas / polars operations, causing N² thread oversubscription
51+
and severe slowdowns on Windows/WSL. These env vars must be set
52+
before numpy / polars are imported, which is why ``spawn`` is used
53+
as the start method (with ``fork`` they would have no effect because
54+
those libraries are already loaded in the parent).
4555
"""
46-
global _worker_data_provider_service
56+
# Pin math library thread pools to 1 thread per worker.
57+
for var in (
58+
"OMP_NUM_THREADS",
59+
"OPENBLAS_NUM_THREADS",
60+
"MKL_NUM_THREADS",
61+
"NUMEXPR_NUM_THREADS",
62+
"VECLIB_MAXIMUM_THREADS",
63+
"POLARS_MAX_THREADS",
64+
):
65+
os.environ.setdefault(var, "1")
66+
67+
global _worker_data_provider_service, _worker_progress_counter
4768
_worker_data_provider_service = data_provider_service
69+
_worker_progress_counter = progress_counter
4870

4971

5072
def _print_progress(message: str, show_progress: bool = False):
@@ -921,9 +943,13 @@ def run_vector_backtests(
921943

922944
if use_parallel:
923945
# Parallel processing of backtests (batches per worker)
924-
# Determine number of workers
946+
# Determine number of workers. Cap at 8 by default to
947+
# avoid BLAS / IPC contention on Windows/WSL where
948+
# cpu_count() workers is usually slower than fewer.
925949
if n_workers == -1:
926-
n_workers = multiprocessing.cpu_count()
950+
n_workers = min(
951+
max(multiprocessing.cpu_count() - 1, 1), 8
952+
)
927953

928954
# Calculate optimal batch size per worker
929955
# Each worker processes a batch of strategies
@@ -948,11 +974,27 @@ def run_vector_backtests(
948974
show_progress
949975
)
950976

951-
# Shared counter for strategy-level progress
952-
# across all workers. Use Manager so the proxy
953-
# object can be pickled by ProcessPoolExecutor.
954-
manager = multiprocessing.Manager()
955-
progress_counter = manager.Value('i', 0)
977+
# Use a single ``spawn`` context for everything.
978+
# On WSL/Linux the default is ``fork`` which copies
979+
# the entire parent process into each worker,
980+
# bloating workers and preventing the BLAS thread
981+
# env vars set in ``_init_worker`` from taking
982+
# effect (numpy/polars are already loaded). With
983+
# ``spawn`` workers start with a clean interpreter
984+
# and those env vars are honoured.
985+
mp_ctx = multiprocessing.get_context("spawn")
986+
987+
# Shared counter for strategy-level progress across
988+
# all workers. Use ``mp_ctx.Value`` (shared memory +
989+
# semaphore) instead of ``Manager().Value`` which
990+
# is a proxy that performs an IPC round-trip for
991+
# every read/write through a separate manager
992+
# process. Manager proxies are catastrophically
993+
# slow on Windows/WSL when many workers update the
994+
# same counter, and they also cause the progress
995+
# bar to appear frozen because the monitor thread's
996+
# reads queue behind worker writes.
997+
progress_counter = mp_ctx.Value('i', 0)
956998

957999
# Copy data provider once and pass via initializer
9581000
# so each worker inherits it at startup instead of
@@ -974,11 +1016,15 @@ def run_vector_backtests(
9741016
None, # placeholder, worker reads global
9751017
False,
9761018
dynamic_position_sizing,
977-
progress_counter,
1019+
None, # progress_counter inherited via init
9781020
))
9791021

9801022
# Start a monitoring thread that updates a
981-
# strategy-level progress bar in real time
1023+
# strategy-level progress bar in real time. Use
1024+
# mininterval=0 / miniters=1 so the bar refreshes
1025+
# promptly on Windows/WSL where stdout is often
1026+
# line-buffered and tqdm's default smoothing can
1027+
# otherwise make the bar appear frozen.
9821028
total_strategies = len(strategies_to_run)
9831029
pbar = tqdm(
9841030
total=total_strategies,
@@ -987,28 +1033,35 @@ def run_vector_backtests(
9871033
f"{start_date} to {end_date}",
9881034
disable=not show_progress,
9891035
unit="strategy",
1036+
mininterval=0,
1037+
miniters=1,
9901038
)
9911039
stop_event = threading.Event()
9921040

9931041
def _monitor_progress():
9941042
while not stop_event.is_set():
9951043
pbar.n = progress_counter.value
9961044
pbar.refresh()
997-
stop_event.wait(0.5)
1045+
stop_event.wait(0.25)
9981046

9991047
monitor = threading.Thread(
10001048
target=_monitor_progress, daemon=True
10011049
)
10021050
monitor.start()
10031051

1004-
# Execute batches in parallel.
1005-
# Use initializer to pass data_provider_service
1006-
# once per worker process rather than pickling it
1007-
# with every submitted task.
1052+
# Execute batches in parallel using the spawn pool
1053+
# created above. The shared ``progress_counter``
1054+
# (a ``mp_ctx.Value``) and ``shared_data_provider``
1055+
# are passed through the initializer so they are
1056+
# inherited once per worker rather than pickled per
1057+
# task.
10081058
with ProcessPoolExecutor(
10091059
max_workers=n_workers,
1060+
mp_context=mp_ctx,
10101061
initializer=_init_worker,
1011-
initargs=(shared_data_provider,),
1062+
initargs=(
1063+
shared_data_provider, progress_counter,
1064+
),
10121065
) as ex:
10131066
# Submit all batch tasks
10141067
futures = [
@@ -1066,7 +1119,6 @@ def _monitor_progress():
10661119
pbar.n = progress_counter.value
10671120
pbar.refresh()
10681121
pbar.close()
1069-
manager.shutdown()
10701122

10711123
# Save remaining batch and create checkpoint files when
10721124
# storage directory provided
@@ -1810,6 +1862,12 @@ def _run_batch_backtest_worker(args):
18101862
if data_provider_service is None:
18111863
data_provider_service = _worker_data_provider_service
18121864

1865+
# In parallel mode the progress counter is inherited via the
1866+
# worker initializer (multiprocessing.Value, shared memory).
1867+
# Fall back to the per-task argument for backward compatibility.
1868+
if progress_counter is None:
1869+
progress_counter = _worker_progress_counter
1870+
18131871
vector_backtest_service = VectorBacktestService(
18141872
data_provider_service=data_provider_service
18151873
)
@@ -1852,9 +1910,12 @@ def _run_batch_backtest_worker(args):
18521910
batch_results.append(backtest)
18531911

18541912
# Increment shared progress counter so the
1855-
# main process can track per-strategy progress
1913+
# main process can track per-strategy progress.
1914+
# ``multiprocessing.Value`` lives in shared memory; the
1915+
# lock makes ``+= 1`` atomic across workers.
18561916
if progress_counter is not None:
1857-
progress_counter.value += 1
1917+
with progress_counter.get_lock():
1918+
progress_counter.value += 1
18581919

18591920
except Exception as e:
18601921
if continue_on_error:
@@ -1865,7 +1926,8 @@ def _run_batch_backtest_worker(args):
18651926
# Still increment counter for failed strategies
18661927
# so progress total stays accurate
18671928
if progress_counter is not None:
1868-
progress_counter.value += 1
1929+
with progress_counter.get_lock():
1930+
progress_counter.value += 1
18691931
continue
18701932
else:
18711933
raise

0 commit comments

Comments
 (0)