Skip to content

Commit 0d3f37c

Browse files
authored
Merge pull request #481 from coding-kitties/dev
feat: Jupyter Notebook / IPython Magic Integration (#454)
2 parents 81f14ae + 43a2906 commit 0d3f37c

9 files changed

Lines changed: 707 additions & 43 deletions

File tree

investing_algorithm_framework/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,4 +254,13 @@
254254
"VolumeBasedFill",
255255
"FXRateProvider",
256256
"StaticFXRateProvider",
257+
"load_ipython_extension",
257258
]
259+
260+
261+
def load_ipython_extension(ipython):
262+
"""Allow ``%load_ext investing_algorithm_framework`` to register
263+
the ``%backtest`` / ``%%backtest`` magic commands."""
264+
from investing_algorithm_framework.notebook import load_ipython_extension \
265+
as _load
266+
_load(ipython)

investing_algorithm_framework/app/app.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,11 @@ def initialize_config(self):
279279

280280
if SQLALCHEMY_DATABASE_URI not in config \
281281
or config[SQLALCHEMY_DATABASE_URI] is None:
282-
path = "sqlite:///" + os.path.join(
282+
db_path = os.path.join(
283283
configuration_service.config[DATABASE_DIRECTORY_PATH],
284284
configuration_service.config[DATABASE_NAME]
285285
)
286+
path = "sqlite:///" + db_path.replace("\\", "/")
286287
configuration_service.add_value(SQLALCHEMY_DATABASE_URI, path)
287288

288289
def initialize_backtest_config(
@@ -394,7 +395,7 @@ def initialize_storage(self, remove_database_if_exists: bool = False):
394395
os.remove(database_path)
395396

396397
# Create the sqlalchemy database uri
397-
path = f"sqlite:///{database_path}"
398+
path = "sqlite:///" + database_path.replace("\\", "/")
398399
self.set_config(SQLALCHEMY_DATABASE_URI, path)
399400

400401
# Setup sql if needed

investing_algorithm_framework/app/reporting/backtest_report.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import os
22
import csv
33
import base64
4+
import tempfile
45
import webbrowser
56
import logging
67
from dataclasses import dataclass, field
8+
from pathlib import Path
79
from typing import List, Union
810
from datetime import datetime, timedelta
911

@@ -185,12 +187,14 @@ def show(self, backtest_date_range=None, browser=False):
185187
if not self.html_report:
186188
self.html_report = self._build_html()
187189

188-
path = "/tmp/backtest_report.html"
190+
path = os.path.join(tempfile.gettempdir(), "backtest_report.html")
189191
with open(path, "w", encoding="utf-8") as f:
190192
f.write(self.html_report)
191193

194+
file_uri = Path(path).as_uri()
195+
192196
if browser:
193-
webbrowser.open(f"file://{path}")
197+
webbrowser.open(file_uri)
194198
return
195199

196200
try:
@@ -203,7 +207,7 @@ def show(self, backtest_date_range=None, browser=False):
203207
except (NameError, ImportError):
204208
pass
205209

206-
webbrowser.open(f"file://{path}")
210+
webbrowser.open(file_uri)
207211

208212
def save(self, path):
209213
if not self.html_report:

investing_algorithm_framework/app/reporting/backtest_report_old.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import os
3+
import tempfile
34
from pathlib import Path
45
import webbrowser
56
from dataclasses import dataclass
@@ -117,12 +118,14 @@ def show(
117118
self._create_html_report(backtest_date_range)
118119

119120
# Save the html report to a tmp location
120-
path = "/tmp/backtest_report.html"
121+
path = os.path.join(tempfile.gettempdir(), "backtest_report.html")
121122
with open(path, "w") as html_file:
122123
html_file.write(self.html_report)
123124

125+
file_uri = Path(path).as_uri()
126+
124127
if browser:
125-
webbrowser.open(f"file://{path}")
128+
webbrowser.open(file_uri)
126129

127130
def in_jupyter_notebook():
128131
try:
@@ -134,7 +137,7 @@ def in_jupyter_notebook():
134137
if in_jupyter_notebook():
135138
display(HTML(self.html_report))
136139
else:
137-
webbrowser.open(f"file://{path}")
140+
webbrowser.open(file_uri)
138141

139142
def _create_html_report(self, backtest_date_range: BacktestDateRange):
140143
"""

investing_algorithm_framework/infrastructure/services/backtesting/backtest_service.py

Lines changed: 127 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import gc
22
import json
33
import logging
4+
import multiprocessing
45
import os
5-
import numpy as np
6-
import pandas as pd
7-
import polars as pl
6+
import threading
87
from collections import defaultdict
8+
from concurrent.futures import ProcessPoolExecutor, as_completed
99
from datetime import datetime, timedelta, timezone
1010
from pathlib import Path
11-
from typing import Dict, List, Union, Optional, Callable
11+
from typing import Callable, Dict, List, Optional, Union
12+
13+
import numpy as np
14+
import pandas as pd
15+
import polars as pl
1216

1317
from investing_algorithm_framework.domain import BacktestRun, TimeUnit, \
1418
OperationalException, BacktestDateRange, Backtest, combine_backtests, \
@@ -27,6 +31,21 @@
2731

2832
logger = logging.getLogger(__name__)
2933

34+
# Module-level global used by worker processes. Set via _init_worker
35+
# which is called once per worker by ProcessPoolExecutor's initializer.
36+
_worker_data_provider_service = None
37+
38+
39+
def _init_worker(data_provider_service):
40+
"""Initializer for ProcessPoolExecutor workers.
41+
42+
Stores the data_provider_service in a module-level global so each
43+
worker pickles/unpickles it only once at startup rather than per task.
44+
This dramatically reduces overhead on Windows/WSL (spawn start method).
45+
"""
46+
global _worker_data_provider_service
47+
_worker_data_provider_service = data_provider_service
48+
3049

3150
def _print_progress(message: str, show_progress: bool = False):
3251
"""
@@ -902,10 +921,6 @@ def run_vector_backtests(
902921

903922
if use_parallel:
904923
# Parallel processing of backtests (batches per worker)
905-
import multiprocessing
906-
from concurrent.futures import \
907-
ProcessPoolExecutor, as_completed
908-
909924
# Determine number of workers
910925
if n_workers == -1:
911926
n_workers = multiprocessing.cpu_count()
@@ -933,6 +948,19 @@ def run_vector_backtests(
933948
show_progress
934949
)
935950

951+
# Shared counter for strategy-level progress
952+
# across all workers. Use Manager so the proxy
953+
# object can be pickled by ProcessPoolExecutor.
954+
manager = multiprocessing.Manager()
955+
progress_counter = manager.Value('i', 0)
956+
957+
# Copy data provider once and pass via initializer
958+
# so each worker inherits it at startup instead of
959+
# pickling it per task (major speedup on Windows/WSL
960+
# where spawn is used instead of fork).
961+
shared_data_provider = \
962+
self._data_provider_service.copy()
963+
936964
worker_args = []
937965

938966
for batch in strategy_batches:
@@ -943,13 +971,45 @@ def run_vector_backtests(
943971
snapshot_interval,
944972
risk_free_rate,
945973
continue_on_error,
946-
self._data_provider_service.copy(),
974+
None, # placeholder, worker reads global
947975
False,
948-
dynamic_position_sizing
976+
dynamic_position_sizing,
977+
progress_counter,
949978
))
950979

951-
# Execute batches in parallel
952-
with ProcessPoolExecutor(max_workers=n_workers) as ex:
980+
# Start a monitoring thread that updates a
981+
# strategy-level progress bar in real time
982+
total_strategies = len(strategies_to_run)
983+
pbar = tqdm(
984+
total=total_strategies,
985+
colour="green",
986+
desc="Running backtests for "
987+
f"{start_date} to {end_date}",
988+
disable=not show_progress,
989+
unit="strategy",
990+
)
991+
stop_event = threading.Event()
992+
993+
def _monitor_progress():
994+
while not stop_event.is_set():
995+
pbar.n = progress_counter.value
996+
pbar.refresh()
997+
stop_event.wait(0.5)
998+
999+
monitor = threading.Thread(
1000+
target=_monitor_progress, daemon=True
1001+
)
1002+
monitor.start()
1003+
1004+
# Execute batches in parallel.
1005+
# Use initializer to pass data_provider_service
1006+
# once per worker process rather than pickling it
1007+
# with every submitted task.
1008+
with ProcessPoolExecutor(
1009+
max_workers=n_workers,
1010+
initializer=_init_worker,
1011+
initargs=(shared_data_provider,),
1012+
) as ex:
9531013
# Submit all batch tasks
9541014
futures = [
9551015
ex.submit(
@@ -961,15 +1021,8 @@ def run_vector_backtests(
9611021
# Track completed batches for periodic cleanup
9621022
completed_count = 0
9631023

964-
# Collect results with progress bar
965-
for future in tqdm(
966-
as_completed(futures),
967-
total=len(futures),
968-
colour="green",
969-
desc="Running backtests for "
970-
f"{start_date} to {end_date}",
971-
disable=not show_progress
972-
):
1024+
# Collect results as batches complete
1025+
for future in as_completed(futures):
9731026
try:
9741027
batch_result = future.result()
9751028
if batch_result:
@@ -1006,6 +1059,15 @@ def run_vector_backtests(
10061059
else:
10071060
raise
10081061

1062+
# Stop the monitoring thread and finalise
1063+
# the progress bar
1064+
stop_event.set()
1065+
monitor.join()
1066+
pbar.n = progress_counter.value
1067+
pbar.refresh()
1068+
pbar.close()
1069+
manager.shutdown()
1070+
10091071
# Save remaining batch and create checkpoint files when
10101072
# storage directory provided
10111073
if backtest_storage_directory is not None:
@@ -1309,8 +1371,6 @@ def run_vector_backtests(
13091371
combined_backtests.append(backtests_list[0])
13101372
else:
13111373
# Combine multiple backtests for the same algorithm
1312-
from investing_algorithm_framework.domain import (
1313-
combine_backtests)
13141374
combined = combine_backtests(backtests_list)
13151375
combined_backtests.append(combined)
13161376

@@ -1709,23 +1769,46 @@ def _run_batch_backtest_worker(args):
17091769
continue_on_error,
17101770
data_provider_service,
17111771
show_progress,
1712-
dynamic_position_sizing
1772+
dynamic_position_sizing,
1773+
progress_counter (optional),
17131774
)
17141775
17151776
Returns:
17161777
List[Backtest]: List of completed backtest results
17171778
"""
1718-
(
1719-
strategy_batch,
1720-
backtest_date_range,
1721-
portfolio_configuration,
1722-
snapshot_interval,
1723-
risk_free_rate,
1724-
continue_on_error,
1725-
data_provider_service,
1726-
show_progress,
1727-
dynamic_position_sizing
1728-
) = args
1779+
# Support both old (9-element) and new (10-element) tuple
1780+
if len(args) == 10:
1781+
(
1782+
strategy_batch,
1783+
backtest_date_range,
1784+
portfolio_configuration,
1785+
snapshot_interval,
1786+
risk_free_rate,
1787+
continue_on_error,
1788+
data_provider_service,
1789+
show_progress,
1790+
dynamic_position_sizing,
1791+
progress_counter,
1792+
) = args
1793+
else:
1794+
(
1795+
strategy_batch,
1796+
backtest_date_range,
1797+
portfolio_configuration,
1798+
snapshot_interval,
1799+
risk_free_rate,
1800+
continue_on_error,
1801+
data_provider_service,
1802+
show_progress,
1803+
dynamic_position_sizing,
1804+
) = args
1805+
progress_counter = None
1806+
1807+
# Use the worker-global data provider if none was passed
1808+
# directly (parallel mode passes None and relies on the
1809+
# initializer to set the global once per worker process).
1810+
if data_provider_service is None:
1811+
data_provider_service = _worker_data_provider_service
17291812

17301813
vector_backtest_service = VectorBacktestService(
17311814
data_provider_service=data_provider_service
@@ -1768,12 +1851,21 @@ def _run_batch_backtest_worker(args):
17681851
)
17691852
batch_results.append(backtest)
17701853

1854+
# Increment shared progress counter so the
1855+
# main process can track per-strategy progress
1856+
if progress_counter is not None:
1857+
progress_counter.value += 1
1858+
17711859
except Exception as e:
17721860
if continue_on_error:
17731861
logger.error(
17741862
"Worker error for strategy "
17751863
f"{strategy.algorithm_id}: {e}"
17761864
)
1865+
# Still increment counter for failed strategies
1866+
# so progress total stays accurate
1867+
if progress_counter is not None:
1868+
progress_counter.value += 1
17771869
continue
17781870
else:
17791871
raise
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .magic import load_ipython_extension
2+
3+
__all__ = ["load_ipython_extension"]

0 commit comments

Comments
 (0)