Skip to content

Commit 4be221c

Browse files
authored
Merge pull request #103 from m9o8/ENH-spatio-temporal-splits
Feature: Flexible Spatio-Temporal Splits and Dynamic Multi-Group Visualization
2 parents a497ad4 + b8ed502 commit 4be221c

6 files changed

Lines changed: 728 additions & 935 deletions

File tree

README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,30 @@ for train_idx, test_idx in splits:
4848
print("Test:"); display(panel_data.loc[test_idx])
4949
```
5050

51+
### Spatio-Temporal Cross-Validation
52+
53+
panelsplit can also handle combined spatio-temporal holdouts by factoring in entity hierarchies (e.g., states or cities) to prevent cluster-level leakage. You can simultaneously validate on unobserved time periods *and* structurally unobserved groups:
54+
55+
```python
56+
from sklearn.model_selection import StratifiedGroupKFold
57+
58+
# Create spatial splits that evaluate cluster-level combinations robustly:
59+
panel_split = PanelSplit(
60+
periods=panel_data.year,
61+
n_splits=2,
62+
groups=panel_data["country_id"],
63+
group_splitter=StratifiedGroupKFold(n_splits=3) # Use any valid Scikit-Learn group methodology!
64+
)
65+
66+
# You can also pass arbitrarily nested multi-column groups!
67+
# PanelSplit will internally flatten them into a single composite group identifier for KFold slicing.
68+
# e.g., groups = panel_data[["country_id", "city_id"]]
69+
70+
# Lazy Evaluation securely propagates X and y through the StratifiedGroupKFold!
71+
splits = panel_split.split(X=panel_data, y=panel_data["y"])
72+
# Yields 6 total sub-splits (2 temporal cuts x 3 spatial stratified holds)!
73+
```
74+
5175
For more examples and detailed usage instructions, refer to the [examples](examples) directory in this repository. Also feel free to check out [an introductory article on panelsplit](https://towardsdatascience.com/how-to-cross-validate-your-panel-data-in-python-9ad981ddd043).
5276

5377
## Background

examples/An introduction to PanelSplit.ipynb

Lines changed: 349 additions & 898 deletions
Large diffs are not rendered by default.

panelsplit/cross_validation.py

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import warnings
2-
from typing import Optional, Union, TYPE_CHECKING, Any
3-
from numpy.typing import NDArray
4-
from narwhals.typing import IntoDataFrame, IntoSeries
5-
from .utils.typing import ArrayLike, CVIndices
2+
from typing import TYPE_CHECKING, Any, Optional, Union
63

74
import narwhals as nw
85
import numpy as np
9-
from sklearn.model_selection import TimeSeriesSplit
6+
from narwhals.typing import IntoDataFrame, IntoSeries
7+
from numpy.typing import NDArray
8+
from sklearn.model_selection import GroupKFold, TimeSeriesSplit
109

10+
from .utils.typing import ArrayLike, CVIndices
1111
from .utils.validation import (
1212
_safe_indexing,
1313
_to_numpy_array,
14+
check_groups,
1415
check_labels,
1516
check_periods,
1617
get_index_or_col_from_df,
@@ -64,6 +65,11 @@ class PanelSplit:
6465
include_train_in_test : bool
6566
Whether to include all training sets in their respective test sets. If set to
6667
True, overrides ``include_first_train_in_test``. Default is False.
68+
groups : Optional[Any]
69+
A 1D/2D array or DataFrame of spatial groupings/IDs for implementing spatio-temporal holdouts.
70+
If provided, tests will simultaneously cross-validate over spatial nested structures using GroupKFold. Default is None.
71+
group_splitter : Optional[Any]
72+
A scikit-learn compatible splitter (e.g., `StratifiedGroupKFold(n_splits=3)`) used to build spatial splits natively. Default is `GroupKFold(n_splits=2)`.
6773
6874
Attributes
6975
----------
@@ -101,6 +107,8 @@ def __init__(
101107
max_train_size: Optional[int] = None,
102108
include_first_train_in_test: bool = False,
103109
include_train_in_test: bool = False,
110+
groups: Optional[Any] = None,
111+
group_splitter: Optional[Any] = None,
104112
) -> None:
105113
periods = check_periods(periods)
106114

@@ -130,11 +138,42 @@ def __init__(
130138
self._include_first_train_in_test = include_first_train_in_test
131139
else:
132140
self._include_first_train_in_test = True
141+
133142
self._u_periods_cv = self._split_unique_periods(indices, unique_periods_array)
134143
self._periods = _to_numpy_array(periods)
135144
self._snapshots = _to_numpy_array(snapshots) if snapshots is not None else None
145+
146+
self._groups = check_groups(groups) if groups is not None else None
147+
148+
if self._groups is not None:
149+
if group_splitter is None:
150+
self._group_splitter = GroupKFold(n_splits=2)
151+
else:
152+
self._group_splitter = group_splitter
153+
if len(self._groups) != len(self._periods):
154+
raise ValueError(
155+
f"groups size ({len(self._groups)}) does not match periods size ({len(self._periods)})"
156+
)
157+
else:
158+
self._group_splitter = None
159+
136160
self.n_splits = n_splits
137-
self.train_test_splits = self._gen_splits()
161+
if self._groups is not None:
162+
self.n_splits = n_splits * self._group_splitter.get_n_splits() # type: ignore[union-attr]
163+
164+
self._temporal_splits = self._gen_splits()
165+
166+
self.train_test_splits = self._temporal_splits
167+
if self._groups is not None:
168+
try:
169+
self.train_test_splits = self._compute_spatio_temporal_splits(
170+
X=None, y=None
171+
)
172+
except Exception as e:
173+
warnings.warn(
174+
f"Could not cleanly pre-generate spatial splits in __init__: {e}. Passing X and y to split() natively at runtime."
175+
)
176+
self.train_test_splits = []
138177

139178
def _split_unique_periods(self, indices: Any, unique_periods: NDArray) -> CVIndices:
140179
"""
@@ -200,6 +239,30 @@ def _gen_splits(self) -> CVIndices:
200239

201240
return train_test_splits
202241

242+
def _compute_spatio_temporal_splits(
243+
self,
244+
X: Optional[ArrayLike] = None,
245+
y: Optional[ArrayLike] = None,
246+
) -> CVIndices:
247+
"""
248+
Intersect internal time cuts with lazy spatial matrices natively.
249+
"""
250+
spatio_temporal_splits = []
251+
dummy_X = np.zeros(len(self._periods)) if X is None else X
252+
253+
for train_indices, test_indices in self._temporal_splits:
254+
for sp_train, sp_test in self._group_splitter.split(
255+
dummy_X, y, groups=self._groups
256+
):
257+
final_train_indices = np.intersect1d(
258+
train_indices, sp_train, assume_unique=True
259+
)
260+
final_test_indices = np.intersect1d(
261+
test_indices, sp_test, assume_unique=True
262+
)
263+
spatio_temporal_splits.append((final_train_indices, final_test_indices))
264+
return spatio_temporal_splits
265+
203266
def split(
204267
self,
205268
X: Optional[ArrayLike] = None,
@@ -212,9 +275,9 @@ def split(
212275
Parameters
213276
----------
214277
X : Optional[ArrayLike]
215-
Ignored; included for compatibility.
278+
Data matrix to base splits on (used by logic like StratifiedGroupKFold).
216279
y : Optional[ArrayLike]
217-
Ignored; included for compatibility.
280+
Target variables to base splits on (used by logic like StratifiedGroupKFold).
218281
groups : Optional[np.ndarray]
219282
Ignored; included for compatibility.
220283
@@ -241,7 +304,18 @@ def split(
241304
... print("Train:", train, "Test:", test)
242305
Train: [0 1] Test: [2]
243306
"""
244-
return self.train_test_splits
307+
if self._groups is None:
308+
return self.train_test_splits # type: ignore[return-value]
309+
310+
if X is not None or y is not None:
311+
self.train_test_splits = self._compute_spatio_temporal_splits(X=X, y=y)
312+
313+
if not self.train_test_splits:
314+
raise ValueError(
315+
"train_test_splits is uncomputed. Your selected group_splitter requires passing X and y explicitly to the .split() method to calculate strata boundaries."
316+
)
317+
318+
return self.train_test_splits # type: ignore[return-value]
245319

246320
def get_n_splits(
247321
self,

panelsplit/plot.py

Lines changed: 121 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,49 @@
1+
from typing import Optional, Tuple, Union
2+
13
import matplotlib.pyplot as plt
4+
import numpy as np
5+
26
from .cross_validation import PanelSplit
3-
from typing import Tuple, Optional
7+
from .utils.typing import ArrayLike
48

59

610
def plot_splits(
7-
panel_split: PanelSplit, show: bool = True
8-
) -> Optional[Tuple[plt.Figure, plt.Axes]]:
11+
panel_split: PanelSplit,
12+
X: Optional[ArrayLike] = None,
13+
y: Optional[ArrayLike] = None,
14+
n_groups: int = 2,
15+
show: bool = True,
16+
) -> Optional[Tuple[plt.Figure, Union[plt.Axes, np.ndarray]]]:
917
"""
1018
Visualize time series cross-validation splits using a scatter plot.
1119
1220
Each split is plotted on a separate horizontal line: blue markers represent training indices
1321
and red markers represent test indices.
1422
23+
If the PanelSplit instance uses groups for spatio-temporal holdouts, this will create
24+
`n_groups` subplots, each visualizing the train/test periods for an individual, randomly
25+
sampled or sequential subset of the unique groups.
26+
1527
Parameters
1628
----------
1729
panel_split : PanelSplit
1830
An instance of PanelSplit containing the cross-validation splits.
19-
It must have an attribute `_u_periods_cv`, which should be an iterable of tuples,
20-
each in the form `(train_index, test_index)`. Both `train_index` and `test_index`
21-
are array-like collections of period indices.
31+
X : ArrayLike, optional
32+
Features dataset. Needed if the group_splitter relies on X for boundary calculation.
33+
y : ArrayLike, optional
34+
Target dataset. Needed if the group_splitter relies on y for boundary calculation.
35+
n_groups : int, default=2
36+
The number of subgroups to plot side-by-side if groups are used.
2237
show : bool, default=True
2338
If True, the plot is immediately displayed using `plt.show()`.
24-
If False, the function returns the matplotlib Figure and Axes objects for further customization.
39+
If False, the function returns the matplotlib Figure and Axes objects.
2540
2641
Returns
2742
-------
28-
Optional[Tuple[plt.Figure, plt.Axes]]
43+
Optional[Tuple[plt.Figure, Union[plt.Axes, np.ndarray]]]
2944
If `show` is False, returns a tuple `(fig, ax)` where `fig` is the matplotlib Figure
30-
and `ax` is the Axes object. If `show` is True, the plot is displayed and the function returns None.
45+
and `ax` is the Axes object (or an array of Axes objects if groups are used).
46+
If `show` is True, the plot is displayed and the function returns None.
3147
3248
Examples
3349
--------
@@ -36,16 +52,15 @@ def plot_splits(
3652
>>> import matplotlib.pyplot as plt
3753
>>> periods = np.array([1, 2, 3, 4, 5, 6])
3854
>>> ps = PanelSplit(periods, n_splits=3)
39-
>>> # Display the plot immediately
4055
>>> plot_splits(ps)
41-
>>> # Or return the Figure and Axes for customization
42-
>>> fig, ax = plot_splits(ps, show=False)
43-
>>> ax.set_title("A custom plot of cross-validation splits")
44-
>>> plt.show()
4556
"""
57+
58+
if panel_split._groups is not None:
59+
return _plot_group_subplots(panel_split, X=X, y=y, n_groups=n_groups, show=show)
60+
4661
split_output = panel_split._u_periods_cv
4762
splits = len(split_output)
48-
fig, ax = plt.subplots()
63+
fig, ax = plt.subplots(figsize=(10, 5))
4964

5065
for i, (train_index, test_index) in enumerate(split_output):
5166
ax.scatter(train_index, [i] * len(train_index), color="blue", marker=".", s=50)
@@ -54,13 +69,100 @@ def plot_splits(
5469
ax.set_xlabel("Periods")
5570
ax.set_ylabel("Split")
5671
ax.set_title("Cross-validation splits")
57-
ax.set_yticks(range(splits)) # Set the number of ticks on the y-axis
58-
ax.set_yticklabels(
59-
[f"{i}" for i in range(splits)]
60-
) # Set custom labels for the y-axis
72+
ax.set_yticks(range(splits))
73+
ax.set_yticklabels([f"{i}" for i in range(splits)])
6174

6275
if show:
6376
plt.show()
6477
return None
6578
else:
6679
return fig, ax
80+
81+
82+
def _plot_group_subplots(
83+
panel_split: PanelSplit,
84+
X: Optional[ArrayLike],
85+
y: Optional[ArrayLike],
86+
n_groups: int,
87+
show: bool,
88+
) -> Optional[Tuple[plt.Figure, Union[plt.Axes, np.ndarray]]]:
89+
unique_groups = np.unique(np.asarray(panel_split._groups))
90+
selected_groups = unique_groups[:n_groups]
91+
actual_groups = len(selected_groups)
92+
93+
fig, axes = plt.subplots(
94+
ncols=actual_groups,
95+
sharey=True,
96+
sharex=True,
97+
figsize=(min(16, 6 * actual_groups), 6),
98+
)
99+
100+
if actual_groups == 1:
101+
axes = np.array([axes])
102+
103+
splits = panel_split.split(X=X, y=y)
104+
n_total_splits = len(splits)
105+
n_temporal_splits = (
106+
len(panel_split._temporal_splits)
107+
if hasattr(panel_split, "_temporal_splits")
108+
else n_total_splits
109+
)
110+
n_spatial_splits = (
111+
n_total_splits // n_temporal_splits if n_temporal_splits > 0 else 1
112+
)
113+
114+
for ax_idx, group in enumerate(selected_groups):
115+
ax = axes[ax_idx]
116+
group_mask = panel_split._groups == group
117+
118+
group_indices = np.where(group_mask)[0]
119+
for i, (train_indices, test_indices) in enumerate(splits):
120+
group_train_indices = np.intersect1d(
121+
train_indices, group_indices, assume_unique=True
122+
)
123+
group_test_indices = np.intersect1d(
124+
test_indices, group_indices, assume_unique=True
125+
)
126+
127+
train_periods = panel_split._periods[group_train_indices]
128+
test_periods = panel_split._periods[group_test_indices]
129+
130+
temporal_idx = i // n_spatial_splits
131+
132+
if len(train_periods) > 0:
133+
ax.scatter(
134+
train_periods,
135+
[temporal_idx] * len(train_periods),
136+
color="blue",
137+
marker=".",
138+
s=50,
139+
)
140+
if len(test_periods) > 0:
141+
ax.scatter(
142+
test_periods,
143+
[temporal_idx] * len(test_periods),
144+
color="red",
145+
marker=".",
146+
s=50,
147+
)
148+
149+
ax.set_xlabel("Periods")
150+
if ax_idx == 0:
151+
ax.set_ylabel("Split")
152+
ax.set_title(f"Cross-validation splits: {group}")
153+
ax.set_yticks(range(n_temporal_splits))
154+
ax.set_yticklabels([f"{j}" for j in range(n_temporal_splits)])
155+
156+
total_groups = len(unique_groups)
157+
remaining_groups = total_groups - actual_groups
158+
if remaining_groups > 0:
159+
fig.suptitle(
160+
f"Cross-validation splits: {remaining_groups} other groups not plotted"
161+
)
162+
163+
plt.tight_layout()
164+
if show:
165+
plt.show()
166+
return None
167+
else:
168+
return fig, (axes if actual_groups > 1 else axes[0])

0 commit comments

Comments
 (0)