Skip to content

Commit b36fc77

Browse files
committed
Add residuals vs predicted
1 parent 1ac72ac commit b36fc77

1 file changed

Lines changed: 82 additions & 0 deletions

File tree

content/python_files/feature_engineering.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -950,6 +950,88 @@ def plot_reliability_diagram(cv_predictions, n_bins=10):
950950
title="Reliability diagram from cross-validation predictions"
951951
)
952952

953+
# %%
954+
def plot_residuals_vs_predicted(cv_predictions):
955+
"""Plot residuals vs predicted values scatter plot for all CV folds."""
956+
all_scatter_plots = []
957+
958+
for i, cv_prediction in enumerate(cv_predictions):
959+
# Get date range for this CV fold
960+
min_date = cv_prediction["prediction_time"].min().strftime("%Y-%m-%d")
961+
max_date = cv_prediction["prediction_time"].max().strftime("%Y-%m-%d")
962+
fold_label = f"#{i+1} - {min_date} to {max_date}"
963+
964+
# Calculate residuals
965+
residuals_data = cv_prediction.with_columns(
966+
[(pl.col("predicted_load_mw") - pl.col("load_mw")).alias("residual")]
967+
).with_columns([pl.lit(fold_label).alias("fold")])
968+
969+
# Create scatter plot for this CV fold
970+
scatter_plot = (
971+
altair.Chart(residuals_data)
972+
.mark_circle(opacity=0.6, size=20)
973+
.encode(
974+
x=altair.X(
975+
"predicted_load_mw:Q",
976+
title="Predicted Load (MW)",
977+
scale=altair.Scale(zero=False),
978+
),
979+
y=altair.Y("residual:Q", title="Residual (MW)"),
980+
color=altair.Color("fold:N", legend=None),
981+
tooltip=[
982+
"prediction_time:T",
983+
"load_mw:Q",
984+
"predicted_load_mw:Q",
985+
"residual:Q",
986+
"fold:N",
987+
],
988+
)
989+
)
990+
991+
all_scatter_plots.append(scatter_plot)
992+
993+
# Get the range of predicted values for the perfect line
994+
all_predictions = pl.concat(
995+
[cv_pred["predicted_load_mw"] for cv_pred in cv_predictions]
996+
)
997+
min_pred, max_pred = all_predictions.min(), all_predictions.max()
998+
999+
# Create perfect residuals line at y=0
1000+
perfect_line = (
1001+
altair.Chart(
1002+
pl.DataFrame(
1003+
{
1004+
"predicted_load_mw": [min_pred, max_pred],
1005+
"perfect_residual": [0, 0],
1006+
"label": ["Perfect"] * 2,
1007+
}
1008+
)
1009+
)
1010+
.mark_line(strokeDash=[5, 5], opacity=0.8, color="black")
1011+
.encode(
1012+
x=altair.X("predicted_load_mw:Q", title="Predicted Load (MW)"),
1013+
y=altair.Y("perfect_residual:Q", title="Residual (MW)"),
1014+
color=altair.Color(
1015+
"label:N",
1016+
scale=altair.Scale(range=["black"]),
1017+
legend=None,
1018+
),
1019+
)
1020+
)
1021+
1022+
# Combine all scatter plots
1023+
combined_scatter = all_scatter_plots[0]
1024+
for plot in all_scatter_plots[1:]:
1025+
combined_scatter += plot
1026+
1027+
# Layer the scatter plots with the perfect line
1028+
return (combined_scatter + perfect_line).resolve_scale(color="independent")
1029+
1030+
1031+
plot_residuals_vs_predicted(cv_predictions).interactive().properties(
1032+
title="Residuals vs Predicted Values from cross-validation predictions"
1033+
)
1034+
9531035

9541036
# %%
9551037
def plot_binned_residuals(cv_predictions, by="hour"):

0 commit comments

Comments
 (0)