Skip to content

Commit d7f0e7c

Browse files
committed
Hide the code of the collect_cv_predictions hack in the helpers
1 parent d59c10a commit d7f0e7c

2 files changed

Lines changed: 37 additions & 36 deletions

File tree

content/python_files/feature_engineering.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
plot_residuals_vs_predicted,
6868
plot_binned_residuals,
6969
plot_horizon_forecast,
70+
collect_cv_predictions,
7071
)
7172

7273
# Ignore warnings from pkg_resources triggered by Python 3.13's multiprocessing.
@@ -773,42 +774,6 @@ def build_targets(prediction_time, electricity, horizons):
773774
# We further analyze our cross-validated model by collecting the predictions on each
774775
# split.
775776

776-
# %%
777-
def collect_cv_predictions(
778-
pipelines,
779-
cv_splitter,
780-
predictions,
781-
prediction_time,
782-
):
783-
index_generator = cv_splitter.split(prediction_time.skb.eval())
784-
785-
def splitter(X, y, index_generator):
786-
"""Workaround to transform a scikit-learn splitter into a function understood
787-
by `skrub.train_test_split`."""
788-
train_idx, test_idx = next(index_generator)
789-
return X[train_idx], X[test_idx], y[train_idx], y[test_idx]
790-
791-
results = []
792-
793-
for (_, test_idx), pipeline in zip(
794-
cv_splitter.split(prediction_time.skb.eval()), pipelines
795-
):
796-
split = predictions.skb.train_test_split(
797-
predictions.skb.get_data(),
798-
splitter=splitter,
799-
index_generator=index_generator,
800-
)
801-
results.append(
802-
pl.DataFrame(
803-
{
804-
"prediction_time": prediction_time.skb.eval()[test_idx],
805-
"load_mw": split["y_test"],
806-
"predicted_load_mw": pipeline.predict(split["test"]),
807-
}
808-
)
809-
)
810-
return results
811-
812777

813778
# %%
814779
hgbr_cv_predictions = collect_cv_predictions(

content/python_files/tutorial_helpers.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,3 +689,39 @@ def binned_coverage(y_true_folds, y_quantile_low, y_quantile_high, n_bins=10):
689689
)
690690

691691
return pd.DataFrame(results)
692+
693+
694+
def collect_cv_predictions(
695+
pipelines,
696+
cv_splitter,
697+
predictions,
698+
prediction_time,
699+
):
700+
index_generator = cv_splitter.split(prediction_time.skb.eval())
701+
702+
def splitter(X, y, index_generator):
703+
"""Workaround to transform a scikit-learn splitter into a function understood
704+
by `skrub.train_test_split`."""
705+
train_idx, test_idx = next(index_generator)
706+
return X[train_idx], X[test_idx], y[train_idx], y[test_idx]
707+
708+
results = []
709+
710+
for (_, test_idx), pipeline in zip(
711+
cv_splitter.split(prediction_time.skb.eval()), pipelines
712+
):
713+
split = predictions.skb.train_test_split(
714+
predictions.skb.get_data(),
715+
splitter=splitter,
716+
index_generator=index_generator,
717+
)
718+
results.append(
719+
pl.DataFrame(
720+
{
721+
"prediction_time": prediction_time.skb.eval()[test_idx],
722+
"load_mw": split["y_test"],
723+
"predicted_load_mw": pipeline.predict(split["test"]),
724+
}
725+
)
726+
)
727+
return results

0 commit comments

Comments
 (0)