Skip to content

Commit 8ad46ba

Browse files
committed
feat: add context.record() for tracking custom variables during backtests
- Add record(**kwargs) method to Context for event-driven backtests - Add generate_recorded_values() method to TradingStrategy for vectorized backtests - Add recorded_values field to BacktestRun with serialization support - Wire recorded values into EventBacktestService and VectorBacktestService - Add 8 unit tests for record functionality Closes #455
1 parent 127c3b7 commit 8ad46ba

6 files changed

Lines changed: 312 additions & 0 deletions

File tree

investing_algorithm_framework/app/context.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
self._blotter = None
5454
self._fx_rate_provider = None
5555
self._base_currency = None
56+
self._recorded_values = {} # key -> list of (datetime, value)
5657

5758
def _validate_target_symbol(self, target_symbol, market=None):
5859
"""
@@ -2334,3 +2335,64 @@ def get_transactions(self):
23342335
list[Transaction]: Recorded transactions.
23352336
"""
23362337
return self._blotter.get_transactions()
2338+
2339+
def record(self, **kwargs):
2340+
"""
2341+
Record arbitrary key-value pairs at the current backtest timestamp.
2342+
2343+
This method allows you to store any custom indicator, metric, or
2344+
variable during a backtest. Each key creates a time series of
2345+
values that can be retrieved after the backtest completes via
2346+
``BacktestRun.recorded_values``.
2347+
2348+
The values are stored as a list of ``(datetime, value)`` tuples
2349+
per key, allowing you to track any indicator over time.
2350+
2351+
This method only records during backtesting. In live mode it is
2352+
a no-op.
2353+
2354+
Args:
2355+
**kwargs: Arbitrary key-value pairs to record. Keys are
2356+
strings, values can be any type (float, int, str,
2357+
dict, list, etc.).
2358+
2359+
Example::
2360+
2361+
def on_run(self, context, data):
2362+
context.record(
2363+
rsi=compute_rsi(data),
2364+
sma_20=compute_sma(data, 20),
2365+
signal_strength=0.85,
2366+
)
2367+
"""
2368+
is_backtest = self.configuration_service.config.get(
2369+
BACKTESTING_FLAG, False
2370+
)
2371+
2372+
if not is_backtest:
2373+
return
2374+
2375+
current_datetime = self.configuration_service.config.get(
2376+
INDEX_DATETIME
2377+
)
2378+
2379+
for key, value in kwargs.items():
2380+
if key not in self._recorded_values:
2381+
self._recorded_values[key] = []
2382+
self._recorded_values[key].append((current_datetime, value))
2383+
2384+
def get_recorded_values(self):
2385+
"""
2386+
Get all recorded values from the context.
2387+
2388+
Returns:
2389+
dict: A dictionary mapping keys to lists of
2390+
``(datetime, value)`` tuples.
2391+
"""
2392+
return self._recorded_values
2393+
2394+
def clear_recorded_values(self):
2395+
"""
2396+
Clear all recorded values from the context.
2397+
"""
2398+
self._recorded_values = {}

investing_algorithm_framework/app/strategy.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,32 @@ def generate_scale_out_signals(
850850
"""
851851
return None
852852

853+
def generate_recorded_values(
854+
self, data: Dict[str, Any]
855+
) -> Union[Dict[str, pd.Series], None]:
856+
"""
857+
Optional method to generate recorded values for vectorized
858+
backtesting. Override this to record arbitrary indicators,
859+
metrics, or variables as time series during a vectorized
860+
backtest.
861+
862+
Each key in the returned dict becomes a recorded variable
863+
with the Series index as timestamps and the Series values
864+
as the recorded data.
865+
866+
This is the vectorized equivalent of calling
867+
``context.record()`` in event-driven backtests.
868+
869+
Args:
870+
data (Dict[str, Any]): The market data for the strategy.
871+
872+
Returns:
873+
Dict[str, Series] | None: A dictionary where keys are
874+
variable names and values are pandas Series with the
875+
recorded values. Return None to not record anything.
876+
"""
877+
return None
878+
853879
def on_trade_closed(self, context: Context, trade: Trade):
854880
pass
855881

investing_algorithm_framework/domain/backtesting/backtest_run.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ class BacktestRun:
112112
metadata: Dict[str, str] = field(default_factory=dict)
113113
signals: Dict[str, Dict[str, Any]] = field(default_factory=dict)
114114
signal_events: List[Dict[str, Any]] = field(default_factory=list)
115+
recorded_values: Dict[str, List] = field(default_factory=dict)
115116

116117
def to_dict(self) -> dict:
117118
"""
@@ -167,6 +168,16 @@ def ensure_iso(value):
167168
"date": ensure_iso(evt["date"])
168169
} for evt in self.signal_events
169170
],
171+
"recorded_values": {
172+
key: [
173+
{
174+
"datetime": ensure_iso(entry[0]),
175+
"value": entry[1]
176+
}
177+
for entry in entries
178+
]
179+
for key, entries in self.recorded_values.items()
180+
},
170181
}
171182

172183
@staticmethod
@@ -330,10 +341,26 @@ def open(directory_path: Union[str, Path]) -> 'BacktestRun':
330341
pass
331342
signal_events.append(parsed)
332343

344+
# Parse recorded_values
345+
raw_recorded = data.pop("recorded_values", {})
346+
recorded_values = {}
347+
for key, entries in raw_recorded.items():
348+
parsed_entries = []
349+
for entry in entries:
350+
dt = entry.get("datetime")
351+
if isinstance(dt, str):
352+
try:
353+
dt = datetime.fromisoformat(dt)
354+
except (ValueError, TypeError):
355+
pass
356+
parsed_entries.append((dt, entry.get("value")))
357+
recorded_values[key] = parsed_entries
358+
333359
return BacktestRun(
334360
backtest_metrics=backtest_metrics,
335361
signals=signals,
336362
signal_events=signal_events,
363+
recorded_values=recorded_values,
337364
**data
338365
)
339366

investing_algorithm_framework/infrastructure/services/backtesting/event_backtest_service.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ def run(
112112
backtest_date_range=backtest_date_range,
113113
number_of_runs=event_loop_service.total_number_of_runs,
114114
risk_free_rate=risk_free_rate,
115+
recorded_values=event_loop_service.context
116+
.get_recorded_values(),
115117
)
116118

117119
def generate_schedule(
@@ -175,6 +177,7 @@ def _create_backtest_run(
175177
backtest_date_range: BacktestDateRange,
176178
number_of_runs: int,
177179
risk_free_rate: float,
180+
recorded_values: dict = None,
178181
) -> BacktestRun:
179182
"""
180183
Create a BacktestRun from the current state after event loop execution.
@@ -184,6 +187,7 @@ def _create_backtest_run(
184187
backtest_date_range: The date range of the backtest.
185188
number_of_runs: Total number of strategy executions.
186189
risk_free_rate: Risk-free rate for metrics calculation.
190+
recorded_values: Optional dict of recorded values from context.
187191
188192
Returns:
189193
BacktestRun: The completed backtest run with metrics.
@@ -215,6 +219,7 @@ def _create_backtest_run(
215219
positions=self._position_repository.get_all(
216220
{"portfolio": portfolio.id}
217221
),
222+
recorded_values=recorded_values or {},
218223
)
219224

220225
# Calculate and add metrics

investing_algorithm_framework/infrastructure/services/backtesting/vector_backtest_service.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def run(
7575
scale_in_signals = strategy.generate_scale_in_signals(data)
7676
scale_out_signals = strategy.generate_scale_out_signals(data)
7777

78+
# Generate optional recorded values
79+
raw_recorded = strategy.generate_recorded_values(data)
80+
7881
if scale_in_signals is None:
7982
scale_in_signals = buy_signals
8083

@@ -798,6 +801,7 @@ def _partial_close(sym, sym_data, price, date, sell_pct):
798801
symbols=list(buy_signals.keys()),
799802
signals=raw_signals,
800803
signal_events=signal_events,
804+
recorded_values=self._convert_recorded_values(raw_recorded),
801805
)
802806

803807
# Create backtest metrics
@@ -806,6 +810,34 @@ def _partial_close(sym, sym_data, price, date, sell_pct):
806810
)
807811
return run
808812

813+
@staticmethod
814+
def _convert_recorded_values(raw_recorded):
815+
"""
816+
Convert recorded values from pandas Series to list-of-tuples format.
817+
818+
Args:
819+
raw_recorded: Dict[str, pd.Series] or None from
820+
strategy.generate_recorded_values().
821+
822+
Returns:
823+
Dict[str, List[Tuple[datetime, Any]]]: Converted values.
824+
"""
825+
if raw_recorded is None:
826+
return {}
827+
828+
recorded_values = {}
829+
for key, series in raw_recorded.items():
830+
entries = []
831+
for ts, val in series.items():
832+
dt = ts
833+
if isinstance(dt, pd.Timestamp):
834+
dt = dt.to_pydatetime()
835+
if hasattr(dt, 'tzinfo') and dt.tzinfo is None:
836+
dt = dt.replace(tzinfo=timezone.utc)
837+
entries.append((dt, val))
838+
recorded_values[key] = entries
839+
return recorded_values
840+
809841
@staticmethod
810842
def get_most_granular_ohlcv_data_source(data_sources):
811843
"""

0 commit comments

Comments
 (0)