@@ -124,6 +124,76 @@ def bottom(self, n: int) -> "Filter":
124124 from .filter import _BottomN
125125 return _BottomN (self , n )
126126
127+ # ------------------------------------------------------------------ #
128+ # Cross-sectional transforms (Phase 2 / #502)
129+ # ------------------------------------------------------------------ #
130+ def zscore (self , mask : Optional ["Filter" ] = None ) -> "Factor" :
131+ """Cross-sectional z-score within each timestamp.
132+
133+ Returns ``(x - mean) / std`` computed over the symbols at each
134+ bar. With ``mask``, symbols outside the mask are excluded from
135+ the mean/std and receive ``null`` in the output.
136+ """
137+ return _Zscore (self , mask = mask )
138+
139+ def demean (self , mask : Optional ["Filter" ] = None ) -> "Factor" :
140+ """Cross-sectional mean removal within each timestamp.
141+
142+ Returns ``x - mean(x)`` computed over the symbols at each bar.
143+ With ``mask``, symbols outside the mask are excluded from the
144+ mean and receive ``null`` in the output.
145+ """
146+ return _Demean (self , mask = mask )
147+
148+ def winsorize (
149+ self ,
150+ lower : float = 0.01 ,
151+ upper : float = 0.99 ,
152+ mask : Optional ["Filter" ] = None ,
153+ ) -> "Factor" :
154+ """Cross-sectional winsorisation within each timestamp.
155+
156+ Clips values below the ``lower`` quantile and above the
157+ ``upper`` quantile (computed per bar). Both bounds are in
158+ ``[0, 1]`` with ``lower < upper``.
159+ """
160+ if not (0.0 <= lower < upper <= 1.0 ):
161+ raise ValueError (
162+ f"winsorize requires 0 <= lower < upper <= 1, "
163+ f"got lower={ lower } , upper={ upper } "
164+ )
165+ return _Winsorize (self , lower = lower , upper = upper , mask = mask )
166+
167+ # ------------------------------------------------------------------ #
168+ # Arithmetic (Phase 2 / #502) — composes Factors into expression trees
169+ # ------------------------------------------------------------------ #
170+ def __neg__ (self ) -> "Factor" :
171+ return _UnaryOp (self , op = "neg" )
172+
173+ def __add__ (self , other ) -> "Factor" :
174+ return _BinaryOp (self , other , op = "add" )
175+
176+ def __radd__ (self , other ) -> "Factor" :
177+ return _BinaryOp (other , self , op = "add" )
178+
179+ def __sub__ (self , other ) -> "Factor" :
180+ return _BinaryOp (self , other , op = "sub" )
181+
182+ def __rsub__ (self , other ) -> "Factor" :
183+ return _BinaryOp (other , self , op = "sub" )
184+
185+ def __mul__ (self , other ) -> "Factor" :
186+ return _BinaryOp (self , other , op = "mul" )
187+
188+ def __rmul__ (self , other ) -> "Factor" :
189+ return _BinaryOp (other , self , op = "mul" )
190+
191+ def __truediv__ (self , other ) -> "Factor" :
192+ return _BinaryOp (self , other , op = "div" )
193+
194+ def __rtruediv__ (self , other ) -> "Factor" :
195+ return _BinaryOp (other , self , op = "div" )
196+
127197 # ------------------------------------------------------------------ #
128198 # Repr
129199 # ------------------------------------------------------------------ #
@@ -180,3 +250,218 @@ def compute_panel(self, panel: pl.DataFrame) -> pl.Series:
180250 .alias ("__rank__" )
181251 )
182252 return ranked ["__rank__" ]
253+
254+
255+ # --------------------------------------------------------------------- #
256+ # Phase 2 expression-tree wrappers (#502): arithmetic + cross-sectional
257+ # transforms. Each wrapper composes existing factors into a new factor
258+ # without losing the per-evaluation cache (they call ``evaluate`` on
259+ # their children, not ``compute_panel``).
260+ # --------------------------------------------------------------------- #
261+ def _coerce_operand (operand ) -> "Factor" :
262+ """Wrap a scalar operand in a :class:`_Constant` so binary ops
263+ can treat ``factor + 1`` and ``factor + other_factor`` uniformly.
264+ """
265+ if isinstance (operand , Factor ):
266+ return operand
267+ if isinstance (operand , (int , float )):
268+ return _Constant (float (operand ))
269+ raise TypeError (
270+ f"Unsupported operand type for Factor arithmetic: "
271+ f"{ type (operand ).__name__ } "
272+ )
273+
274+
275+ class _Constant (Factor ):
276+ """A panel-aligned constant series. Window is 1 (no warmup needed)."""
277+
278+ inputs : List [str ] = []
279+
280+ def __init__ (self , value : float ) -> None :
281+ super ().__init__ (window = 1 )
282+ self ._value = float (value )
283+
284+ def required_columns (self ) -> List [str ]:
285+ return []
286+
287+ def compute_panel (self , panel : pl .DataFrame ) -> pl .Series :
288+ return pl .Series (
289+ "__const__" , [self ._value ] * panel .height , dtype = pl .Float64
290+ )
291+
292+
293+ class _UnaryOp (Factor ):
294+ """Element-wise unary op (currently only ``neg``)."""
295+
296+ def __init__ (self , base : Factor , op : str ) -> None :
297+ super ().__init__ (window = base .required_window ())
298+ self ._base = base
299+ self ._op = op
300+ self .inputs = list (base .required_columns ())
301+
302+ def required_columns (self ) -> List [str ]:
303+ return list (self .inputs )
304+
305+ def required_window (self ) -> int :
306+ return int (self .window )
307+
308+ def compute_panel (self , panel : pl .DataFrame ) -> pl .Series :
309+ values = self ._base .evaluate (panel )
310+ if self ._op == "neg" :
311+ return (- values ).rename ("__unary__" )
312+ raise ValueError (f"Unknown unary op: { self ._op } " ) # pragma: no cover
313+
314+
315+ class _BinaryOp (Factor ):
316+ """Element-wise binary arithmetic between two ``Factor``s.
317+
318+ Either operand may be a scalar; it is auto-wrapped in
319+ :class:`_Constant`.
320+ """
321+
322+ def __init__ (self , left , right , op : str ) -> None :
323+ left_f = _coerce_operand (left )
324+ right_f = _coerce_operand (right )
325+ super ().__init__ (
326+ window = max (
327+ left_f .required_window (), right_f .required_window ()
328+ )
329+ )
330+ self ._left = left_f
331+ self ._right = right_f
332+ self ._op = op
333+ cols : List [str ] = list (left_f .required_columns ())
334+ for c in right_f .required_columns ():
335+ if c not in cols :
336+ cols .append (c )
337+ self .inputs = cols
338+
339+ def required_columns (self ) -> List [str ]:
340+ return list (self .inputs )
341+
342+ def required_window (self ) -> int :
343+ return int (self .window )
344+
345+ def compute_panel (self , panel : pl .DataFrame ) -> pl .Series :
346+ left = self ._left .evaluate (panel )
347+ right = self ._right .evaluate (panel )
348+ if self ._op == "add" :
349+ out = left + right
350+ elif self ._op == "sub" :
351+ out = left - right
352+ elif self ._op == "mul" :
353+ out = left * right
354+ elif self ._op == "div" :
355+ # Polars naturally yields nulls when the divisor is null;
356+ # division by zero produces inf which we leave as-is so
357+ # callers can decide what to do (e.g. ``zscore`` will
358+ # propagate inf and downstream filters can drop it).
359+ out = left / right
360+ else :
361+ raise ValueError ( # pragma: no cover
362+ f"Unknown binary op: { self ._op } "
363+ )
364+ return out .rename ("__binop__" )
365+
366+
367+ class _CrossSectionalTransform (Factor ):
368+ """Common base for per-bar transforms (zscore / demean / winsorize).
369+
370+ Subclasses implement :meth:`_transform_per_bar` which receives a
371+ Polars expression for the (possibly mask-nulled) factor values and
372+ returns the transformed expression. The base class handles mask
373+ application and per-``datetime`` grouping.
374+ """
375+
376+ def __init__ (
377+ self ,
378+ base : Factor ,
379+ mask : Optional ["Filter" ] = None ,
380+ ) -> None :
381+ super ().__init__ (window = base .required_window ())
382+ self ._base = base
383+ self ._mask = mask
384+ cols = list (base .required_columns ())
385+ if mask is not None :
386+ for c in mask .required_columns ():
387+ if c not in cols :
388+ cols .append (c )
389+ self .window = max (self .window , mask .required_window ())
390+ self .inputs = cols
391+
392+ def required_columns (self ) -> List [str ]:
393+ return list (self .inputs )
394+
395+ def required_window (self ) -> int :
396+ return int (self .window )
397+
398+ def _transform_expr (self ) -> pl .Expr :
399+ raise NotImplementedError # pragma: no cover
400+
401+ def compute_panel (self , panel : pl .DataFrame ) -> pl .Series :
402+ values = self ._base .evaluate (panel )
403+ df = panel .select (["datetime" , "symbol" ]).with_columns (
404+ values .alias ("__x__" )
405+ )
406+ if self ._mask is not None :
407+ mask_values = self ._mask .evaluate (panel )
408+ df = df .with_columns (
409+ pl .when (mask_values )
410+ .then (pl .col ("__x__" ))
411+ .otherwise (None )
412+ .alias ("__x__" )
413+ )
414+ df = df .with_columns (self ._transform_expr ().alias ("__out__" ))
415+ return df ["__out__" ]
416+
417+
418+ class _Zscore (_CrossSectionalTransform ):
419+ """Cross-sectional z-score per bar."""
420+
421+ def _transform_expr (self ) -> pl .Expr :
422+ x = pl .col ("__x__" )
423+ mean = x .mean ().over ("datetime" )
424+ std = x .std ().over ("datetime" )
425+ # If std is 0 or null, returning null is the safe choice (it
426+ # signals "no dispersion" rather than producing inf/NaN that
427+ # poisons downstream rolling stats).
428+ return (
429+ pl .when ((std == 0 ) | std .is_null ())
430+ .then (None )
431+ .otherwise ((x - mean ) / std )
432+ )
433+
434+
435+ class _Demean (_CrossSectionalTransform ):
436+ """Cross-sectional mean removal per bar."""
437+
438+ def _transform_expr (self ) -> pl .Expr :
439+ x = pl .col ("__x__" )
440+ return x - x .mean ().over ("datetime" )
441+
442+
443+ class _Winsorize (_CrossSectionalTransform ):
444+ """Cross-sectional clip-to-quantiles per bar."""
445+
446+ def __init__ (
447+ self ,
448+ base : Factor ,
449+ lower : float ,
450+ upper : float ,
451+ mask : Optional ["Filter" ] = None ,
452+ ) -> None :
453+ super ().__init__ (base = base , mask = mask )
454+ self ._lower = float (lower )
455+ self ._upper = float (upper )
456+
457+ def _transform_expr (self ) -> pl .Expr :
458+ x = pl .col ("__x__" )
459+ lo = x .quantile (self ._lower ).over ("datetime" )
460+ hi = x .quantile (self ._upper ).over ("datetime" )
461+ return (
462+ pl .when (x < lo )
463+ .then (lo )
464+ .when (x > hi )
465+ .then (hi )
466+ .otherwise (x )
467+ )
0 commit comments