@@ -127,23 +127,44 @@ def bottom(self, n: int) -> "Filter":
127127 # ------------------------------------------------------------------ #
128128 # Cross-sectional transforms (Phase 2 / #502)
129129 # ------------------------------------------------------------------ #
130- def zscore (self , mask : Optional ["Filter" ] = None ) -> "Factor" :
130+ def zscore (
131+ self ,
132+ mask : Optional ["Filter" ] = None ,
133+ groups = None ,
134+ ) -> "Factor" :
131135 """Cross-sectional z-score within each timestamp.
132136
133137 Returns ``(x - mean) / std`` computed over the symbols at each
134138 bar. With ``mask``, symbols outside the mask are excluded from
135139 the mean/std and receive ``null`` in the output.
140+
141+ ``groups`` enables **group-relative** (e.g. sector-neutral)
142+ normalisation: the statistic is computed within each
143+ ``(datetime, group)`` cell instead of across all symbols. It
144+ accepts:
145+
146+ - a ``dict[str, Any]`` mapping ``symbol`` → group label —
147+ internally wrapped in :class:`StaticPerSymbol`,
148+ - a :class:`Factor` returning a categorical value per row
149+ (e.g. a slow-moving fundamental bucket).
136150 """
137- return _Zscore (self , mask = mask )
151+ return _Zscore (self , mask = mask , groups = groups )
138152
139- def demean (self , mask : Optional ["Filter" ] = None ) -> "Factor" :
153+ def demean (
154+ self ,
155+ mask : Optional ["Filter" ] = None ,
156+ groups = None ,
157+ ) -> "Factor" :
140158 """Cross-sectional mean removal within each timestamp.
141159
142160 Returns ``x - mean(x)`` computed over the symbols at each bar.
143161 With ``mask``, symbols outside the mask are excluded from the
144162 mean and receive ``null`` in the output.
163+
164+ ``groups`` (same shape as in :meth:`zscore`) enables
165+ group-relative demeaning — e.g. sector neutrality.
145166 """
146- return _Demean (self , mask = mask )
167+ return _Demean (self , mask = mask , groups = groups )
147168
148169 def winsorize (
149170 self ,
@@ -272,6 +293,28 @@ def _coerce_operand(operand) -> "Factor":
272293 )
273294
274295
296+ def _coerce_groups (groups ) -> Optional ["Factor" ]:
297+ """Normalise the ``groups`` argument of cross-sectional transforms
298+ into a :class:`Factor` (or ``None``).
299+
300+ Accepts ``None``, a ``dict[symbol, group]`` mapping (auto-wrapped
301+ in :class:`StaticPerSymbol`), or any pre-existing :class:`Factor`.
302+ """
303+ if groups is None :
304+ return None
305+ if isinstance (groups , Factor ):
306+ return groups
307+ if isinstance (groups , dict ):
308+ # Local import to avoid an import cycle at module load: the
309+ # built-in factors module imports from this file.
310+ from .factors .builtin import StaticPerSymbol
311+ return StaticPerSymbol (groups )
312+ raise TypeError (
313+ f"Unsupported type for `groups`: { type (groups ).__name__ } . "
314+ f"Expected None, dict[str, Any], or Factor."
315+ )
316+
317+
275318class _Constant (Factor ):
276319 """A panel-aligned constant series. Window is 1 (no warmup needed)."""
277320
@@ -371,22 +414,36 @@ class _CrossSectionalTransform(Factor):
371414 Polars expression for the (possibly mask-nulled) factor values and
372415 returns the transformed expression. The base class handles mask
373416 application and per-``datetime`` grouping.
417+
418+ When ``groups`` is provided, statistics are computed within each
419+ ``(datetime, group)`` cell instead of across all symbols at a
420+ bar — enabling sector-neutral or otherwise group-relative
421+ transforms. ``groups`` may be a ``dict[symbol, group]`` (wrapped
422+ in :class:`StaticPerSymbol` automatically) or any :class:`Factor`
423+ returning a categorical value per row.
374424 """
375425
376426 def __init__ (
377427 self ,
378428 base : Factor ,
379429 mask : Optional ["Filter" ] = None ,
430+ groups = None ,
380431 ) -> None :
381432 super ().__init__ (window = base .required_window ())
382433 self ._base = base
383434 self ._mask = mask
435+ self ._groups = _coerce_groups (groups )
384436 cols = list (base .required_columns ())
385437 if mask is not None :
386438 for c in mask .required_columns ():
387439 if c not in cols :
388440 cols .append (c )
389441 self .window = max (self .window , mask .required_window ())
442+ if self ._groups is not None :
443+ for c in self ._groups .required_columns ():
444+ if c not in cols :
445+ cols .append (c )
446+ self .window = max (self .window , self ._groups .required_window ())
390447 self .inputs = cols
391448
392449 def required_columns (self ) -> List [str ]:
@@ -398,6 +455,15 @@ def required_window(self) -> int:
398455 def _transform_expr (self ) -> pl .Expr :
399456 raise NotImplementedError # pragma: no cover
400457
458+ def _group_keys (self ) -> List [str ]:
459+ """Return the columns to group by for the cross-sectional
460+ statistic. ``["datetime"]`` for the standard case,
461+ ``["datetime", "__group__"]`` when ``groups`` is set.
462+ """
463+ if self ._groups is None :
464+ return ["datetime" ]
465+ return ["datetime" , "__group__" ]
466+
401467 def compute_panel (self , panel : pl .DataFrame ) -> pl .Series :
402468 values = self ._base .evaluate (panel )
403469 df = panel .select (["datetime" , "symbol" ]).with_columns (
@@ -411,17 +477,21 @@ def compute_panel(self, panel: pl.DataFrame) -> pl.Series:
411477 .otherwise (None )
412478 .alias ("__x__" )
413479 )
480+ if self ._groups is not None :
481+ group_values = self ._groups .evaluate (panel )
482+ df = df .with_columns (group_values .alias ("__group__" ))
414483 df = df .with_columns (self ._transform_expr ().alias ("__out__" ))
415484 return df ["__out__" ]
416485
417486
418487class _Zscore (_CrossSectionalTransform ):
419- """Cross-sectional z-score per bar."""
488+ """Cross-sectional z-score per bar (optionally per group) ."""
420489
421490 def _transform_expr (self ) -> pl .Expr :
422491 x = pl .col ("__x__" )
423- mean = x .mean ().over ("datetime" )
424- std = x .std ().over ("datetime" )
492+ keys = self ._group_keys ()
493+ mean = x .mean ().over (keys )
494+ std = x .std ().over (keys )
425495 # If std is 0 or null, returning null is the safe choice (it
426496 # signals "no dispersion" rather than producing inf/NaN that
427497 # poisons downstream rolling stats).
@@ -433,11 +503,12 @@ def _transform_expr(self) -> pl.Expr:
433503
434504
435505class _Demean (_CrossSectionalTransform ):
436- """Cross-sectional mean removal per bar."""
506+ """Cross-sectional mean removal per bar (optionally per group) ."""
437507
438508 def _transform_expr (self ) -> pl .Expr :
439509 x = pl .col ("__x__" )
440- return x - x .mean ().over ("datetime" )
510+ keys = self ._group_keys ()
511+ return x - x .mean ().over (keys )
441512
442513
443514class _Winsorize (_CrossSectionalTransform ):
0 commit comments