[FEAT] Add partition_by support for lag transforms (with global_/groupby fixes)#636
[FEAT] Add partition_by support for lag transforms (with global_/groupby fixes)#636simonez-tuidi wants to merge 8 commits into
partition_by support for lag transforms (with global_/groupby fixes)#636Conversation
… partition_by window aggregations
nasaul
left a comment
There was a problem hiding this comment.
Code Review — PR #636: [FEAT] Add partition_by support for lag transforms (with global_/groupby fixes)
Overview
Adds a partition_by argument to lag transforms allowing a single transform to compute statistics over arbitrary partition keys (e.g., promo, brand) rather than strictly per-series windows. Supports three modes: local (per-series), groupby (grouped series), and global_ (all series). Also fixes a pre-existing correctness bug where bounded rolling windows in groupby/global_ mode collapsed multi-series timestamps via sum, producing wrong results. Fix switches to RANGE semantics for non-local modes.
Correctness
- The RANGE-semantics fix (keeping individual rows per observation instead of timestamp-level aggregation) is sound and well-motivated. However, there are two legitimate interpretations of "group rolling mean across series" and the PR only supports one — worth discussing (see Discussion section below).
localmode behavior is unchanged — safe._RollingBase._compute_bucket_featurecorrectly handlesmin_samples,NaNpadding, and tied-timestamp broadcasting vianp.uniqueinverse index.RollingMean._compute_bucket_featureO(m log m) cumsum path is mathematically correct and consistent with the O(n×w) fallback.RollingStd._window_statusesddof=1— verified consistent with coreforecast's implementation.- Backup/restore in
_backupnow includes_partition_statesandtarget_transforms— both necessary for correctness.
Issues
1. assert in production code — core.py:484 (high)
context = self._get_partition_context(step_df)
assert context is not Noneassert statements are silently removed with python -O. Consider replacing with a proper guard:
if context is None:
raise RuntimeError(
"_get_partition_context returned None despite non-empty _partition_states. "
"This is a bug — please report it."
)2. Silent wrong results for _Seasonal_RollingBase + partition_by in non-local mode (high)
Seasonal rolling and ExponentiallyWeightedMean fall back to the position-based path when used with partition_by in group/global_ mode. For bounded seasonal transforms this is semantically incorrect (same timestamp-collapse issue the PR fixes for _RollingBase). Currently produces wrong results with no indication to the user. Should emit a UserWarning at fit_transform time, or raise NotImplementedError until RANGE semantics are implemented for these classes.
3. Dead attribute _current_step_x (medium)
Set in _get_features_for_next_step (line 1584) and reset in _predict_setup (line 1640), but never read anywhere. Remove it unless it's intentionally exposed for external introspection (in which case, document it).
4. Redundant getattr defensive checks (medium)
_partition_states is unconditionally initialized in __init__, so getattr(self, "_partition_states", {}) (lines 419, 479, 510, 1613) is misleading — it implies the attribute might not exist. Replace with direct self._partition_states access.
5. deepcopy on every predict timestep in _compute_partition_features (medium)
fresh_tfm = copy.deepcopy(tfm)._set_core_tfm(tfm._get_configured_lag())
fresh_tfm.transform(core_ga)
updates[name] = fresh_tfm.update(core_ga)A full deepcopy of the transform is made on every recursive predict step. The existing _group_states predict path avoids this pattern. For models with many partition transforms, this may noticeably affect predict latency.
6. Explicit pandas/polars split paths instead of narwhals (high)
The new partition code introduces isinstance(data, pd.DataFrame) branches throughout — _add_bucket_id, _lookup_bucket_ids, _ensure_partition_bucket_ids, _get_partition_context, _bucket_pos computation, and the join-back in _transform all have dual implementations. The existing codebase abstracts these operations via utilsforecast.processing (ufp) or narwhals; the new code largely bypasses that. This doubles the maintenance surface for every operation and makes the new code inconsistent with the rest of the codebase. The partition-specific helpers should be rewritten using narwhals or routed through ufp where equivalent operations already exist.
7. Polars path casts partition cols to pl.String on every lookup (low)
In _lookup_bucket_ids (lines 167–173), all partition key columns are cast to pl.String on every call to avoid categorical cache mismatches. For integer partition columns this is an avoidable cast on the hot predict path. A targeted fix (only cast when dtype differs) would be cleaner.
8. _has_aggregated_partition_tfms re-iterates all transforms on every call (low)
Called from _check_aligned_ends, predict, and update. Since transform configs are fixed post-fit, consider caching this as a bool at the end of _fit.
9. target_transforms backup is unrelated to this PR (low)
The target_transforms deepcopy added to _backup is a genuine correctness fix but touches non-partition code paths. It changes backup/restore behavior for all callers. Worth calling out explicitly in the PR description.
Tests
- Coverage of fit/predict/update/cross-validation across pandas + polars is solid.
- Hand-computed reference values with detailed inline comments make the tests easy to audit.
Issues
-
test_partition_combine_lag_transform_predict: Only checkscol in features.columnsandnot features[col].isna().all(). No numerical verification — this is a smoke test, not a correctness test. Should verify specific values similar to the other predict tests. -
Duplicate helpers:
_make_partition_seriesand_make_partition_futureare duplicated betweentest_auto.pyandtest_forecast.py. Move toconftest.py. -
Test comment mismatch in
test_global_partition_lag_transform: Comments describe position-based sequential logic (pos=3: rolling_mean([2, 20])) butRollingMeanoverrides_compute_bucket_featurewith the RANGE path. Results happen to agree here, but the comments are misleading about the code path. -
No test for
_Seasonal_RollingBase+partition_bynon-local mode: Either a correctness test or an explicit test verifying a warning/error is raised would prevent future confusion. -
test_forecast.pyintegration tests only assert shape and not-NA: Bothtest_partition_by_cross_validation_refit_falseandtest_transfer_learning_with_partition_by_group_transformnever verify predicted values. A stronger pattern: construct data wherey_Tequals the cross-series partition mean atT-1, fit aLinearRegression(fit_intercept=False)withlags=[]and only aRollingMean(window=1, lag=1, global_=True, partition_by=[...])transform, and assert predictions match targets within tolerance. If the feature flows through the pipeline correctly the model learns coefficient=1 and predictions reconstruct the target exactly.RollingMean(window=1)is the right choice here —ExpandingMeanwould require constructingy_Tas the full prior expanding mean which is harder to set analytically.
Minor / Style
_normalize_groupby→_normalize_columnsrename is a clean, non-breaking refactor._dedupe_preserve_orderutility is clear and correctly placed at module level.Combine.update_samplesproperty was moved (not removed) — the reorganization is fine.noqa: ARG001/ARG005additions intest_auto.pyare correct lint cleanups._get_partition_tfmsmode dispatch logic (global > group > local) is easy to follow.
Discussion
RANGE semantics vs. aggregate-then-roll
The PR fixes non-local partition rolling by switching to RANGE semantics — at time T, collect all individual observations from matching series where ts ∈ [T - lag - window + 1, T - lag] and compute the statistic over those raw values. This is one valid interpretation, but there is a second equally valid one:
Aggregate-then-roll: first collapse all series at each timestamp to a single value (e.g., sum or mean), then roll over those per-timestamp aggregates. So if b=20 and c=15 are both at ds=2, they collapse to 17.5 (mean) or 35 (sum), and the window rolls over that aggregate.
Both semantics are useful depending on the problem:
- RANGE is natural when you want the statistic over the raw distribution of individual observations — e.g., estimating a population mean from individual transactions.
- Aggregate-then-roll is natural when you want to roll over a group-level signal — e.g., total promo sales per period, then a rolling average of those period totals.
The original (buggy) code happened to implement a broken version of aggregate-then-roll (sum collapse). Users who were relying on that structure — even if the values were wrong — now get silently different results with no migration path.
A suggested addition would be an agg parameter on partition transforms:
RollingMean(
window=2,
groupby=["brand"],
partition_by=["promo"],
agg=None, # None = RANGE semantics (current behavior, default)
# "sum" or "mean" = collapse same-timestamp obs before rolling
)agg=None preserves the current fix. agg="sum" or agg="mean" enables the aggregate-then-roll path with correct semantics. This would give users explicit control over which interpretation they want rather than having one silently chosen for them.
Summary
The correctness fix (RANGE semantics) is solid and the feature design is well-structured. The test suite is thorough with detailed hand-verified expectations. Main items before merge:
| Priority | Issue |
|---|---|
| High | Replace assert context is not None with raise |
| High | Warn or raise on _Seasonal_RollingBase / ExponentiallyWeightedMean + partition_by in non-local mode |
| High | Rewrite partition helpers using narwhals / ufp instead of explicit pandas/polars splits |
| Medium | Remove dead _current_step_x attribute |
| Medium | Replace redundant getattr(self, "_partition_states", {}) with self._partition_states |
| Medium | Investigate deepcopy cost per predict step in _compute_partition_features |
| Low | Move shared test helpers to conftest.py |
| Low | Strengthen test_partition_combine_lag_transform_predict with numerical assertions |
| Low | Fix test comments in test_global_partition_lag_transform to describe the RANGE path |
|
Thanks for this @simonez-tuidi this is a great work and a great contribution. I think that we should address the discussion section first before implementing any changes. |
|
Hi @nasaul thanks for your thorough review. I'll address the issues you raised in separate commits, issue 5 might require quite a bit of reworking to allow incremental updates which would improve latency but might be non-trivial. Regarding the discussion points, I would note that the review assumes that the previous implementation was part of the "established" code, instead of a first iteration not yet merged. I believe the previous logic was fundamentally flawed and would not serve much purpose in real-world applications. Let's take the example below, where two series
In the case of a Sum aggregation, we would expect the result to be the sum of all the values that belong to the partition, in this case 80. In the case of a mean aggregation, the same logic would need to be assumed, as it is simply the previous sum, divided by the number of observations, i.e. 16. If we first collapse the values at the timestamp level and then aggregate again, we are essentially introducing a double aggregation, first taking the mean at each timestamp, then taking the mean of all the aggregated values at each timestamp. I've just realized that this logic is what is currently implemented in the Now, back to the issue at hand for this PR, unfortunately working out how to define the My argument is that these transformations should follow the same logic as SQL queries, and this is the equivalence I have in mind:
In all of the above, the transformation finds all the relevant records that belong to the window and takes the mean. In the Nixtla-verse, all transforms are already essentially a window function with partition by, like in the first example where the partition is defined over I think this design makes it pretty intuitive to know what to expect from each transform in all of the cases above. The current path defined for the Note: unfortunately the other test introduced in that PR is pretty uninformative, as there is only one series per group and therefore the results coincide with the non-grouped example. |
|
@nasaul & @janrth I've expanded on the points above in a new issue #640 - keen to get your thoughts on this I also got fact-checked by AI - my interpretation that the existing |
|
Major changes in #641 will influence how this will be implemented and might require a partial or complete rework -> converting to draft for now |
Summary
Adds a
partition_byargument to lag transforms so a single transform canoperate over partitioned observation buckets across
local,groupby, andglobal_modes, following from #587 . Supersedes the WIP work from @janrth in #609 from branchjanrth:feature/partition_by_window_functionsand fixes correctness bugs in the non-local modes that produced wrong results
for bounded rolling windows.
Relationship to janrth's PR
This PR includes janrth's commits as the foundation
(
e62aa03,6a26ef5,d22290aand main merges). On top, it adds:0b735a8— re-implementation ofgroupbyandglobal_modes to fixincorrect rolling-window outputs when multiple series share a timestamp.
4591401— merge ofmain(which now hasdrop_auxiliary_columnsand_initialize_lag_transform_states) with conflict resolution and lintcleanup.
The intent is to land the feature in one shot rather than landing janrth's
PR and following up with a fix.
Motivation
partition_bylets a lag transform aggregate over arbitrary partition keys(promo flag, store cluster, etc.) instead of being restricted to per-series
windows. The existing implementation correctly handled
localmode butmiscomputed bounded rolling windows in
groupbyandglobal_mode whenevertwo series shared a timestamp inside a bucket — the pre-fix code aggregated
the target by timestamp via
group_by_agg(... "sum"), collapsing theper-observation cardinality the rolling window depends on.
What
partition_bydoesA lag transform with
partition_by=...defines a bucket key. Three modescontrol how series are folded into buckets:
local(id, *partition_cols)groupby(*group_cols, *partition_cols)global_(*partition_cols,)Bucket state persists across fit / predict / cross-validation / refit, and is
used uniformly with the existing local / global / group-aggregated paths.
What was broken in
groupby/global_Pre-fix
_build_partition_bucket_dfaggregated the target by(key_cols, time_col)withsum. A rolling window then ran over onesynthetic value per timestamp, so e.g. two series in the same bucket at
ds=2with values20and15were collapsed to a single35. Boundedrolling stats (
RollingMean(window=2),RollingStd,RollingMin,RollingMax,RollingQuantile) returned values that did not match theSQL
RANGE BETWEEN ... PRECEDINGsemantics users would expect.localmode was unaffected because exactly one series maps to each bucket.Fix approach
Keep individual rows; switch to RANGE semantics for non-local modes.
In
mlforecast/core.py:_build_partition_bucket_dfno longer aggregates forgroupby/global_.It keeps one row per original observation, sorted by
key_cols + [time_col, id_col], and adds a sequential_bucket_posper bucket so the existing
process_dfmachinery can run._get_partition_key_cols,_add_bucket_id,_lookup_bucket_ids,_get_partition_context,_ensure_partition_bucket_ids,_compute_partition_features,_update_partition_states) maintain bucket-id ↔ row mappings acrossfit / predict / update, including dynamic allocation of new bucket ids
for unseen partition keys at predict time and Polars categorical-cast
safety on join-back.
join_colsswitches fromkey_cols + [time_col](local) to[id_col, time_col](non-local) so each original series row recoversits own feature value.
In
mlforecast/lag_transforms.py:_BaseLagTransform._compute_bucket_feature(bid_arr, ts_arr, y_arr)returning
Optional[np.ndarray]. DefaultNonefalls back to theposition-based
GroupedArraytransform, which is correct for unbounded(expanding) transforms because position-based expanding over
timestamp-sorted observations equals timestamp-based expanding.
_RollingBase._compute_bucket_featureoverrides with a RANGE-based loop:for each row at
(bucket_id, T)collect observations withts ∈ [T - lag - window_size + 1, T - lag], apply_window_stat(vals),broadcast back to tied timestamps via
np.uniqueinverse. Subclasses(
RollingStd,RollingMin,RollingMax,RollingQuantile) implement_window_stat.RollingMean._compute_bucket_featureoverrides again with anO(m log m) cumulative-sum /
np.searchsortedfast path(m = unique timestamps in bucket).
_Seasonal_RollingBaseandExponentiallyWeightedMeandeliberately keep thedefault fallback for now — semantics for those under partition modes are
out of scope for this PR.
Diff scope (vs
upstream/main)Tests
tests/test_core.pyadds multi-series partition fixtures(
_make_partition_df(include_brand=True)and_make_partition_future_df(include_brand=True), introducing seriescsharing a brand with
b) and exercises:test_group_partition_lag_transformand_predict— bounded rolling onmulti-series buckets, including expected NaNs where
min_samplesis notsatisfied.
test_global_partition_lag_transformand_predict— same-timestamptie-breaking against the global bucket.
test_aggregated_partition_lag_transform_update— refit / update pathwith new and existing buckets.
Backward compatibility
partition_byparameter.localmode behavior is unchanged.groupby/global_outputs for bounded rolling windows wereincorrect; numeric outputs change accordingly. Any user who pinned
expectations on those values needs to re-baseline.
Verification
uv run pytest tests/test_core.py -k partition— partition unit tests.uv run pytest tests/test_forecast.py tests/test_auto.py— surroundingregressions touched by the merge with
main.uv run ruff check mlforecast tests— lint clean.RANGE BETWEENsemantics on the multi-seriesfixture by comparing
RollingMean(window=2)outputs to a hand-computedreference in
tests/test_core.py::test_group_partition_lag_transform.