diff --git a/backtesting/lib.py b/backtesting/lib.py index fd05026a..1677616e 100644 --- a/backtesting/lib.py +++ b/backtesting/lib.py @@ -608,7 +608,7 @@ def _mp_task_run(args): data_shm, strategy, bt_kwargs, run_kwargs = args dfs, shms = zip(*(SharedMemoryManager.shm2df(i) for i in data_shm)) try: - return [stats.filter(regex='^[^_]') if stats['# Trades'] else None + return [stats.filter(regex='^[^_]') for stats in (Backtest(df, strategy, **bt_kwargs).run(**run_kwargs) for df in dfs)] finally: diff --git a/backtesting/test/_test.py b/backtesting/test/_test.py index 63045ce1..f1f12247 100644 --- a/backtesting/test/_test.py +++ b/backtesting/test/_test.py @@ -4,6 +4,7 @@ import sys import time import unittest +import warnings from concurrent.futures.process import ProcessPoolExecutor from contextlib import contextmanager from glob import glob @@ -982,6 +983,50 @@ def test_MultiBacktest(self): print(start_method, time.monotonic() - start_time) plot_heatmaps(heatmap.mean(axis=1), open_browser=False) + def test_MultiBacktest_keeps_zero_trade_runs(self): + datasets = [GOOG[:-4], GOOG[:-3], GOOG[:-2], GOOG[:-1], GOOG] + cases = { + 'all_false': ([False, False, False, False, False], [0, 0, 0, 0, 0]), + 'first_true_rest_false': ([True, False, False, False, False], [1, 0, 0, 0, 0]), + 'first_false_second_true': ([False, True, False, False, False], [0, 1, 0, 0, 0]), + } + + for name, (will_buys, expected_trades) in cases.items(): + class TestStrat(Strategy): + def init(self): + self.will_buy = will_buys[len(self.data.index) - 2144] + self.has_bought = False + + def next(self): + if not self.will_buy: + return + if self.position: + self.position.close() + if not self.has_bought: + self.buy() + self.has_bought = True + + with self.subTest(case=name), warnings.catch_warnings(): + warnings.filterwarnings( + 'ignore', + message='If you want to use multi-process optimization', + category=RuntimeWarning, + ) + result = MultiBacktest( + datasets, + TestStrat, + cash=10_000, + commission=.002, + exclusive_orders=True, + ).run() + + self.assertIsInstance(result, pd.DataFrame) + self.assertEqual(result.columns.tolist(), [0, 1, 2, 3, 4]) + self.assertIn('# Trades', result.index) + self.assertEqual(result.loc['# Trades'].astype(int).tolist(), expected_trades) + self.assertFalse(any(isinstance(value, pd.Series) for value in result.to_numpy().ravel())) + self.assertIn('Equity Final [$]', result.index) + class TestUtil(TestCase): def test_as_str(self):