Skip to content

[FEAT] Add partition_by support for lag transforms (with global_/groupby fixes)#636

Draft
simonez-tuidi wants to merge 8 commits into
Nixtla:mainfrom
simonez-tuidi:feature/partition_by_window_fixes
Draft

[FEAT] Add partition_by support for lag transforms (with global_/groupby fixes)#636
simonez-tuidi wants to merge 8 commits into
Nixtla:mainfrom
simonez-tuidi:feature/partition_by_window_fixes

Conversation

@simonez-tuidi
Copy link
Copy Markdown
Contributor

@simonez-tuidi simonez-tuidi commented Apr 30, 2026

Summary

Adds a partition_by argument to lag transforms so a single transform can
operate over partitioned observation buckets across local, groupby, and
global_ modes, following from #587 . Supersedes the WIP work from @janrth in #609 from branch
janrth:feature/partition_by_window_functions
and fixes correctness bugs in the non-local modes that produced wrong results
for bounded rolling windows.

Relationship to janrth's PR

This PR includes janrth's commits as the foundation
(e62aa03, 6a26ef5, d22290a and main merges). On top, it adds:

  • 0b735a8 — re-implementation of groupby and global_ modes to fix
    incorrect rolling-window outputs when multiple series share a timestamp.
  • 4591401 — merge of main (which now has drop_auxiliary_columns and
    _initialize_lag_transform_states) with conflict resolution and lint
    cleanup.

The intent is to land the feature in one shot rather than landing janrth's
PR and following up with a fix.

Motivation

partition_by lets a lag transform aggregate over arbitrary partition keys
(promo flag, store cluster, etc.) instead of being restricted to per-series
windows. The existing implementation correctly handled local mode but
miscomputed bounded rolling windows in groupby and global_ mode whenever
two series shared a timestamp inside a bucket — the pre-fix code aggregated
the target by timestamp via group_by_agg(... "sum"), collapsing the
per-observation cardinality the rolling window depends on.

What partition_by does

A lag transform with partition_by=... defines a bucket key. Three modes
control how series are folded into buckets:

Mode Bucket key Rows per bucket
local (id, *partition_cols) one series, one row per ds
groupby (*group_cols, *partition_cols) many series, possibly multiple rows per ds
global_ (*partition_cols,) all matching series, possibly multiple rows per ds

Bucket state persists across fit / predict / cross-validation / refit, and is
used uniformly with the existing local / global / group-aggregated paths.

What was broken in groupby / global_

Pre-fix _build_partition_bucket_df aggregated the target by
(key_cols, time_col) with sum. A rolling window then ran over one
synthetic value per timestamp, so e.g. two series in the same bucket at
ds=2 with values 20 and 15 were collapsed to a single 35. Bounded
rolling stats (RollingMean(window=2), RollingStd, RollingMin,
RollingMax, RollingQuantile) returned values that did not match the
SQL RANGE BETWEEN ... PRECEDING semantics users would expect.

local mode was unaffected because exactly one series maps to each bucket.

Fix approach

Keep individual rows; switch to RANGE semantics for non-local modes.

In mlforecast/core.py:

  • _build_partition_bucket_df no longer aggregates for groupby / global_.
    It keeps one row per original observation, sorted by
    key_cols + [time_col, id_col], and adds a sequential _bucket_pos
    per bucket so the existing process_df machinery can run.
  • New helpers (_get_partition_key_cols, _add_bucket_id,
    _lookup_bucket_ids, _get_partition_context,
    _ensure_partition_bucket_ids, _compute_partition_features,
    _update_partition_states) maintain bucket-id ↔ row mappings across
    fit / predict / update, including dynamic allocation of new bucket ids
    for unseen partition keys at predict time and Polars categorical-cast
    safety on join-back.
  • join_cols switches from key_cols + [time_col] (local) to
    [id_col, time_col] (non-local) so each original series row recovers
    its own feature value.

In mlforecast/lag_transforms.py:

  • New hook _BaseLagTransform._compute_bucket_feature(bid_arr, ts_arr, y_arr)
    returning Optional[np.ndarray]. Default None falls back to the
    position-based GroupedArray transform, which is correct for unbounded
    (expanding) transforms because position-based expanding over
    timestamp-sorted observations equals timestamp-based expanding.
  • _RollingBase._compute_bucket_feature overrides with a RANGE-based loop:
    for each row at (bucket_id, T) collect observations with
    ts ∈ [T - lag - window_size + 1, T - lag], apply _window_stat(vals),
    broadcast back to tied timestamps via np.unique inverse. Subclasses
    (RollingStd, RollingMin, RollingMax, RollingQuantile) implement
    _window_stat.
  • RollingMean._compute_bucket_feature overrides again with an
    O(m log m) cumulative-sum / np.searchsorted fast path
    (m = unique timestamps in bucket).

_Seasonal_RollingBase and ExponentiallyWeightedMean deliberately keep the
default fallback for now — semantics for those under partition modes are
out of scope for this PR.

Diff scope (vs upstream/main)

mlforecast/core.py           |  +497
mlforecast/lag_transforms.py |  +236
tests/test_core.py           |  +568
tests/test_forecast.py       |   +78
tests/test_auto.py           |   +66

Tests

tests/test_core.py adds multi-series partition fixtures
(_make_partition_df(include_brand=True) and
_make_partition_future_df(include_brand=True), introducing series c
sharing a brand with b) and exercises:

  • test_group_partition_lag_transform and _predict — bounded rolling on
    multi-series buckets, including expected NaNs where min_samples is not
    satisfied.
  • test_global_partition_lag_transform and _predict — same-timestamp
    tie-breaking against the global bucket.
  • test_aggregated_partition_lag_transform_update — refit / update path
    with new and existing buckets.

Backward compatibility

  • No public API changes beyond the existing partition_by parameter.
  • local mode behavior is unchanged.
  • Pre-fix groupby / global_ outputs for bounded rolling windows were
    incorrect; numeric outputs change accordingly. Any user who pinned
    expectations on those values needs to re-baseline.

Verification

  1. uv run pytest tests/test_core.py -k partition — partition unit tests.
  2. uv run pytest tests/test_forecast.py tests/test_auto.py — surrounding
    regressions touched by the merge with main.
  3. uv run ruff check mlforecast tests — lint clean.
  4. Spot-check parity with SQL RANGE BETWEEN semantics on the multi-series
    fixture by comparing RollingMean(window=2) outputs to a hand-computed
    reference in tests/test_core.py::test_group_partition_lag_transform.

@codspeed-hq
Copy link
Copy Markdown

codspeed-hq Bot commented Apr 30, 2026

Merging this PR will not alter performance

✅ 12 untouched benchmarks


Comparing simonez-tuidi:feature/partition_by_window_fixes (4591401) with main (825b7c9)

Open in CodSpeed

Copy link
Copy Markdown
Contributor

@nasaul nasaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review — PR #636: [FEAT] Add partition_by support for lag transforms (with global_/groupby fixes)

Overview

Adds a partition_by argument to lag transforms allowing a single transform to compute statistics over arbitrary partition keys (e.g., promo, brand) rather than strictly per-series windows. Supports three modes: local (per-series), groupby (grouped series), and global_ (all series). Also fixes a pre-existing correctness bug where bounded rolling windows in groupby/global_ mode collapsed multi-series timestamps via sum, producing wrong results. Fix switches to RANGE semantics for non-local modes.


Correctness

  • The RANGE-semantics fix (keeping individual rows per observation instead of timestamp-level aggregation) is sound and well-motivated. However, there are two legitimate interpretations of "group rolling mean across series" and the PR only supports one — worth discussing (see Discussion section below).
  • local mode behavior is unchanged — safe.
  • _RollingBase._compute_bucket_feature correctly handles min_samples, NaN padding, and tied-timestamp broadcasting via np.unique inverse index.
  • RollingMean._compute_bucket_feature O(m log m) cumsum path is mathematically correct and consistent with the O(n×w) fallback.
  • RollingStd._window_stat uses ddof=1 — verified consistent with coreforecast's implementation.
  • Backup/restore in _backup now includes _partition_states and target_transforms — both necessary for correctness.

Issues

1. assert in production code — core.py:484 (high)

context = self._get_partition_context(step_df)
assert context is not None

assert statements are silently removed with python -O. Consider replacing with a proper guard:

if context is None:
    raise RuntimeError(
        "_get_partition_context returned None despite non-empty _partition_states. "
        "This is a bug — please report it."
    )

2. Silent wrong results for _Seasonal_RollingBase + partition_by in non-local mode (high)

Seasonal rolling and ExponentiallyWeightedMean fall back to the position-based path when used with partition_by in group/global_ mode. For bounded seasonal transforms this is semantically incorrect (same timestamp-collapse issue the PR fixes for _RollingBase). Currently produces wrong results with no indication to the user. Should emit a UserWarning at fit_transform time, or raise NotImplementedError until RANGE semantics are implemented for these classes.

3. Dead attribute _current_step_x (medium)

Set in _get_features_for_next_step (line 1584) and reset in _predict_setup (line 1640), but never read anywhere. Remove it unless it's intentionally exposed for external introspection (in which case, document it).

4. Redundant getattr defensive checks (medium)

_partition_states is unconditionally initialized in __init__, so getattr(self, "_partition_states", {}) (lines 419, 479, 510, 1613) is misleading — it implies the attribute might not exist. Replace with direct self._partition_states access.

5. deepcopy on every predict timestep in _compute_partition_features (medium)

fresh_tfm = copy.deepcopy(tfm)._set_core_tfm(tfm._get_configured_lag())
fresh_tfm.transform(core_ga)
updates[name] = fresh_tfm.update(core_ga)

A full deepcopy of the transform is made on every recursive predict step. The existing _group_states predict path avoids this pattern. For models with many partition transforms, this may noticeably affect predict latency.

6. Explicit pandas/polars split paths instead of narwhals (high)

The new partition code introduces isinstance(data, pd.DataFrame) branches throughout — _add_bucket_id, _lookup_bucket_ids, _ensure_partition_bucket_ids, _get_partition_context, _bucket_pos computation, and the join-back in _transform all have dual implementations. The existing codebase abstracts these operations via utilsforecast.processing (ufp) or narwhals; the new code largely bypasses that. This doubles the maintenance surface for every operation and makes the new code inconsistent with the rest of the codebase. The partition-specific helpers should be rewritten using narwhals or routed through ufp where equivalent operations already exist.

7. Polars path casts partition cols to pl.String on every lookup (low)

In _lookup_bucket_ids (lines 167–173), all partition key columns are cast to pl.String on every call to avoid categorical cache mismatches. For integer partition columns this is an avoidable cast on the hot predict path. A targeted fix (only cast when dtype differs) would be cleaner.

8. _has_aggregated_partition_tfms re-iterates all transforms on every call (low)

Called from _check_aligned_ends, predict, and update. Since transform configs are fixed post-fit, consider caching this as a bool at the end of _fit.

9. target_transforms backup is unrelated to this PR (low)

The target_transforms deepcopy added to _backup is a genuine correctness fix but touches non-partition code paths. It changes backup/restore behavior for all callers. Worth calling out explicitly in the PR description.


Tests

  • Coverage of fit/predict/update/cross-validation across pandas + polars is solid.
  • Hand-computed reference values with detailed inline comments make the tests easy to audit.

Issues

  • test_partition_combine_lag_transform_predict: Only checks col in features.columns and not features[col].isna().all(). No numerical verification — this is a smoke test, not a correctness test. Should verify specific values similar to the other predict tests.

  • Duplicate helpers: _make_partition_series and _make_partition_future are duplicated between test_auto.py and test_forecast.py. Move to conftest.py.

  • Test comment mismatch in test_global_partition_lag_transform: Comments describe position-based sequential logic (pos=3: rolling_mean([2, 20])) but RollingMean overrides _compute_bucket_feature with the RANGE path. Results happen to agree here, but the comments are misleading about the code path.

  • No test for _Seasonal_RollingBase + partition_by non-local mode: Either a correctness test or an explicit test verifying a warning/error is raised would prevent future confusion.

  • test_forecast.py integration tests only assert shape and not-NA: Both test_partition_by_cross_validation_refit_false and test_transfer_learning_with_partition_by_group_transform never verify predicted values. A stronger pattern: construct data where y_T equals the cross-series partition mean at T-1, fit a LinearRegression(fit_intercept=False) with lags=[] and only a RollingMean(window=1, lag=1, global_=True, partition_by=[...]) transform, and assert predictions match targets within tolerance. If the feature flows through the pipeline correctly the model learns coefficient=1 and predictions reconstruct the target exactly. RollingMean(window=1) is the right choice here — ExpandingMean would require constructing y_T as the full prior expanding mean which is harder to set analytically.


Minor / Style

  • _normalize_groupby_normalize_columns rename is a clean, non-breaking refactor.
  • _dedupe_preserve_order utility is clear and correctly placed at module level.
  • Combine.update_samples property was moved (not removed) — the reorganization is fine.
  • noqa: ARG001/ARG005 additions in test_auto.py are correct lint cleanups.
  • _get_partition_tfms mode dispatch logic (global > group > local) is easy to follow.

Discussion

RANGE semantics vs. aggregate-then-roll

The PR fixes non-local partition rolling by switching to RANGE semantics — at time T, collect all individual observations from matching series where ts ∈ [T - lag - window + 1, T - lag] and compute the statistic over those raw values. This is one valid interpretation, but there is a second equally valid one:

Aggregate-then-roll: first collapse all series at each timestamp to a single value (e.g., sum or mean), then roll over those per-timestamp aggregates. So if b=20 and c=15 are both at ds=2, they collapse to 17.5 (mean) or 35 (sum), and the window rolls over that aggregate.

Both semantics are useful depending on the problem:

  • RANGE is natural when you want the statistic over the raw distribution of individual observations — e.g., estimating a population mean from individual transactions.
  • Aggregate-then-roll is natural when you want to roll over a group-level signal — e.g., total promo sales per period, then a rolling average of those period totals.

The original (buggy) code happened to implement a broken version of aggregate-then-roll (sum collapse). Users who were relying on that structure — even if the values were wrong — now get silently different results with no migration path.

A suggested addition would be an agg parameter on partition transforms:

RollingMean(
    window=2,
    groupby=["brand"],
    partition_by=["promo"],
    agg=None,    # None = RANGE semantics (current behavior, default)
                 # "sum" or "mean" = collapse same-timestamp obs before rolling
)

agg=None preserves the current fix. agg="sum" or agg="mean" enables the aggregate-then-roll path with correct semantics. This would give users explicit control over which interpretation they want rather than having one silently chosen for them.


Summary

The correctness fix (RANGE semantics) is solid and the feature design is well-structured. The test suite is thorough with detailed hand-verified expectations. Main items before merge:

Priority Issue
High Replace assert context is not None with raise
High Warn or raise on _Seasonal_RollingBase / ExponentiallyWeightedMean + partition_by in non-local mode
High Rewrite partition helpers using narwhals / ufp instead of explicit pandas/polars splits
Medium Remove dead _current_step_x attribute
Medium Replace redundant getattr(self, "_partition_states", {}) with self._partition_states
Medium Investigate deepcopy cost per predict step in _compute_partition_features
Low Move shared test helpers to conftest.py
Low Strengthen test_partition_combine_lag_transform_predict with numerical assertions
Low Fix test comments in test_global_partition_lag_transform to describe the RANGE path

@nasaul
Copy link
Copy Markdown
Contributor

nasaul commented May 5, 2026

Thanks for this @simonez-tuidi this is a great work and a great contribution. I think that we should address the discussion section first before implementing any changes.

@simonez-tuidi
Copy link
Copy Markdown
Contributor Author

simonez-tuidi commented May 6, 2026

Hi @nasaul thanks for your thorough review.

I'll address the issues you raised in separate commits, issue 5 might require quite a bit of reworking to allow incremental updates which would improve latency but might be non-trivial.

Regarding the discussion points, I would note that the review assumes that the previous implementation was part of the "established" code, instead of a first iteration not yet merged. I believe the previous logic was fundamentally flawed and would not serve much purpose in real-world applications. Let's take the example below, where two series a and b have their y-values displayed where they belong to the partition, and are left empty where they don't belong to the partition.

ts a b
1 10 --
2 15 --
3 12 20
4 -- 23

In the case of a Sum aggregation, we would expect the result to be the sum of all the values that belong to the partition, in this case 80. In the case of a mean aggregation, the same logic would need to be assumed, as it is simply the previous sum, divided by the number of observations, i.e. 16. If we first collapse the values at the timestamp level and then aggregate again, we are essentially introducing a double aggregation, first taking the mean at each timestamp, then taking the mean of all the aggregated values at each timestamp.

I've just realized that this logic is what is currently implemented in the groupby-only transformations as introduced in #551, for example this test checks the RollingMean in global mode with a two-timestamp window lagged by 1. The expected results are nan, nan, 16.5, 27.5: for the first timestamp we leave null as there is no previous observations, for the second one also null as there is only one previous timestamp (potentially this could be still allowed as we did not specify a min_samples). The third result is the sum of the per-timestamp means, so (1+10)/2+(2+20)/2=5.5+11=16.5. Same applies to the next test result. I would argue that the correct expected result would be the avg([1,2,10,20])= 8.25, but perhaps the interpretation here is that the RollingMean transformation with groupby applies the mean to all values in the group so that we obtain a single aggregated timeseries with one value per timestamp, and then those values are summed. This is essentially what @nasaul describes as having a agg=[mean, sum, None] parameter added, in which case the current implementation above defaults to agg=sum. However, I'm not sure how this behavior would have meaningful applications, since using increasingly larger windows would make the resulting mean proportional to the size of the window, which kind of defies the point of taking the mean in the first place. And would also make the results of the ExpandingMean completely useless in groupby mode since for each timestamp we would have completely different orders of magnitude, as the window size increases.

Now, back to the issue at hand for this PR, unfortunately working out how to define the agg parameter is not so obvious. In the groupby-only mode, applying the RollingMean(groupby=["brand"], agg="mean") produces the correct expected result since aggregating per-timestamp first and then across timestamps after that is a commutative operation. In the partition_by case though, aggregating per-timestamp first and then across timestamps would not return the same result as aggregating across all observations independent of their timestamp, because not all timestamps necessarily contain the same number of observations - like in the example I made above.

My argument is that these transformations should follow the same logic as SQL queries, and this is the equivalence I have in mind:

  • RollingMean at the unique series level (default): AVG(y) OVER(PARTITION BY unique_id RANGE BETWEEN 1 PRECEDING ...)
  • RollingMean with groupby=["brand"]: AVG(y) OVER(PARTITION BY brand RANGE BETWEEN 1 PRECEDING ...)
  • RollingMean with partition_by=["promo"]: AVG(y) OVER(PARTITION BY unique_id, promo RANGE BETWEEN 1 PRECEDING ...)
  • RollingMean with partition_by=["promo"] and groupby=["brand"]: AVG(y) OVER(PARTITION BY brand, promo RANGE BETWEEN 1 PRECEDING ...)

In all of the above, the transformation finds all the relevant records that belong to the window and takes the mean. In the Nixtla-verse, all transforms are already essentially a window function with partition by, like in the first example where the partition is defined over unique_id - which is the default behavior of every lag-transform and can be rewritten as a simple SQL query. When applying a groupby transform we're essentially expanding it to all series that share the same static value for that group so instead of PARTITION BY unique_id we use PARTITION BY brand. Then again, when applying a partition_by transform we're adding an extra level to the partition within that series defined by unique_id - such as promo, so it becomes PARTITION BY unique_id, promo. Then finally, if we partition by promo and group by brand, we are taking all the records that belong to the same brand, and selecting those with the same value of promo as the current observation we're working on, so it's PARTITION BY brand, promo.

I think this design makes it pretty intuitive to know what to expect from each transform in all of the cases above.

The current path defined for the groupby transforms can be kept to allow the current behavior alongside the one I'm describing, introducing the agg flag as was proposed above, which would improve on the current state IMO without breaking changes. But unfortunately, doing so would then conflict with the design for the partition_by transforms as it would make all transforms first collapse at the timestamp level and then aggregate, which would return incorrect results.

Note: unfortunately the other test introduced in that PR is pretty uninformative, as there is only one series per group and therefore the results coincide with the non-grouped example.

@simonez-tuidi
Copy link
Copy Markdown
Contributor Author

simonez-tuidi commented May 6, 2026

@nasaul & @janrth I've expanded on the points above in a new issue #640 - keen to get your thoughts on this

I also got fact-checked by AI - my interpretation that the existing groupby implementation was aggregating then summing was incorrect, instead the global_ and groupby transforms first sum all values across series per each timestamp and then aggregate, which might be easier to work around

@simonez-tuidi simonez-tuidi marked this pull request as draft May 8, 2026 13:28
@simonez-tuidi
Copy link
Copy Markdown
Contributor Author

Major changes in #641 will influence how this will be implemented and might require a partial or complete rework -> converting to draft for now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants