diff --git a/investing_algorithm_framework/__init__.py b/investing_algorithm_framework/__init__.py index eeb7e0ae..719d12aa 100644 --- a/investing_algorithm_framework/__init__.py +++ b/investing_algorithm_framework/__init__.py @@ -254,4 +254,13 @@ "VolumeBasedFill", "FXRateProvider", "StaticFXRateProvider", + "load_ipython_extension", ] + + +def load_ipython_extension(ipython): + """Allow ``%load_ext investing_algorithm_framework`` to register + the ``%backtest`` / ``%%backtest`` magic commands.""" + from investing_algorithm_framework.notebook import load_ipython_extension \ + as _load + _load(ipython) diff --git a/investing_algorithm_framework/app/app.py b/investing_algorithm_framework/app/app.py index e8b9749f..f18c06f6 100644 --- a/investing_algorithm_framework/app/app.py +++ b/investing_algorithm_framework/app/app.py @@ -279,10 +279,11 @@ def initialize_config(self): if SQLALCHEMY_DATABASE_URI not in config \ or config[SQLALCHEMY_DATABASE_URI] is None: - path = "sqlite:///" + os.path.join( + db_path = os.path.join( configuration_service.config[DATABASE_DIRECTORY_PATH], configuration_service.config[DATABASE_NAME] ) + path = "sqlite:///" + db_path.replace("\\", "/") configuration_service.add_value(SQLALCHEMY_DATABASE_URI, path) def initialize_backtest_config( @@ -394,7 +395,7 @@ def initialize_storage(self, remove_database_if_exists: bool = False): os.remove(database_path) # Create the sqlalchemy database uri - path = f"sqlite:///{database_path}" + path = "sqlite:///" + database_path.replace("\\", "/") self.set_config(SQLALCHEMY_DATABASE_URI, path) # Setup sql if needed diff --git a/investing_algorithm_framework/app/reporting/backtest_report.py b/investing_algorithm_framework/app/reporting/backtest_report.py index b1c2bbcc..2daf83f2 100644 --- a/investing_algorithm_framework/app/reporting/backtest_report.py +++ b/investing_algorithm_framework/app/reporting/backtest_report.py @@ -1,9 +1,11 @@ import os import csv import base64 +import tempfile import webbrowser import logging from dataclasses import dataclass, field +from pathlib import Path from typing import List, Union from datetime import datetime, timedelta @@ -185,12 +187,14 @@ def show(self, backtest_date_range=None, browser=False): if not self.html_report: self.html_report = self._build_html() - path = "/tmp/backtest_report.html" + path = os.path.join(tempfile.gettempdir(), "backtest_report.html") with open(path, "w", encoding="utf-8") as f: f.write(self.html_report) + file_uri = Path(path).as_uri() + if browser: - webbrowser.open(f"file://{path}") + webbrowser.open(file_uri) return try: @@ -203,7 +207,7 @@ def show(self, backtest_date_range=None, browser=False): except (NameError, ImportError): pass - webbrowser.open(f"file://{path}") + webbrowser.open(file_uri) def save(self, path): if not self.html_report: diff --git a/investing_algorithm_framework/app/reporting/backtest_report_old.py b/investing_algorithm_framework/app/reporting/backtest_report_old.py index 4f0b8961..ea6335af 100644 --- a/investing_algorithm_framework/app/reporting/backtest_report_old.py +++ b/investing_algorithm_framework/app/reporting/backtest_report_old.py @@ -1,5 +1,6 @@ import logging import os +import tempfile from pathlib import Path import webbrowser from dataclasses import dataclass @@ -117,12 +118,14 @@ def show( self._create_html_report(backtest_date_range) # Save the html report to a tmp location - path = "/tmp/backtest_report.html" + path = os.path.join(tempfile.gettempdir(), "backtest_report.html") with open(path, "w") as html_file: html_file.write(self.html_report) + file_uri = Path(path).as_uri() + if browser: - webbrowser.open(f"file://{path}") + webbrowser.open(file_uri) def in_jupyter_notebook(): try: @@ -134,7 +137,7 @@ def in_jupyter_notebook(): if in_jupyter_notebook(): display(HTML(self.html_report)) else: - webbrowser.open(f"file://{path}") + webbrowser.open(file_uri) def _create_html_report(self, backtest_date_range: BacktestDateRange): """ diff --git a/investing_algorithm_framework/infrastructure/services/backtesting/backtest_service.py b/investing_algorithm_framework/infrastructure/services/backtesting/backtest_service.py index c326b119..ad669e4e 100644 --- a/investing_algorithm_framework/infrastructure/services/backtesting/backtest_service.py +++ b/investing_algorithm_framework/infrastructure/services/backtesting/backtest_service.py @@ -1,14 +1,18 @@ import gc import json import logging +import multiprocessing import os -import numpy as np -import pandas as pd -import polars as pl +import threading from collections import defaultdict +from concurrent.futures import ProcessPoolExecutor, as_completed from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Dict, List, Union, Optional, Callable +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import pandas as pd +import polars as pl from investing_algorithm_framework.domain import BacktestRun, TimeUnit, \ OperationalException, BacktestDateRange, Backtest, combine_backtests, \ @@ -27,6 +31,21 @@ logger = logging.getLogger(__name__) +# Module-level global used by worker processes. Set via _init_worker +# which is called once per worker by ProcessPoolExecutor's initializer. +_worker_data_provider_service = None + + +def _init_worker(data_provider_service): + """Initializer for ProcessPoolExecutor workers. + + Stores the data_provider_service in a module-level global so each + worker pickles/unpickles it only once at startup rather than per task. + This dramatically reduces overhead on Windows/WSL (spawn start method). + """ + global _worker_data_provider_service + _worker_data_provider_service = data_provider_service + def _print_progress(message: str, show_progress: bool = False): """ @@ -902,10 +921,6 @@ def run_vector_backtests( if use_parallel: # Parallel processing of backtests (batches per worker) - import multiprocessing - from concurrent.futures import \ - ProcessPoolExecutor, as_completed - # Determine number of workers if n_workers == -1: n_workers = multiprocessing.cpu_count() @@ -933,6 +948,19 @@ def run_vector_backtests( show_progress ) + # Shared counter for strategy-level progress + # across all workers. Use Manager so the proxy + # object can be pickled by ProcessPoolExecutor. + manager = multiprocessing.Manager() + progress_counter = manager.Value('i', 0) + + # Copy data provider once and pass via initializer + # so each worker inherits it at startup instead of + # pickling it per task (major speedup on Windows/WSL + # where spawn is used instead of fork). + shared_data_provider = \ + self._data_provider_service.copy() + worker_args = [] for batch in strategy_batches: @@ -943,13 +971,45 @@ def run_vector_backtests( snapshot_interval, risk_free_rate, continue_on_error, - self._data_provider_service.copy(), + None, # placeholder, worker reads global False, - dynamic_position_sizing + dynamic_position_sizing, + progress_counter, )) - # Execute batches in parallel - with ProcessPoolExecutor(max_workers=n_workers) as ex: + # Start a monitoring thread that updates a + # strategy-level progress bar in real time + total_strategies = len(strategies_to_run) + pbar = tqdm( + total=total_strategies, + colour="green", + desc="Running backtests for " + f"{start_date} to {end_date}", + disable=not show_progress, + unit="strategy", + ) + stop_event = threading.Event() + + def _monitor_progress(): + while not stop_event.is_set(): + pbar.n = progress_counter.value + pbar.refresh() + stop_event.wait(0.5) + + monitor = threading.Thread( + target=_monitor_progress, daemon=True + ) + monitor.start() + + # Execute batches in parallel. + # Use initializer to pass data_provider_service + # once per worker process rather than pickling it + # with every submitted task. + with ProcessPoolExecutor( + max_workers=n_workers, + initializer=_init_worker, + initargs=(shared_data_provider,), + ) as ex: # Submit all batch tasks futures = [ ex.submit( @@ -961,15 +1021,8 @@ def run_vector_backtests( # Track completed batches for periodic cleanup completed_count = 0 - # Collect results with progress bar - for future in tqdm( - as_completed(futures), - total=len(futures), - colour="green", - desc="Running backtests for " - f"{start_date} to {end_date}", - disable=not show_progress - ): + # Collect results as batches complete + for future in as_completed(futures): try: batch_result = future.result() if batch_result: @@ -1006,6 +1059,15 @@ def run_vector_backtests( else: raise + # Stop the monitoring thread and finalise + # the progress bar + stop_event.set() + monitor.join() + pbar.n = progress_counter.value + pbar.refresh() + pbar.close() + manager.shutdown() + # Save remaining batch and create checkpoint files when # storage directory provided if backtest_storage_directory is not None: @@ -1309,8 +1371,6 @@ def run_vector_backtests( combined_backtests.append(backtests_list[0]) else: # Combine multiple backtests for the same algorithm - from investing_algorithm_framework.domain import ( - combine_backtests) combined = combine_backtests(backtests_list) combined_backtests.append(combined) @@ -1709,23 +1769,46 @@ def _run_batch_backtest_worker(args): continue_on_error, data_provider_service, show_progress, - dynamic_position_sizing + dynamic_position_sizing, + progress_counter (optional), ) Returns: List[Backtest]: List of completed backtest results """ - ( - strategy_batch, - backtest_date_range, - portfolio_configuration, - snapshot_interval, - risk_free_rate, - continue_on_error, - data_provider_service, - show_progress, - dynamic_position_sizing - ) = args + # Support both old (9-element) and new (10-element) tuple + if len(args) == 10: + ( + strategy_batch, + backtest_date_range, + portfolio_configuration, + snapshot_interval, + risk_free_rate, + continue_on_error, + data_provider_service, + show_progress, + dynamic_position_sizing, + progress_counter, + ) = args + else: + ( + strategy_batch, + backtest_date_range, + portfolio_configuration, + snapshot_interval, + risk_free_rate, + continue_on_error, + data_provider_service, + show_progress, + dynamic_position_sizing, + ) = args + progress_counter = None + + # Use the worker-global data provider if none was passed + # directly (parallel mode passes None and relies on the + # initializer to set the global once per worker process). + if data_provider_service is None: + data_provider_service = _worker_data_provider_service vector_backtest_service = VectorBacktestService( data_provider_service=data_provider_service @@ -1768,12 +1851,21 @@ def _run_batch_backtest_worker(args): ) batch_results.append(backtest) + # Increment shared progress counter so the + # main process can track per-strategy progress + if progress_counter is not None: + progress_counter.value += 1 + except Exception as e: if continue_on_error: logger.error( "Worker error for strategy " f"{strategy.algorithm_id}: {e}" ) + # Still increment counter for failed strategies + # so progress total stays accurate + if progress_counter is not None: + progress_counter.value += 1 continue else: raise diff --git a/investing_algorithm_framework/notebook/__init__.py b/investing_algorithm_framework/notebook/__init__.py new file mode 100644 index 00000000..d5cef44f --- /dev/null +++ b/investing_algorithm_framework/notebook/__init__.py @@ -0,0 +1,3 @@ +from .magic import load_ipython_extension + +__all__ = ["load_ipython_extension"] diff --git a/investing_algorithm_framework/notebook/magic.py b/investing_algorithm_framework/notebook/magic.py new file mode 100644 index 00000000..7f51f1b3 --- /dev/null +++ b/investing_algorithm_framework/notebook/magic.py @@ -0,0 +1,401 @@ +""" +IPython magic commands for running backtests directly in Jupyter notebooks. + +Usage: + # Load the extension + %load_ext investing_algorithm_framework + + # Cell magic — define and run a backtest inline + %%backtest --start 2023-01-01 --end 2023-12-31 \ + --initial-amount 10000 -o results + from investing_algorithm_framework import TradingStrategy, DataSource + + class MyStrategy(TradingStrategy): + time_unit = "DAY" + interval = 1 + data_sources = [ + DataSource( + identifier="btc", + symbol="BTC/EUR", + data_type="OHLCV", + time_frame="1d", + ) + ] + + def run_strategy(self, context, data): + ... + + # Line magic — run a strategy from an existing file + %backtest strategies/my_strategy.py \ + --start 2023-01-01 --end 2023-12-31 -o results +""" + +import argparse +import os +import shlex +import sys +from datetime import datetime, timezone +from pathlib import Path + +from IPython.core.magic import ( + Magics, + magics_class, + cell_magic, + line_magic, +) + + +def _build_parser(): + """Build the argument parser for the backtest magic.""" + parser = argparse.ArgumentParser( + prog="%%backtest", + description="Run a backtest from a Jupyter notebook cell.", + add_help=False, + ) + parser.add_argument( + "strategy_path", + nargs="?", + default=None, + help="Path to a .py file containing a TradingStrategy subclass " + "(line magic only).", + ) + parser.add_argument( + "--start", + required=True, + help="Backtest start date (YYYY-MM-DD or YYYY-MM-DD-HH).", + ) + parser.add_argument( + "--end", + default=None, + help="Backtest end date (YYYY-MM-DD or YYYY-MM-DD-HH). " + "Defaults to now.", + ) + parser.add_argument( + "--initial-amount", + type=float, + default=1000.0, + help="Initial portfolio balance (default: 1000).", + ) + parser.add_argument( + "--market", + default=None, + help="Market identifier (e.g. BITVAVO, COINBASE).", + ) + parser.add_argument( + "--trading-symbol", + default=None, + help="Trading / quote currency (e.g. EUR, USD).", + ) + parser.add_argument( + "-o", + "--output", + default=None, + help="Variable name to store the Backtest result in the " + "notebook namespace.", + ) + parser.add_argument( + "--vectorized", + action="store_true", + default=False, + help="Use vectorized backtesting (run_vector_backtest) instead " + "of event-driven.", + ) + parser.add_argument( + "--show-progress", + action="store_true", + default=False, + help="Show progress bars during the backtest.", + ) + parser.add_argument( + "--show-report", + action="store_true", + default=False, + help="Display an inline HTML report after the backtest completes.", + ) + parser.add_argument( + "--resource-dir", + default=None, + help="Resource directory path. Defaults to ./resources.", + ) + parser.add_argument( + "--risk-free-rate", + type=float, + default=None, + help="Risk-free rate for metrics calculation.", + ) + parser.add_argument( + "--snapshot-interval", + default="DAILY", + help="Snapshot interval (DAILY, TRADE_CLOSE). Default: DAILY.", + ) + parser.add_argument( + "--fill-missing-data", + action="store_true", + default=True, + help="Fill missing time-series entries (default: True).", + ) + parser.add_argument( + "--no-fill-missing-data", + action="store_true", + default=False, + help="Do not fill missing time-series entries.", + ) + return parser + + +def _parse_date(value): + """Parse a date string into a timezone-aware datetime (UTC).""" + for fmt in ("%Y-%m-%d-%H", "%Y-%m-%d"): + try: + dt = datetime.strptime(value, fmt) + # Round to the nearest hour as required by BacktestDateRange + return dt.replace( + minute=0, second=0, microsecond=0, tzinfo=timezone.utc + ) + except ValueError: + continue + raise ValueError( + f"Cannot parse date '{value}'. " + "Expected format: YYYY-MM-DD or YYYY-MM-DD-HH." + ) + + +def _find_strategy_classes(namespace): + """Find all TradingStrategy subclasses in a namespace dict.""" + from investing_algorithm_framework.app.strategy import TradingStrategy + + strategies = [] + + for obj in namespace.values(): + try: + if ( + isinstance(obj, type) + and issubclass(obj, TradingStrategy) + and obj is not TradingStrategy + ): + strategies.append(obj) + except TypeError: + continue + + return strategies + + +def _find_strategy_classes_from_file(file_path): + """Import a .py file and return all TradingStrategy subclasses in it.""" + import importlib.util + + path = Path(file_path).resolve() + + if not path.exists(): + raise FileNotFoundError(f"Strategy file not found: {path}") + + spec = importlib.util.spec_from_file_location( + f"_backtest_magic_{path.stem}", str(path) + ) + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + + return _find_strategy_classes(vars(module)) + + +def _run_backtest(args, strategies): + """Execute the backtest using the framework's App API.""" + from investing_algorithm_framework.create_app import create_app + from investing_algorithm_framework.domain import ( + BacktestDateRange, + PortfolioConfiguration, + SnapshotInterval, + RESOURCE_DIRECTORY, + ) + + start_date = _parse_date(args.start) + end_date = ( + _parse_date(args.end) if args.end else + datetime.now(timezone.utc).replace( + minute=0, second=0, microsecond=0 + ) + ) + + backtest_date_range = BacktestDateRange( + start_date=start_date, end_date=end_date + ) + + snapshot_interval = SnapshotInterval.from_value(args.snapshot_interval) + fill_missing = not args.no_fill_missing_data + + resource_dir = args.resource_dir or os.path.join( + os.getcwd(), "resources" + ) + + config = {RESOURCE_DIRECTORY: resource_dir} + app = create_app(config=config) + + # Configure portfolio if market info is provided + if args.market and args.trading_symbol: + app.add_portfolio_configuration( + PortfolioConfiguration( + initial_balance=args.initial_amount, + market=args.market, + trading_symbol=args.trading_symbol, + ) + ) + + if len(strategies) == 0: + raise RuntimeError( + "No TradingStrategy subclass found. " + "Define a class that inherits from TradingStrategy " + "in your cell or strategy file." + ) + + # Use the first strategy found + strategy_cls = strategies[0] + strategy = strategy_cls() + + # Determine market/trading_symbol from strategy if not from CLI + market = args.market + trading_symbol = args.trading_symbol + + if args.vectorized: + backtest = app.run_vector_backtest( + strategy=strategy, + backtest_date_range=backtest_date_range, + initial_amount=args.initial_amount, + market=market, + trading_symbol=trading_symbol, + snapshot_interval=snapshot_interval, + risk_free_rate=args.risk_free_rate, + show_progress=args.show_progress, + fill_missing_data=fill_missing, + ) + else: + backtest = app.run_backtest( + strategy=strategy, + backtest_date_range=backtest_date_range, + initial_amount=args.initial_amount, + market=market, + trading_symbol=trading_symbol, + snapshot_interval=snapshot_interval, + risk_free_rate=args.risk_free_rate, + show_progress=args.show_progress, + fill_missing_data=fill_missing, + ) + + return backtest + + +@magics_class +class BacktestMagics(Magics): + """IPython magic commands for the investing-algorithm-framework.""" + + @cell_magic + def backtest(self, line, cell): + """ + %%backtest — define a strategy inline and run a backtest. + + Example:: + + %%backtest --start 2023-01-01 --end 2023-12-31 \\ + --initial-amount 10000 --market BITVAVO \\ + --trading-symbol EUR -o results + + from investing_algorithm_framework import ( + TradingStrategy, DataSource, TimeUnit + ) + + class MyStrategy(TradingStrategy): + time_unit = TimeUnit.DAY + interval = 1 + data_sources = [ + DataSource( + identifier="btc", + symbol="BTC/EUR", + data_type="OHLCV", + time_frame="1d", + ) + ] + + def run_strategy(self, context, data): + ... + """ + parser = _build_parser() + + try: + args = parser.parse_args(shlex.split(line)) + except SystemExit: + return + + # Execute the cell code in a fresh namespace that inherits + # from the user's notebook namespace so imports are available + cell_ns = dict(self.shell.user_ns) + exec(compile(cell, "<%%backtest>", "exec"), cell_ns) # noqa: S102 + + strategies = _find_strategy_classes(cell_ns) + backtest = _run_backtest(args, strategies) + + if args.show_report: + self._show_report(backtest) + + if args.output: + self.shell.user_ns[args.output] = backtest + print(f"Backtest result stored in '{args.output}'") + else: + return backtest + + @line_magic + def backtest(self, line): # noqa: F811 + """ + %backtest — run a backtest from an existing strategy file. + + Example:: + + %backtest strategies/my_strategy.py \\ + --start 2023-01-01 --end 2023-12-31 -o results + """ + parser = _build_parser() + + try: + args = parser.parse_args(shlex.split(line)) + except SystemExit: + return + + if args.strategy_path is None: + print( + "Usage: %backtest " + "--start YYYY-MM-DD [options]" + ) + return + + strategies = _find_strategy_classes_from_file(args.strategy_path) + backtest = _run_backtest(args, strategies) + + if args.show_report: + self._show_report(backtest) + + if args.output: + self.shell.user_ns[args.output] = backtest + print(f"Backtest result stored in '{args.output}'") + else: + return backtest + + @staticmethod + def _show_report(backtest): + """Display an inline HTML report in the notebook.""" + try: + from investing_algorithm_framework.app.reporting import ( + BacktestReport, + ) + + report = BacktestReport(backtests=[backtest]) + report.show() + except Exception as e: + print(f"Could not display report: {e}") + + +def load_ipython_extension(ipython): + """ + Entry point called by ``%load_ext investing_algorithm_framework``. + + Registers the ``%backtest`` / ``%%backtest`` magic commands. + """ + ipython.register_magics(BacktestMagics) diff --git a/tests/notebook/__init__.py b/tests/notebook/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/notebook/test_magic.py b/tests/notebook/test_magic.py new file mode 100644 index 00000000..ddf92591 --- /dev/null +++ b/tests/notebook/test_magic.py @@ -0,0 +1,151 @@ +from unittest import TestCase +from unittest.mock import MagicMock + +from investing_algorithm_framework.notebook.magic import ( + _build_parser, + _parse_date, + _find_strategy_classes, + BacktestMagics, + load_ipython_extension, +) + + +class TestParseDate(TestCase): + + def test_parse_date_ymd(self): + dt = _parse_date("2023-06-15") + self.assertEqual(dt.year, 2023) + self.assertEqual(dt.month, 6) + self.assertEqual(dt.day, 15) + self.assertEqual(dt.hour, 0) + self.assertIsNotNone(dt.tzinfo) + + def test_parse_date_ymdh(self): + dt = _parse_date("2023-06-15-14") + self.assertEqual(dt.hour, 14) + self.assertEqual(dt.minute, 0) + self.assertIsNotNone(dt.tzinfo) + + def test_parse_date_invalid(self): + with self.assertRaises(ValueError): + _parse_date("not-a-date") + + def test_parse_date_rounds_to_hour(self): + dt = _parse_date("2023-01-01") + self.assertEqual(dt.minute, 0) + self.assertEqual(dt.second, 0) + self.assertEqual(dt.microsecond, 0) + + +class TestBuildParser(TestCase): + + def test_minimal_args(self): + parser = _build_parser() + args = parser.parse_args(["--start", "2023-01-01"]) + self.assertEqual(args.start, "2023-01-01") + self.assertIsNone(args.end) + self.assertEqual(args.initial_amount, 1000.0) + self.assertIsNone(args.output) + self.assertFalse(args.vectorized) + self.assertFalse(args.show_report) + + def test_full_args(self): + parser = _build_parser() + args = parser.parse_args([ + "--start", "2023-01-01", + "--end", "2023-12-31", + "--initial-amount", "5000", + "--market", "BITVAVO", + "--trading-symbol", "EUR", + "-o", "results", + "--vectorized", + "--show-report", + "--show-progress", + ]) + self.assertEqual(args.initial_amount, 5000.0) + self.assertEqual(args.market, "BITVAVO") + self.assertEqual(args.trading_symbol, "EUR") + self.assertEqual(args.output, "results") + self.assertTrue(args.vectorized) + self.assertTrue(args.show_report) + self.assertTrue(args.show_progress) + + def test_strategy_path_line_magic(self): + parser = _build_parser() + args = parser.parse_args([ + "my_strategy.py", "--start", "2023-01-01" + ]) + self.assertEqual(args.strategy_path, "my_strategy.py") + + def test_risk_free_rate(self): + parser = _build_parser() + args = parser.parse_args([ + "--start", "2023-01-01", + "--risk-free-rate", "0.04", + ]) + self.assertEqual(args.risk_free_rate, 0.04) + + def test_no_fill_missing_data(self): + parser = _build_parser() + args = parser.parse_args([ + "--start", "2023-01-01", + "--no-fill-missing-data", + ]) + self.assertTrue(args.no_fill_missing_data) + + +class TestFindStrategyClasses(TestCase): + + def test_finds_strategy_subclass(self): + from investing_algorithm_framework.app.strategy import TradingStrategy + + class DummyStrategy(TradingStrategy): + time_unit = "DAY" + interval = 1 + + ns = {"DummyStrategy": DummyStrategy, "x": 42, "s": "hello"} + found = _find_strategy_classes(ns) + self.assertEqual(len(found), 1) + self.assertIs(found[0], DummyStrategy) + + def test_finds_multiple_strategies(self): + from investing_algorithm_framework.app.strategy import TradingStrategy + + class StratA(TradingStrategy): + time_unit = "DAY" + interval = 1 + + class StratB(TradingStrategy): + time_unit = "HOUR" + interval = 4 + + ns = {"StratA": StratA, "StratB": StratB} + found = _find_strategy_classes(ns) + self.assertEqual(len(found), 2) + + def test_ignores_base_class(self): + from investing_algorithm_framework.app.strategy import TradingStrategy + + ns = {"TradingStrategy": TradingStrategy} + found = _find_strategy_classes(ns) + self.assertEqual(len(found), 0) + + def test_ignores_non_classes(self): + ns = {"a": 1, "b": "text", "c": [1, 2]} + found = _find_strategy_classes(ns) + self.assertEqual(len(found), 0) + + +class TestLoadExtension(TestCase): + + def test_registers_magics(self): + mock_ipython = MagicMock() + load_ipython_extension(mock_ipython) + mock_ipython.register_magics.assert_called_once_with(BacktestMagics) + + def test_top_level_load_ext(self): + from investing_algorithm_framework import load_ipython_extension \ + as top_load + mock_ipython = MagicMock() + top_load(mock_ipython) + mock_ipython.register_magics.assert_called_once()