Skip to content

Commit 042c8ab

Browse files
committed
feat(pipeline): risk-neutrality primitives (#504)
Adds first-class support for sector neutrality, beta neutralisation and multi-factor risk models via four new pipeline primitives plus a groups= parameter on the existing cross-sectional transforms. New factors (exported at the top level): * StaticPerSymbol(mapping, default=None) Broadcast a per-symbol dict (e.g. sector / market-cap) into a cross-sectional column. window=1, no inputs. * CrossSectionalMean(base, mask=None) Equal-weight cross-sectional mean per bar (the canonical 'market' series for beta neutralisation). * RollingBeta(target, market, window>=2) Time-series OLS beta via cov/var with a rolling window; returns null when var(market) == 0. * Neutralize(target, exposures=[...], mask=None, add_intercept=True) Per-bar OLS residualisation. Bars that are rank-deficient (n_surviving <= n_params) yield null residuals. Residuals are orthogonal to each exposure within each bar by construction. Existing transforms: * Factor.zscore(mask=None, groups=None) * Factor.demean(mask=None, groups=None) groups= accepts a dict[symbol, key] (broadcast via StaticPerSymbol) or any Factor that emits a per-symbol category; stats are computed within each (datetime, group) instead of across the whole bar — i.e. sector-relative z-scores. Tests: 14 new tests in tests/domain/pipeline/test_risk_neutrality.py covering broadcast, sector demean, market mean, beta of self == 1, zero-variance null, OLS orthogonality (residuals.sum() == 0 and residuals dot each exposure == 0), market-factor beta-strip, mask filtering and rank-deficient null behaviour. Full pipeline suite remains green (61 tests); full regression: 1346 passed.
1 parent 6a46b3a commit 042c8ab

8 files changed

Lines changed: 861 additions & 20 deletions

File tree

docusaurus/docs/Advanced Concepts/pipelines.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,59 @@ symbol survives the mask), `zscore` returns `null` rather than
108108
`inf`/`NaN`. Masked-out symbols are excluded from the bar's
109109
statistic *and* from the bar's output.
110110

111+
### Risk neutrality
112+
113+
When you want a factor's signal to be independent of structural
114+
exposures (sector, beta to the market, multi-factor risk model),
115+
use the built-in risk-neutrality primitives. They cover three
116+
common cases:
117+
118+
**Sector neutrality** — z-score or demean *within* each sector
119+
instead of across the whole universe by passing `groups=`. The
120+
mapping can be a `dict[symbol, sector]` or any `Factor` that
121+
emits a per-symbol category:
122+
123+
```python
124+
SECTORS = {"AAPL": "Tech", "MSFT": "Tech", "JPM": "Fin", ...}
125+
126+
class SectorNeutralMomentum(Pipeline):
127+
momentum = Returns(window=60)
128+
signal = momentum.zscore(groups=SECTORS) # z-score within sector
129+
```
130+
131+
**Beta neutralisation** — strip a factor's exposure to the market
132+
(or any other reference series) using `RollingBeta` and
133+
`Neutralize`:
134+
135+
```python
136+
from investing_algorithm_framework import (
137+
Returns, RollingBeta, CrossSectionalMean, Neutralize,
138+
)
139+
140+
class BetaNeutralAlpha(Pipeline):
141+
r = Returns(window=1)
142+
market = CrossSectionalMean(r) # equal-weight market
143+
beta = RollingBeta(r, market, window=60)
144+
alpha = Neutralize(r, exposures=[beta]) # market-neutral residual
145+
```
146+
147+
**Multi-factor risk model** — pass several exposures to
148+
`Neutralize` and the residual is orthogonal to all of them at
149+
each bar (per-bar OLS):
150+
151+
```python
152+
class FactorNeutralAlpha(Pipeline):
153+
r = Returns(window=1)
154+
size = StaticPerSymbol(MARKET_CAPS) # cross-sectional size
155+
val = BookToPrice()
156+
mom = Returns(window=252)
157+
residual = Neutralize(r, exposures=[size, val, mom])
158+
```
159+
160+
Bars where the system is rank-deficient (more exposures than
161+
surviving symbols) yield `null` residuals so they're skipped
162+
downstream rather than producing `NaN`.
163+
111164
### Factor algebra
112165

113166
Factors compose via the standard arithmetic operators. The framework

investing_algorithm_framework/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
FillModel, FullFill, VolumeBasedFill, \
3333
FXRateProvider, StaticFXRateProvider, \
3434
Pipeline, Factor, CustomFactor, Filter, \
35-
Returns, AverageDollarVolume, AverageTradedValue, SMA, RSI, Volatility
35+
Returns, AverageDollarVolume, AverageTradedValue, SMA, RSI, \
36+
Volatility, StaticPerSymbol, CrossSectionalMean, RollingBeta, \
37+
Neutralize
3638
from .infrastructure import AzureBlobStorageStateHandler, \
3739
CSVOHLCVDataProvider, CSVTickerDataProvider, CSVURLDataProvider, \
3840
JSONURLDataProvider, ParquetURLDataProvider, \
@@ -269,6 +271,10 @@
269271
"SMA",
270272
"RSI",
271273
"Volatility",
274+
"StaticPerSymbol",
275+
"CrossSectionalMean",
276+
"RollingBeta",
277+
"Neutralize",
272278
"load_ipython_extension",
273279
"get_cv_consistency",
274280
"get_normalized_stability",

investing_algorithm_framework/domain/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@
5757
SMA,
5858
RSI,
5959
Volatility,
60+
StaticPerSymbol,
61+
CrossSectionalMean,
62+
RollingBeta,
63+
Neutralize,
6064
)
6165

6266
__all__ = [
@@ -181,6 +185,10 @@
181185
"SMA",
182186
"RSI",
183187
"Volatility",
188+
"StaticPerSymbol",
189+
"CrossSectionalMean",
190+
"RollingBeta",
191+
"Neutralize",
184192
"Blotter",
185193
"DefaultBlotter",
186194
"SimulationBlotter",

investing_algorithm_framework/domain/pipeline/__init__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,15 @@
1010
from .filter import Filter
1111
from .pipeline import Pipeline
1212
from .factors import (
13-
Returns,
1413
AverageDollarVolume,
1514
AverageTradedValue,
16-
SMA,
15+
CrossSectionalMean,
16+
Neutralize,
17+
Returns,
18+
RollingBeta,
1719
RSI,
20+
SMA,
21+
StaticPerSymbol,
1822
Volatility,
1923
)
2024

@@ -23,10 +27,14 @@
2327
"Factor",
2428
"CustomFactor",
2529
"Filter",
26-
"Returns",
2730
"AverageDollarVolume",
2831
"AverageTradedValue",
29-
"SMA",
32+
"CrossSectionalMean",
33+
"Neutralize",
34+
"Returns",
35+
"RollingBeta",
3036
"RSI",
37+
"SMA",
38+
"StaticPerSymbol",
3139
"Volatility",
3240
]

investing_algorithm_framework/domain/pipeline/factor.py

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -127,23 +127,44 @@ def bottom(self, n: int) -> "Filter":
127127
# ------------------------------------------------------------------ #
128128
# Cross-sectional transforms (Phase 2 / #502)
129129
# ------------------------------------------------------------------ #
130-
def zscore(self, mask: Optional["Filter"] = None) -> "Factor":
130+
def zscore(
131+
self,
132+
mask: Optional["Filter"] = None,
133+
groups=None,
134+
) -> "Factor":
131135
"""Cross-sectional z-score within each timestamp.
132136
133137
Returns ``(x - mean) / std`` computed over the symbols at each
134138
bar. With ``mask``, symbols outside the mask are excluded from
135139
the mean/std and receive ``null`` in the output.
140+
141+
``groups`` enables **group-relative** (e.g. sector-neutral)
142+
normalisation: the statistic is computed within each
143+
``(datetime, group)`` cell instead of across all symbols. It
144+
accepts:
145+
146+
- a ``dict[str, Any]`` mapping ``symbol`` → group label —
147+
internally wrapped in :class:`StaticPerSymbol`,
148+
- a :class:`Factor` returning a categorical value per row
149+
(e.g. a slow-moving fundamental bucket).
136150
"""
137-
return _Zscore(self, mask=mask)
151+
return _Zscore(self, mask=mask, groups=groups)
138152

139-
def demean(self, mask: Optional["Filter"] = None) -> "Factor":
153+
def demean(
154+
self,
155+
mask: Optional["Filter"] = None,
156+
groups=None,
157+
) -> "Factor":
140158
"""Cross-sectional mean removal within each timestamp.
141159
142160
Returns ``x - mean(x)`` computed over the symbols at each bar.
143161
With ``mask``, symbols outside the mask are excluded from the
144162
mean and receive ``null`` in the output.
163+
164+
``groups`` (same shape as in :meth:`zscore`) enables
165+
group-relative demeaning — e.g. sector neutrality.
145166
"""
146-
return _Demean(self, mask=mask)
167+
return _Demean(self, mask=mask, groups=groups)
147168

148169
def winsorize(
149170
self,
@@ -272,6 +293,28 @@ def _coerce_operand(operand) -> "Factor":
272293
)
273294

274295

296+
def _coerce_groups(groups) -> Optional["Factor"]:
297+
"""Normalise the ``groups`` argument of cross-sectional transforms
298+
into a :class:`Factor` (or ``None``).
299+
300+
Accepts ``None``, a ``dict[symbol, group]`` mapping (auto-wrapped
301+
in :class:`StaticPerSymbol`), or any pre-existing :class:`Factor`.
302+
"""
303+
if groups is None:
304+
return None
305+
if isinstance(groups, Factor):
306+
return groups
307+
if isinstance(groups, dict):
308+
# Local import to avoid an import cycle at module load: the
309+
# built-in factors module imports from this file.
310+
from .factors.builtin import StaticPerSymbol
311+
return StaticPerSymbol(groups)
312+
raise TypeError(
313+
f"Unsupported type for `groups`: {type(groups).__name__}. "
314+
f"Expected None, dict[str, Any], or Factor."
315+
)
316+
317+
275318
class _Constant(Factor):
276319
"""A panel-aligned constant series. Window is 1 (no warmup needed)."""
277320

@@ -371,22 +414,36 @@ class _CrossSectionalTransform(Factor):
371414
Polars expression for the (possibly mask-nulled) factor values and
372415
returns the transformed expression. The base class handles mask
373416
application and per-``datetime`` grouping.
417+
418+
When ``groups`` is provided, statistics are computed within each
419+
``(datetime, group)`` cell instead of across all symbols at a
420+
bar — enabling sector-neutral or otherwise group-relative
421+
transforms. ``groups`` may be a ``dict[symbol, group]`` (wrapped
422+
in :class:`StaticPerSymbol` automatically) or any :class:`Factor`
423+
returning a categorical value per row.
374424
"""
375425

376426
def __init__(
377427
self,
378428
base: Factor,
379429
mask: Optional["Filter"] = None,
430+
groups=None,
380431
) -> None:
381432
super().__init__(window=base.required_window())
382433
self._base = base
383434
self._mask = mask
435+
self._groups = _coerce_groups(groups)
384436
cols = list(base.required_columns())
385437
if mask is not None:
386438
for c in mask.required_columns():
387439
if c not in cols:
388440
cols.append(c)
389441
self.window = max(self.window, mask.required_window())
442+
if self._groups is not None:
443+
for c in self._groups.required_columns():
444+
if c not in cols:
445+
cols.append(c)
446+
self.window = max(self.window, self._groups.required_window())
390447
self.inputs = cols
391448

392449
def required_columns(self) -> List[str]:
@@ -398,6 +455,15 @@ def required_window(self) -> int:
398455
def _transform_expr(self) -> pl.Expr:
399456
raise NotImplementedError # pragma: no cover
400457

458+
def _group_keys(self) -> List[str]:
459+
"""Return the columns to group by for the cross-sectional
460+
statistic. ``["datetime"]`` for the standard case,
461+
``["datetime", "__group__"]`` when ``groups`` is set.
462+
"""
463+
if self._groups is None:
464+
return ["datetime"]
465+
return ["datetime", "__group__"]
466+
401467
def compute_panel(self, panel: pl.DataFrame) -> pl.Series:
402468
values = self._base.evaluate(panel)
403469
df = panel.select(["datetime", "symbol"]).with_columns(
@@ -411,17 +477,21 @@ def compute_panel(self, panel: pl.DataFrame) -> pl.Series:
411477
.otherwise(None)
412478
.alias("__x__")
413479
)
480+
if self._groups is not None:
481+
group_values = self._groups.evaluate(panel)
482+
df = df.with_columns(group_values.alias("__group__"))
414483
df = df.with_columns(self._transform_expr().alias("__out__"))
415484
return df["__out__"]
416485

417486

418487
class _Zscore(_CrossSectionalTransform):
419-
"""Cross-sectional z-score per bar."""
488+
"""Cross-sectional z-score per bar (optionally per group)."""
420489

421490
def _transform_expr(self) -> pl.Expr:
422491
x = pl.col("__x__")
423-
mean = x.mean().over("datetime")
424-
std = x.std().over("datetime")
492+
keys = self._group_keys()
493+
mean = x.mean().over(keys)
494+
std = x.std().over(keys)
425495
# If std is 0 or null, returning null is the safe choice (it
426496
# signals "no dispersion" rather than producing inf/NaN that
427497
# poisons downstream rolling stats).
@@ -433,11 +503,12 @@ def _transform_expr(self) -> pl.Expr:
433503

434504

435505
class _Demean(_CrossSectionalTransform):
436-
"""Cross-sectional mean removal per bar."""
506+
"""Cross-sectional mean removal per bar (optionally per group)."""
437507

438508
def _transform_expr(self) -> pl.Expr:
439509
x = pl.col("__x__")
440-
return x - x.mean().over("datetime")
510+
keys = self._group_keys()
511+
return x - x.mean().over(keys)
441512

442513

443514
class _Winsorize(_CrossSectionalTransform):
Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,26 @@
1-
"""Built-in factors shipped with the Pipeline API (Phase 1)."""
1+
"""Built-in factors shipped with the Pipeline API."""
22
from .builtin import (
3-
Returns,
43
AverageDollarVolume,
54
AverageTradedValue,
6-
SMA,
5+
CrossSectionalMean,
6+
Neutralize,
7+
Returns,
8+
RollingBeta,
79
RSI,
10+
SMA,
11+
StaticPerSymbol,
812
Volatility,
913
)
1014

1115
__all__ = [
12-
"Returns",
1316
"AverageDollarVolume",
1417
"AverageTradedValue",
15-
"SMA",
18+
"CrossSectionalMean",
19+
"Neutralize",
20+
"Returns",
21+
"RollingBeta",
1622
"RSI",
23+
"SMA",
24+
"StaticPerSymbol",
1725
"Volatility",
1826
]

0 commit comments

Comments
 (0)