Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions investing_algorithm_framework/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,4 +254,13 @@
"VolumeBasedFill",
"FXRateProvider",
"StaticFXRateProvider",
"load_ipython_extension",
]


def load_ipython_extension(ipython):
"""Allow ``%load_ext investing_algorithm_framework`` to register
the ``%backtest`` / ``%%backtest`` magic commands."""
from investing_algorithm_framework.notebook import load_ipython_extension \
as _load
_load(ipython)
5 changes: 3 additions & 2 deletions investing_algorithm_framework/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,11 @@ def initialize_config(self):

if SQLALCHEMY_DATABASE_URI not in config \
or config[SQLALCHEMY_DATABASE_URI] is None:
path = "sqlite:///" + os.path.join(
db_path = os.path.join(
configuration_service.config[DATABASE_DIRECTORY_PATH],
configuration_service.config[DATABASE_NAME]
)
path = "sqlite:///" + db_path.replace("\\", "/")
configuration_service.add_value(SQLALCHEMY_DATABASE_URI, path)

def initialize_backtest_config(
Expand Down Expand Up @@ -394,7 +395,7 @@ def initialize_storage(self, remove_database_if_exists: bool = False):
os.remove(database_path)

# Create the sqlalchemy database uri
path = f"sqlite:///{database_path}"
path = "sqlite:///" + database_path.replace("\\", "/")
self.set_config(SQLALCHEMY_DATABASE_URI, path)

# Setup sql if needed
Expand Down
10 changes: 7 additions & 3 deletions investing_algorithm_framework/app/reporting/backtest_report.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import csv
import base64
import tempfile
import webbrowser
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Union
from datetime import datetime, timedelta

Expand Down Expand Up @@ -185,12 +187,14 @@ def show(self, backtest_date_range=None, browser=False):
if not self.html_report:
self.html_report = self._build_html()

path = "/tmp/backtest_report.html"
path = os.path.join(tempfile.gettempdir(), "backtest_report.html")
with open(path, "w", encoding="utf-8") as f:
f.write(self.html_report)

file_uri = Path(path).as_uri()

if browser:
webbrowser.open(f"file://{path}")
webbrowser.open(file_uri)
return

try:
Expand All @@ -203,7 +207,7 @@ def show(self, backtest_date_range=None, browser=False):
except (NameError, ImportError):
pass

webbrowser.open(f"file://{path}")
webbrowser.open(file_uri)

def save(self, path):
if not self.html_report:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import tempfile
from pathlib import Path
import webbrowser
from dataclasses import dataclass
Expand Down Expand Up @@ -117,12 +118,14 @@ def show(
self._create_html_report(backtest_date_range)

# Save the html report to a tmp location
path = "/tmp/backtest_report.html"
path = os.path.join(tempfile.gettempdir(), "backtest_report.html")
with open(path, "w") as html_file:
html_file.write(self.html_report)

file_uri = Path(path).as_uri()

if browser:
webbrowser.open(f"file://{path}")
webbrowser.open(file_uri)

def in_jupyter_notebook():
try:
Expand All @@ -134,7 +137,7 @@ def in_jupyter_notebook():
if in_jupyter_notebook():
display(HTML(self.html_report))
else:
webbrowser.open(f"file://{path}")
webbrowser.open(file_uri)

def _create_html_report(self, backtest_date_range: BacktestDateRange):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import gc
import json
import logging
import multiprocessing
import os
import numpy as np
import pandas as pd
import polars as pl
import threading
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Dict, List, Union, Optional, Callable
from typing import Callable, Dict, List, Optional, Union

import numpy as np
import pandas as pd
import polars as pl

from investing_algorithm_framework.domain import BacktestRun, TimeUnit, \
OperationalException, BacktestDateRange, Backtest, combine_backtests, \
Expand All @@ -27,6 +31,21 @@

logger = logging.getLogger(__name__)

# Module-level global used by worker processes. Set via _init_worker
# which is called once per worker by ProcessPoolExecutor's initializer.
_worker_data_provider_service = None


def _init_worker(data_provider_service):
"""Initializer for ProcessPoolExecutor workers.

Stores the data_provider_service in a module-level global so each
worker pickles/unpickles it only once at startup rather than per task.
This dramatically reduces overhead on Windows/WSL (spawn start method).
"""
global _worker_data_provider_service
_worker_data_provider_service = data_provider_service


def _print_progress(message: str, show_progress: bool = False):
"""
Expand Down Expand Up @@ -902,10 +921,6 @@ def run_vector_backtests(

if use_parallel:
# Parallel processing of backtests (batches per worker)
import multiprocessing
from concurrent.futures import \
ProcessPoolExecutor, as_completed

# Determine number of workers
if n_workers == -1:
n_workers = multiprocessing.cpu_count()
Expand Down Expand Up @@ -933,6 +948,19 @@ def run_vector_backtests(
show_progress
)

# Shared counter for strategy-level progress
# across all workers. Use Manager so the proxy
# object can be pickled by ProcessPoolExecutor.
manager = multiprocessing.Manager()
progress_counter = manager.Value('i', 0)

# Copy data provider once and pass via initializer
# so each worker inherits it at startup instead of
# pickling it per task (major speedup on Windows/WSL
# where spawn is used instead of fork).
shared_data_provider = \
self._data_provider_service.copy()

worker_args = []

for batch in strategy_batches:
Expand All @@ -943,13 +971,45 @@ def run_vector_backtests(
snapshot_interval,
risk_free_rate,
continue_on_error,
self._data_provider_service.copy(),
None, # placeholder, worker reads global
False,
dynamic_position_sizing
dynamic_position_sizing,
progress_counter,
))

# Execute batches in parallel
with ProcessPoolExecutor(max_workers=n_workers) as ex:
# Start a monitoring thread that updates a
# strategy-level progress bar in real time
total_strategies = len(strategies_to_run)
pbar = tqdm(
total=total_strategies,
colour="green",
desc="Running backtests for "
f"{start_date} to {end_date}",
disable=not show_progress,
unit="strategy",
)
stop_event = threading.Event()

def _monitor_progress():
while not stop_event.is_set():
pbar.n = progress_counter.value
pbar.refresh()
stop_event.wait(0.5)

monitor = threading.Thread(
target=_monitor_progress, daemon=True
)
monitor.start()

# Execute batches in parallel.
# Use initializer to pass data_provider_service
# once per worker process rather than pickling it
# with every submitted task.
with ProcessPoolExecutor(
max_workers=n_workers,
initializer=_init_worker,
initargs=(shared_data_provider,),
) as ex:
# Submit all batch tasks
futures = [
ex.submit(
Expand All @@ -961,15 +1021,8 @@ def run_vector_backtests(
# Track completed batches for periodic cleanup
completed_count = 0

# Collect results with progress bar
for future in tqdm(
as_completed(futures),
total=len(futures),
colour="green",
desc="Running backtests for "
f"{start_date} to {end_date}",
disable=not show_progress
):
# Collect results as batches complete
for future in as_completed(futures):
try:
batch_result = future.result()
if batch_result:
Expand Down Expand Up @@ -1006,6 +1059,15 @@ def run_vector_backtests(
else:
raise

# Stop the monitoring thread and finalise
# the progress bar
stop_event.set()
monitor.join()
pbar.n = progress_counter.value
pbar.refresh()
pbar.close()
manager.shutdown()

# Save remaining batch and create checkpoint files when
# storage directory provided
if backtest_storage_directory is not None:
Expand Down Expand Up @@ -1309,8 +1371,6 @@ def run_vector_backtests(
combined_backtests.append(backtests_list[0])
else:
# Combine multiple backtests for the same algorithm
from investing_algorithm_framework.domain import (
combine_backtests)
combined = combine_backtests(backtests_list)
combined_backtests.append(combined)

Expand Down Expand Up @@ -1709,23 +1769,46 @@ def _run_batch_backtest_worker(args):
continue_on_error,
data_provider_service,
show_progress,
dynamic_position_sizing
dynamic_position_sizing,
progress_counter (optional),
)

Returns:
List[Backtest]: List of completed backtest results
"""
(
strategy_batch,
backtest_date_range,
portfolio_configuration,
snapshot_interval,
risk_free_rate,
continue_on_error,
data_provider_service,
show_progress,
dynamic_position_sizing
) = args
# Support both old (9-element) and new (10-element) tuple
if len(args) == 10:
(
strategy_batch,
backtest_date_range,
portfolio_configuration,
snapshot_interval,
risk_free_rate,
continue_on_error,
data_provider_service,
show_progress,
dynamic_position_sizing,
progress_counter,
) = args
else:
(
strategy_batch,
backtest_date_range,
portfolio_configuration,
snapshot_interval,
risk_free_rate,
continue_on_error,
data_provider_service,
show_progress,
dynamic_position_sizing,
) = args
progress_counter = None

# Use the worker-global data provider if none was passed
# directly (parallel mode passes None and relies on the
# initializer to set the global once per worker process).
if data_provider_service is None:
data_provider_service = _worker_data_provider_service

vector_backtest_service = VectorBacktestService(
data_provider_service=data_provider_service
Expand Down Expand Up @@ -1768,12 +1851,21 @@ def _run_batch_backtest_worker(args):
)
batch_results.append(backtest)

# Increment shared progress counter so the
# main process can track per-strategy progress
if progress_counter is not None:
progress_counter.value += 1

except Exception as e:
if continue_on_error:
logger.error(
"Worker error for strategy "
f"{strategy.algorithm_id}: {e}"
)
# Still increment counter for failed strategies
# so progress total stays accurate
if progress_counter is not None:
progress_counter.value += 1
continue
else:
raise
Expand Down
3 changes: 3 additions & 0 deletions investing_algorithm_framework/notebook/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .magic import load_ipython_extension

__all__ = ["load_ipython_extension"]
Loading
Loading