Skip to content

Commit 7a79bfe

Browse files
committed
Refactoring plot_horizon_forecast to handle preview and eval modes
1 parent 08fad5d commit 7a79bfe

2 files changed

Lines changed: 73 additions & 26 deletions

File tree

content/python_files/feature_engineering.py

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def iqr(col, *, window_size: int):
339339

340340
# %% [markdown]
341341
#
342-
# ## Remark lagged features engineering and system lag
342+
# ## Important remark about lagged features engineering and system lag
343343
#
344344
# When working with historical data, we often have access to all the past
345345
# measurements in the dataset. However, when we want to use the lagged features
@@ -475,7 +475,7 @@ def define_prediction_time_range(prediction_start_time, prediction_end_time):
475475

476476
prediction_time = define_prediction_time_range(
477477
prediction_start_time, prediction_end_time
478-
)
478+
).skb.subsample(n=1000, how="head")
479479
prediction_time
480480

481481

@@ -668,10 +668,10 @@ def build_targets(prediction_time, electricity, horizons):
668668
altair.Chart(
669669
pl.concat(
670670
[
671-
targets.skb.eval(),
671+
targets.skb.preview(),
672672
hgbr_predictions.rename(
673673
{target_column_name: predicted_target_column_name}
674-
).skb.eval(),
674+
).skb.preview(),
675675
],
676676
how="horizontal",
677677
).tail(24 * 7)
@@ -742,7 +742,14 @@ def build_targets(prediction_time, electricity, horizons):
742742

743743

744744
# %%
745-
def collect_cv_predictions(pipelines, cv_splitter, predictions, prediction_time):
745+
def collect_cv_predictions(
746+
pipelines,
747+
cv_splitter,
748+
predictions,
749+
prediction_time,
750+
method_name="predict",
751+
method_kwargs=None,
752+
):
746753
index_generator = cv_splitter.split(prediction_time.skb.eval())
747754

748755
def splitter(X, y, index_generator):
@@ -752,9 +759,13 @@ def splitter(X, y, index_generator):
752759
return X[train_idx], X[test_idx], y[train_idx], y[test_idx]
753760

754761
results = []
762+
if method_kwargs is None:
763+
method_kwargs = {}
764+
755765
for (_, test_idx), pipeline in zip(
756766
cv_splitter.split(prediction_time.skb.eval()), pipelines
757767
):
768+
method = getattr(pipeline, method_name)
758769
split = predictions.skb.train_test_split(
759770
predictions.skb.get_data(),
760771
splitter=splitter,
@@ -765,7 +776,7 @@ def splitter(X, y, index_generator):
765776
{
766777
"prediction_time": prediction_time.skb.eval()[test_idx],
767778
"load_mw": split["y_test"],
768-
"predicted_load_mw": pipeline.predict(split["test"]),
779+
"predicted_load_mw": method(split["test"], **method_kwargs),
769780
}
770781
)
771782
)
@@ -940,10 +951,10 @@ def splitter(X, y, index_generator):
940951
altair.Chart(
941952
pl.concat(
942953
[
943-
targets.skb.eval(),
954+
targets.skb.preview(),
944955
predictions_ridge.rename(
945956
{target_column_name: predicted_target_column_name}
946-
).skb.eval(),
957+
).skb.preview(),
947958
],
948959
how="horizontal",
949960
).tail(24 * 7)
@@ -1114,25 +1125,25 @@ def splitter(X, y, index_generator):
11141125
)
11151126

11161127
# %%
1117-
plot_at_time = datetime.datetime(2025, 5, 24, 0, 0, tzinfo=datetime.timezone.utc)
1128+
plot_at_time = datetime.datetime(2021, 4, 19, 0, 0, tzinfo=datetime.timezone.utc)
11181129
historical_timedelta = datetime.timedelta(hours=24 * 5)
11191130
plot_horizon_forecast(
11201131
targets,
11211132
named_predictions,
11221133
plot_at_time,
11231134
historical_timedelta,
11241135
target_column_name_pattern,
1125-
)
1136+
).skb.preview()
11261137

11271138
# %%
1128-
plot_at_time = datetime.datetime(2025, 5, 25, 0, 0, tzinfo=datetime.timezone.utc)
1139+
plot_at_time = datetime.datetime(2021, 4, 20, 0, 0, tzinfo=datetime.timezone.utc)
11291140
plot_horizon_forecast(
11301141
targets,
11311142
named_predictions,
11321143
plot_at_time,
11331144
historical_timedelta,
11341145
target_column_name_pattern,
1135-
)
1146+
).skb.preview()
11361147

11371148
# %%
11381149
from sklearn.metrics import r2_score
@@ -1208,6 +1219,7 @@ def scoring(regressor, X, y):
12081219
# TODO: Exercise using RandomForestRegressor
12091220
from sklearn.ensemble import RandomForestRegressor
12101221

1222+
12111223
multioutput_predictions_rf = features_with_dropped_cols.skb.apply(
12121224
RandomForestRegressor(min_samples_leaf=30, random_state=0, n_jobs=-1),
12131225
y=targets.skb.drop(cols=["prediction_time", "load_mw"]).skb.mark_as_y(),
@@ -1603,7 +1615,7 @@ def binned_coverage(y_true_folds, y_quantile_low, y_quantile_high, n_bins=10):
16031615

16041616
# %% [markdown]
16051617
#
1606-
# ## Reliability diagram for quantile regression
1618+
# ## Reliability diagrams for quantile regression
16071619

16081620
# %%
16091621
plot_reliability_diagram(
@@ -1674,6 +1686,7 @@ def fit(self, X, y):
16741686
strategy="quantile",
16751687
subsample=200_000,
16761688
encode="ordinal",
1689+
quantile_method="averaged_inverted_cdf",
16771690
random_state=random_state,
16781691
)
16791692

@@ -1716,3 +1729,39 @@ def predict_quantiles(self, X, quantiles=(0.05, 0.5, 0.95)):
17161729

17171730
def predict(self, X):
17181731
return self.predict_quantiles(X, self.quantile).ravel()
1732+
1733+
1734+
# %%
1735+
from sklearn.ensemble import HistGradientBoostingClassifier
1736+
from threadpoolctl import threadpool_limits
1737+
1738+
1739+
# with threadpool_limits(1):
1740+
if True:
1741+
predictions_bqr = features_with_dropped_cols.skb.apply(
1742+
BinnedQuantileRegressor(
1743+
RandomForestClassifier(
1744+
n_jobs=-1, n_estimators=200, min_samples_leaf=5, random_state=0
1745+
),
1746+
# HistGradientBoostingClassifier(random_state=0),
1747+
n_bins=30,
1748+
),
1749+
y=target,
1750+
)
1751+
1752+
# %%
1753+
predictions_bqr
1754+
1755+
# %%
1756+
cv_results_bqr = predictions_bqr.skb.cross_validate(
1757+
cv=ts_cv_5,
1758+
scoring={
1759+
"d2_pinball": make_scorer(d2_pinball_score, alpha=0.5),
1760+
"MAPE": make_scorer(mean_absolute_percentage_error),
1761+
},
1762+
return_pipeline=True,
1763+
verbose=1,
1764+
n_jobs=-1,
1765+
)
1766+
cv_results_bqr
1767+
# %%

content/python_files/tutorial_helpers.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import polars as pl
5+
import polars.selectors as cs
56
import altair
67
import skrub
78

@@ -483,7 +484,7 @@ def plot_binned_residuals(cv_predictions, by="hour"):
483484
color="independent"
484485
)
485486

486-
487+
@skrub.deferred
487488
def plot_horizon_forecast(
488489
targets,
489490
named_predictions,
@@ -511,25 +512,23 @@ def plot_horizon_forecast(
511512
altair.Chart
512513
A chart with the true target and the forecast values for different horizons.
513514
"""
514-
merged_data = targets.skb.select(cols=["prediction_time", "load_mw"]).skb.concat(
515-
[named_predictions], axis=1
515+
merged_data = pl.concat(
516+
[targets.select(pl.col("prediction_time"), pl.col("load_mw")), named_predictions],
517+
how="horizontal",
516518
)
517519
start_time = plot_at_time - historical_timedelta
518520
end_time = plot_at_time + datetime.timedelta(
519-
hours=named_predictions.skb.eval().shape[1]
521+
hours=named_predictions.shape[1]
520522
)
521523
true_values_past = merged_data.filter(
522524
pl.col("prediction_time").is_between(start_time, plot_at_time, closed="both")
523525
).rename({"load_mw": "Past true load"})
524526
true_values_future = merged_data.filter(
525-
pl.col("prediction_time").is_between(plot_at_time, end_time, closed="both")
527+
pl.col("prediction_time").is_between(plot_at_time, end_time, closed="right")
526528
).rename({"load_mw": "Future true load"})
527529
predicted_record = (
528-
merged_data.skb.select(
529-
cols=skrub.selectors.filter_names(str.startswith, "predict")
530-
)
530+
merged_data.select(cs.starts_with("predict"))
531531
.row(by_predicate=pl.col("prediction_time") == plot_at_time, named=True)
532-
.skb.eval()
533532
)
534533
forecast_values = pl.DataFrame(
535534
{
@@ -541,15 +540,14 @@ def plot_horizon_forecast(
541540
}
542541
for horizon in range(1, len(predicted_record))
543542
)
544-
545543
true_values_past_chart = (
546-
altair.Chart(true_values_past.skb.eval())
544+
altair.Chart(true_values_past)
547545
.transform_fold(["Past true load"])
548546
.mark_line(tooltip=True)
549547
.encode(x="prediction_time:T", y="Past true load:Q", color="key:N")
550548
)
551549
true_values_future_chart = (
552-
altair.Chart(true_values_future.skb.eval())
550+
altair.Chart(true_values_future)
553551
.transform_fold(["Future true load"])
554552
.mark_line(tooltip=True)
555553
.encode(x="prediction_time:T", y="Future true load:Q", color="key:N")
@@ -562,4 +560,4 @@ def plot_horizon_forecast(
562560
)
563561
return (
564562
true_values_past_chart + true_values_future_chart + forecast_values_chart
565-
).interactive()
563+
).interactive()

0 commit comments

Comments
 (0)