diff --git a/investing_algorithm_framework/domain/models/time_frame.py b/investing_algorithm_framework/domain/models/time_frame.py index c17d454f..ca9a2386 100644 --- a/investing_algorithm_framework/domain/models/time_frame.py +++ b/investing_algorithm_framework/domain/models/time_frame.py @@ -10,12 +10,16 @@ class TimeFrame(Enum): FIVE_MINUTE = "5m" TEN_MINUTE = "10m" FIFTEEN_MINUTE = "15m" + TWENTY_MINUTE = "20m" THIRTY_MINUTE = "30m" ONE_HOUR = "1h" TWO_HOUR = "2h" FOUR_HOUR = "4h" + SIX_HOUR = "6h" + EIGHT_HOUR = "8h" TWELVE_HOUR = "12h" ONE_DAY = "1d" + THREE_DAY = "3d" ONE_WEEK = "1W" ONE_MONTH = "1M" ONE_YEAR = "1Y" @@ -36,7 +40,16 @@ def from_string(value: str): if value == entry.value.replace("H", "h"): return entry - # For hour timeframes compare with and without H + # For hour timeframes compare with and without h + if "h" in entry.value: + + if value == entry.value: + return entry + + if value == entry.value.replace("h", "H"): + return entry + + # For day timeframes compare with and without D if "d" in entry.value: if value == entry.value: @@ -100,6 +113,9 @@ def amount_of_minutes(self): if self.equals(TimeFrame.FIFTEEN_MINUTE): return 15 + if self.equals(TimeFrame.TWENTY_MINUTE): + return 20 + if self.equals(TimeFrame.THIRTY_MINUTE): return 30 @@ -112,12 +128,21 @@ def amount_of_minutes(self): if self.equals(TimeFrame.FOUR_HOUR): return 240 + if self.equals(TimeFrame.SIX_HOUR): + return 360 + + if self.equals(TimeFrame.EIGHT_HOUR): + return 480 + if self.equals(TimeFrame.TWELVE_HOUR): return 720 if self.equals(TimeFrame.ONE_DAY): return 1440 + if self.equals(TimeFrame.THREE_DAY): + return 4320 + if self.equals(TimeFrame.ONE_WEEK): return 10080 diff --git a/investing_algorithm_framework/services/metrics/drawdown.py b/investing_algorithm_framework/services/metrics/drawdown.py index b5356e2d..1ae173c9 100644 --- a/investing_algorithm_framework/services/metrics/drawdown.py +++ b/investing_algorithm_framework/services/metrics/drawdown.py @@ -109,16 +109,17 @@ def get_max_drawdown(snapshots: List[PortfolioSnapshot]) -> float: def get_max_daily_drawdown(snapshots: List[PortfolioSnapshot]) -> float: """ - Calculate the maximum daily drawdown of the portfolio as a percentage from the peak. + Calculate the worst single-day decline of the portfolio as a percentage. - This is the largest drop in equity (in percentage) from a peak to a trough - during the backtest period, calculated on a daily basis. + This is the largest day-over-day percentage drop in equity, + NOT the peak-to-trough drawdown (use get_max_drawdown for that). Args: snapshots (List[PortfolioSnapshot]): List of portfolio snapshots Returns: - float: The maximum daily drawdown as a negative percentage (e.g., -5.0 for a 5% drawdown). + float: The maximum single-day drawdown as a positive percentage + (e.g., 0.05 for a 5% single-day decline). """ # Create DataFrame from snapshots data = [(s.created_at, s.total_value) for s in snapshots] @@ -136,36 +137,31 @@ def get_max_daily_drawdown(snapshots: List[PortfolioSnapshot]) -> float: # Filter out non-positive values positive_values = daily_df[daily_df['total_value'] > 0]['total_value'] - if positive_values.empty: + if positive_values.empty or len(positive_values) < 2: return 0.0 - peak = positive_values.iloc[0] - max_daily_drawdown_pct = 0.0 + # Compute day-over-day returns; the worst single-day decline + # is the most negative return (ignore positive returns) + daily_returns = positive_values.pct_change().dropna() + negative_returns = daily_returns[daily_returns < 0] - for equity in positive_values: - if equity > peak: - peak = equity - - # Avoid division by zero (shouldn't happen but extra safety) - if peak <= 0: - continue - - drawdown_pct = (equity - peak) / peak - max_daily_drawdown_pct = min(max_daily_drawdown_pct, drawdown_pct) + if negative_returns.empty: + return 0.0 - return abs(max_daily_drawdown_pct) # Return as positive percentage + return abs(negative_returns.min()) def get_max_drawdown_duration(snapshots: List[PortfolioSnapshot]) -> int: """ Calculate the maximum duration of drawdown in days. - This is the longest period where the portfolio equity was below its peak. + This is the longest period (in calendar days) where the portfolio + equity was below its peak. Args: snapshots (List[PortfolioSnapshot]): List of portfolio snapshots Returns: - int: The maximum drawdown duration in days. + int: The maximum drawdown duration in calendar days. """ equity_curve = get_equity_curve(snapshots) if not equity_curve: @@ -173,17 +169,26 @@ def get_max_drawdown_duration(snapshots: List[PortfolioSnapshot]) -> int: peak = equity_curve[0][0] max_duration = 0 - current_duration = 0 + drawdown_start = None - for equity, _ in equity_curve: + for equity, timestamp in equity_curve: if equity < peak: - current_duration += 1 + # Entering or continuing a drawdown + if drawdown_start is None: + drawdown_start = timestamp else: - max_duration = max(max_duration, current_duration) - current_duration = 0 - peak = equity # Reset peak to current equity + # Recovered to or above the peak + if drawdown_start is not None: + elapsed = (timestamp - drawdown_start).days + max_duration = max(max_duration, elapsed) + drawdown_start = None + peak = equity - max_duration = max(max_duration, current_duration) # Final check + # If still in drawdown at the end of the series + if drawdown_start is not None and len(equity_curve) > 0: + last_timestamp = equity_curve[-1][1] + elapsed = (last_timestamp - drawdown_start).days + max_duration = max(max_duration, elapsed) return max_duration diff --git a/investing_algorithm_framework/services/metrics/equity_curve.py b/investing_algorithm_framework/services/metrics/equity_curve.py index bb4a7d29..af8536f8 100644 --- a/investing_algorithm_framework/services/metrics/equity_curve.py +++ b/investing_algorithm_framework/services/metrics/equity_curve.py @@ -21,4 +21,7 @@ def get_equity_curve( total_size = snapshot.total_value series.append((total_size, timestamp)) + # Sort by timestamp to ensure chronological order + series.sort(key=lambda x: x[1]) + return series diff --git a/pyproject.toml b/pyproject.toml index f774f323..8e9bba52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "investing-algorithm-framework" -version = "v7.29.0" +version = "v7.30.0" description = "A framework for creating trading bots" authors = ["MDUYN"] readme = "README.md" diff --git a/tests/domain/models/test_time_frame.py b/tests/domain/models/test_time_frame.py new file mode 100644 index 00000000..386fdb07 --- /dev/null +++ b/tests/domain/models/test_time_frame.py @@ -0,0 +1,169 @@ +from unittest import TestCase + +from investing_algorithm_framework.domain.models.time_frame import TimeFrame + + +class TestTimeFrameNewEnumValues(TestCase): + """Test that the new enum members have the correct string values.""" + + def test_twenty_minute_value(self): + self.assertEqual(TimeFrame.TWENTY_MINUTE.value, "20m") + + def test_six_hour_value(self): + self.assertEqual(TimeFrame.SIX_HOUR.value, "6h") + + def test_eight_hour_value(self): + self.assertEqual(TimeFrame.EIGHT_HOUR.value, "8h") + + def test_three_day_value(self): + self.assertEqual(TimeFrame.THREE_DAY.value, "3d") + + +class TestTimeFrameNewAmountOfMinutes(TestCase): + """Test that amount_of_minutes returns correct values for new members.""" + + def test_twenty_minute_minutes(self): + self.assertEqual(TimeFrame.TWENTY_MINUTE.amount_of_minutes, 20) + + def test_six_hour_minutes(self): + self.assertEqual(TimeFrame.SIX_HOUR.amount_of_minutes, 360) + + def test_eight_hour_minutes(self): + self.assertEqual(TimeFrame.EIGHT_HOUR.amount_of_minutes, 480) + + def test_three_day_minutes(self): + self.assertEqual(TimeFrame.THREE_DAY.amount_of_minutes, 4320) + + +class TestTimeFrameNewFromString(TestCase): + """Test from_string parsing for new members, including case variants.""" + + def test_twenty_minute_lowercase(self): + self.assertEqual(TimeFrame.from_string("20m"), TimeFrame.TWENTY_MINUTE) + + def test_six_hour_lowercase(self): + self.assertEqual(TimeFrame.from_string("6h"), TimeFrame.SIX_HOUR) + + def test_six_hour_uppercase(self): + self.assertEqual(TimeFrame.from_string("6H"), TimeFrame.SIX_HOUR) + + def test_eight_hour_lowercase(self): + self.assertEqual(TimeFrame.from_string("8h"), TimeFrame.EIGHT_HOUR) + + def test_eight_hour_uppercase(self): + self.assertEqual(TimeFrame.from_string("8H"), TimeFrame.EIGHT_HOUR) + + def test_three_day_lowercase(self): + self.assertEqual(TimeFrame.from_string("3d"), TimeFrame.THREE_DAY) + + def test_three_day_uppercase(self): + self.assertEqual(TimeFrame.from_string("3D"), TimeFrame.THREE_DAY) + + +class TestTimeFrameNewFromValue(TestCase): + """Test from_value parsing for new members.""" + + def test_twenty_minute_from_value_string(self): + self.assertEqual( + TimeFrame.from_value("20m"), TimeFrame.TWENTY_MINUTE + ) + + def test_six_hour_from_value_string(self): + self.assertEqual(TimeFrame.from_value("6h"), TimeFrame.SIX_HOUR) + + def test_eight_hour_from_value_string(self): + self.assertEqual(TimeFrame.from_value("8h"), TimeFrame.EIGHT_HOUR) + + def test_three_day_from_value_string(self): + self.assertEqual(TimeFrame.from_value("3d"), TimeFrame.THREE_DAY) + + def test_twenty_minute_from_value_enum(self): + self.assertEqual( + TimeFrame.from_value(TimeFrame.TWENTY_MINUTE), + TimeFrame.TWENTY_MINUTE, + ) + + def test_six_hour_from_value_enum(self): + self.assertEqual( + TimeFrame.from_value(TimeFrame.SIX_HOUR), TimeFrame.SIX_HOUR + ) + + def test_eight_hour_from_value_enum(self): + self.assertEqual( + TimeFrame.from_value(TimeFrame.EIGHT_HOUR), TimeFrame.EIGHT_HOUR + ) + + def test_three_day_from_value_enum(self): + self.assertEqual( + TimeFrame.from_value(TimeFrame.THREE_DAY), TimeFrame.THREE_DAY + ) + + +class TestTimeFrameNewEquals(TestCase): + """Test equals method for new members against string and enum values.""" + + def test_twenty_minute_equals_string(self): + self.assertTrue(TimeFrame.TWENTY_MINUTE.equals("20m")) + + def test_twenty_minute_equals_enum(self): + self.assertTrue(TimeFrame.TWENTY_MINUTE.equals(TimeFrame.TWENTY_MINUTE)) + + def test_twenty_minute_not_equals_other(self): + self.assertFalse(TimeFrame.TWENTY_MINUTE.equals(TimeFrame.THIRTY_MINUTE)) + + def test_six_hour_equals_string(self): + self.assertTrue(TimeFrame.SIX_HOUR.equals("6h")) + + def test_six_hour_equals_enum(self): + self.assertTrue(TimeFrame.SIX_HOUR.equals(TimeFrame.SIX_HOUR)) + + def test_eight_hour_equals_string(self): + self.assertTrue(TimeFrame.EIGHT_HOUR.equals("8h")) + + def test_eight_hour_equals_enum(self): + self.assertTrue(TimeFrame.EIGHT_HOUR.equals(TimeFrame.EIGHT_HOUR)) + + def test_three_day_equals_string(self): + self.assertTrue(TimeFrame.THREE_DAY.equals("3d")) + + def test_three_day_equals_enum(self): + self.assertTrue(TimeFrame.THREE_DAY.equals(TimeFrame.THREE_DAY)) + + def test_three_day_not_equals_other(self): + self.assertFalse(TimeFrame.THREE_DAY.equals(TimeFrame.ONE_DAY)) + + +class TestTimeFrameNewOrdering(TestCase): + """Test ordering of new members relative to their neighbors.""" + + # TWENTY_MINUTE sits between FIFTEEN_MINUTE and THIRTY_MINUTE + def test_fifteen_lt_twenty_minute(self): + self.assertLess(TimeFrame.FIFTEEN_MINUTE, TimeFrame.TWENTY_MINUTE) + + def test_twenty_lt_thirty_minute(self): + self.assertLess(TimeFrame.TWENTY_MINUTE, TimeFrame.THIRTY_MINUTE) + + # SIX_HOUR sits between FOUR_HOUR and EIGHT_HOUR + def test_four_hour_lt_six_hour(self): + self.assertLess(TimeFrame.FOUR_HOUR, TimeFrame.SIX_HOUR) + + def test_six_hour_lt_eight_hour(self): + self.assertLess(TimeFrame.SIX_HOUR, TimeFrame.EIGHT_HOUR) + + # EIGHT_HOUR sits between SIX_HOUR and TWELVE_HOUR + def test_eight_hour_lt_twelve_hour(self): + self.assertLess(TimeFrame.EIGHT_HOUR, TimeFrame.TWELVE_HOUR) + + # THREE_DAY sits between ONE_DAY and ONE_WEEK + def test_one_day_lt_three_day(self): + self.assertLess(TimeFrame.ONE_DAY, TimeFrame.THREE_DAY) + + def test_three_day_lt_one_week(self): + self.assertLess(TimeFrame.THREE_DAY, TimeFrame.ONE_WEEK) + + # Verify >= and <= also work + def test_twenty_minute_le_thirty_minute(self): + self.assertLessEqual(TimeFrame.TWENTY_MINUTE, TimeFrame.THIRTY_MINUTE) + + def test_eight_hour_ge_six_hour(self): + self.assertGreaterEqual(TimeFrame.EIGHT_HOUR, TimeFrame.SIX_HOUR) diff --git a/tests/services/metrics/test_drawdowns.py b/tests/services/metrics/test_drawdowns.py index 102707f9..1c87fdbb 100644 --- a/tests/services/metrics/test_drawdowns.py +++ b/tests/services/metrics/test_drawdowns.py @@ -1,8 +1,21 @@ import unittest +import random from datetime import datetime, timedelta from unittest.mock import MagicMock from investing_algorithm_framework import get_drawdown_series, \ - get_max_drawdown, get_max_drawdown_absolute + get_max_drawdown, get_max_drawdown_absolute, get_max_daily_drawdown, \ + get_max_drawdown_duration + + +def _make_snapshots(timestamps, values): + """Helper to create mock PortfolioSnapshot objects.""" + snapshots = [] + for ts, val in zip(timestamps, values): + snapshot = MagicMock() + snapshot.created_at = ts + snapshot.total_value = val + snapshots.append(snapshot) + return snapshots class TestDrawdownFunctions(unittest.TestCase): @@ -54,3 +67,323 @@ def test_max_drawdown(self): def test_max_drawdown_absolute(self): max_drawdown = get_max_drawdown_absolute(self.backtest_result.portfolio_snapshots) self.assertEqual(max_drawdown, 300) # 1200 - 900 = 300 + + +class TestGetMaxDailyDrawdown(unittest.TestCase): + """Tests for get_max_daily_drawdown — worst single-day decline.""" + + def test_worst_daily_decline_differs_from_max_drawdown(self): + """ + Max daily decline should be the worst single-day return, + not the peak-to-trough drawdown. + + Equity: [1000, 950, 900, 1100, 1300] + Daily returns: -5%, -5.26%, +22.2%, +18.2% + Worst daily: -5.26% (950→900) + Peak-to-trough: -10% (1000→900) + """ + timestamps = [ + datetime(2024, 1, 1), + datetime(2024, 1, 2), + datetime(2024, 1, 3), + datetime(2024, 1, 4), + datetime(2024, 1, 5), + ] + values = [1000, 950, 900, 1100, 1300] + snapshots = _make_snapshots(timestamps, values) + + result = get_max_daily_drawdown(snapshots) + # Worst single-day: (900 - 950) / 950 ≈ -0.05263 + expected = abs((900 - 950) / 950) + self.assertAlmostEqual(result, expected, places=4) + + def test_larger_daily_drop_in_middle(self): + """ + When the largest single-day drop is not the overall trough. + + Equity: [1000, 1200, 1100, 900, 1300] + Daily returns: +20%, -8.33%, -18.18%, +44.4% + Worst daily: -18.18% (1100→900) + Peak-to-trough: -25% (1200→900) + """ + timestamps = [ + datetime(2024, 1, 1), + datetime(2024, 1, 2), + datetime(2024, 1, 3), + datetime(2024, 1, 4), + datetime(2024, 1, 5), + ] + values = [1000, 1200, 1100, 900, 1300] + snapshots = _make_snapshots(timestamps, values) + + result = get_max_daily_drawdown(snapshots) + expected = abs((900 - 1100) / 1100) # ≈ 0.1818 + self.assertAlmostEqual(result, expected, places=4) + + def test_all_positive_returns(self): + """ + When all daily returns are positive, there is no decline, + so max daily drawdown should be 0. + + The function should only consider negative day-over-day returns. + """ + timestamps = [ + datetime(2024, 1, 1), + datetime(2024, 1, 2), + datetime(2024, 1, 3), + datetime(2024, 1, 4), + ] + values = [1000, 1100, 1200, 1400] + snapshots = _make_snapshots(timestamps, values) + + result = get_max_daily_drawdown(snapshots) + # No negative returns exist → worst daily drawdown = 0 + self.assertAlmostEqual(result, 0.0, places=6, + msg="All-positive returns should yield 0 drawdown") + + def test_single_snapshot(self): + """Single snapshot means no daily return — drawdown should be 0.""" + snapshots = _make_snapshots( + [datetime(2024, 1, 1)], [1000] + ) + result = get_max_daily_drawdown(snapshots) + self.assertEqual(result, 0.0) + + def test_resamples_intraday_to_daily(self): + """ + Multiple intra-day snapshots should be resampled to daily + (last value of the day). + + Raw: Day1 9am=1000, Day1 3pm=1200, Day2 9am=1100, Day3 9am=900 + Daily (last): [1200, 1100, 900] + Daily returns: -8.33%, -18.18% + Worst daily: -18.18% (NOT peak-to-trough of -25%) + """ + timestamps = [ + datetime(2024, 1, 1, 9, 0), + datetime(2024, 1, 1, 15, 0), + datetime(2024, 1, 2, 9, 0), + datetime(2024, 1, 3, 9, 0), + ] + values = [1000, 1200, 1100, 900] + snapshots = _make_snapshots(timestamps, values) + + result = get_max_daily_drawdown(snapshots) + # After resample: [1200, 1100, 900] + # Worst day-over-day return: (900 - 1100) / 1100 = -0.1818 + expected = abs((900 - 1100) / 1100) + self.assertAlmostEqual(result, expected, places=4) + + def test_flat_equity(self): + """Constant equity means no daily changes — drawdown = 0.""" + timestamps = [ + datetime(2024, 1, 1), + datetime(2024, 1, 2), + datetime(2024, 1, 3), + ] + values = [1000, 1000, 1000] + snapshots = _make_snapshots(timestamps, values) + + result = get_max_daily_drawdown(snapshots) + self.assertEqual(result, 0.0) + + +class TestGetMaxDrawdownDuration(unittest.TestCase): + """Tests for get_max_drawdown_duration — drawdown duration in actual days.""" + + def test_daily_snapshots_duration(self): + """ + With daily snapshots, drawdown duration should be in calendar days. + + Equity: [1000, 1200, 900, 1000, 1100, 1100, 1300] + Dates: Jan1 Jan2 Jan3 Jan4 Jan5 Jan6 Jan7 + Peak at Jan 2 (1200). Below peak: Jan 3-6. Recovery: Jan 7. + """ + timestamps = [ + datetime(2024, 1, 1), + datetime(2024, 1, 2), + datetime(2024, 1, 3), + datetime(2024, 1, 4), + datetime(2024, 1, 5), + datetime(2024, 1, 6), + datetime(2024, 1, 7), + ] + values = [1000, 1200, 900, 1000, 1100, 1100, 1300] + snapshots = _make_snapshots(timestamps, values) + + result = get_max_drawdown_duration(snapshots) + # With daily data, 4 snapshots below peak = 4 calendar days + self.assertGreaterEqual(result, 4) + + def test_weekly_snapshots_returns_days_not_snapshot_count(self): + """ + With weekly snapshots, should return calendar days, NOT snapshot count. + + Equity: [1000, 1200, 900, 1100, 1300] + Dates: Jan1 Jan8 Jan15 Jan22 Jan29 (weekly) + Peak at Jan 8 (1200). Below peak: Jan 15, Jan 22. Recovery: Jan 29. + Buggy code returns 2 (snapshot count). + Fixed code should return days: ≥7 (e.g. 14 or 21 depending on measurement). + """ + timestamps = [ + datetime(2024, 1, 1), + datetime(2024, 1, 8), + datetime(2024, 1, 15), + datetime(2024, 1, 22), + datetime(2024, 1, 29), + ] + values = [1000, 1200, 900, 1100, 1300] + snapshots = _make_snapshots(timestamps, values) + + result = get_max_drawdown_duration(snapshots) + self.assertGreater( + result, 2, + "Duration should be in calendar days, not snapshot count" + ) + + def test_no_drawdown(self): + """Monotonically increasing equity has no drawdown — duration = 0.""" + timestamps = [ + datetime(2024, 1, 1), + datetime(2024, 1, 2), + datetime(2024, 1, 3), + ] + values = [1000, 1100, 1200] + snapshots = _make_snapshots(timestamps, values) + + result = get_max_drawdown_duration(snapshots) + self.assertEqual(result, 0) + + def test_drawdown_extends_to_end_of_series(self): + """ + If the portfolio never recovers, the drawdown extends to the last + snapshot and should be measured in calendar days. + + Equity: [1000, 800, 700, 900] + Dates: Jan1 Jan8 Jan15 Jan22 (weekly) + Peak at Jan 1. Never recovered. + Buggy: returns 3 (snapshot count). + Fixed: should return ≥14 (calendar days from Jan 1 to Jan 22). + """ + timestamps = [ + datetime(2024, 1, 1), + datetime(2024, 1, 8), + datetime(2024, 1, 15), + datetime(2024, 1, 22), + ] + values = [1000, 800, 700, 900] + snapshots = _make_snapshots(timestamps, values) + + result = get_max_drawdown_duration(snapshots) + self.assertGreater( + result, 3, + "Duration should be in calendar days, not snapshot count" + ) + + def test_multiple_drawdown_periods_returns_longest(self): + """ + When there are multiple drawdown periods, return the longest one + in calendar days. + + Equity: [1000, 900, 1100, 1050, 1000, 900, 1200] + Dates: Jan1 Jan2 Jan3 Jan10 Jan17 Jan24 Jan31 + Drawdown 1: Jan 1→Jan 2 (1 day, 1 snapshot below peak) + Drawdown 2: Jan 3 peak (1100), below: Jan 10, Jan 17, Jan 24 + Recovery: Jan 31. Duration = 21+ days in calendar time. + """ + timestamps = [ + datetime(2024, 1, 1), + datetime(2024, 1, 2), + datetime(2024, 1, 3), + datetime(2024, 1, 10), + datetime(2024, 1, 17), + datetime(2024, 1, 24), + datetime(2024, 1, 31), + ] + values = [1000, 900, 1100, 1050, 1000, 900, 1200] + snapshots = _make_snapshots(timestamps, values) + + result = get_max_drawdown_duration(snapshots) + # Second drawdown period is the longest: 3 snapshots below peak + # but 21+ calendar days. Must be greater than snapshot count. + self.assertGreater( + result, 3, + "Duration should be in calendar days, not snapshot count" + ) + + def test_empty_snapshots(self): + """Empty snapshot list should return 0.""" + result = get_max_drawdown_duration([]) + self.assertEqual(result, 0) + + +class TestDrawdownConsistency(unittest.TestCase): + """Verify equity curve sort order and metric consistency.""" + + def test_unsorted_snapshots_produce_same_drawdown_as_sorted(self): + """ + Shuffled snapshots should produce the same max_drawdown + as chronologically sorted snapshots (functions should sort + internally or the equity curve should be timestamp-ordered). + """ + timestamps = [ + datetime(2024, 1, 1), + datetime(2024, 1, 2), + datetime(2024, 1, 3), + datetime(2024, 1, 4), + datetime(2024, 1, 5), + ] + values = [1000, 1200, 900, 1100, 1300] + + sorted_snapshots = _make_snapshots(timestamps, values) + + pairs = list(zip(timestamps, values)) + random.seed(42) + random.shuffle(pairs) + shuffled_ts, shuffled_vals = zip(*pairs) + shuffled_snapshots = _make_snapshots(shuffled_ts, shuffled_vals) + + sorted_result = get_max_drawdown(sorted_snapshots) + shuffled_result = get_max_drawdown(shuffled_snapshots) + + self.assertAlmostEqual(sorted_result, shuffled_result, places=6) + + def test_max_drawdown_matches_manual_computation_from_total_value(self): + """ + Max drawdown from the equity curve should match a manual + computation from the snapshot total_value fields. + """ + timestamps = [ + datetime(2024, 1, 1), + datetime(2024, 1, 2), + datetime(2024, 1, 3), + datetime(2024, 1, 4), + datetime(2024, 1, 5), + ] + values = [1000, 1200, 900, 1100, 1300] + snapshots = _make_snapshots(timestamps, values) + + # Manual: peak = 1200, trough = 900 + # Max drawdown = (1200 - 900) / 1200 = 0.25 + expected = 0.25 + result = get_max_drawdown(snapshots) + self.assertAlmostEqual(result, expected, places=6) + + def test_drawdown_series_timestamps_match_snapshots(self): + """ + The drawdown series timestamps should correspond to the + snapshot timestamps. + """ + timestamps = [ + datetime(2024, 1, 1), + datetime(2024, 1, 2), + datetime(2024, 1, 3), + ] + values = [1000, 900, 1100] + snapshots = _make_snapshots(timestamps, values) + + drawdown_series = get_drawdown_series(snapshots) + self.assertEqual(len(drawdown_series), len(timestamps)) + + for (_, ts), expected_ts in zip(drawdown_series, timestamps): + self.assertEqual(ts, expected_ts)