Skip to content

Commit 3f74764

Browse files
committed
Refactor binned residual plots to remove redundancy
1 parent 996bdef commit 3f74764

1 file changed

Lines changed: 31 additions & 88 deletions

File tree

content/python_files/feature_engineering.py

Lines changed: 31 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -952,8 +952,20 @@ def plot_reliability_diagram(cv_predictions, n_bins=10):
952952

953953

954954
# %%
955-
def plot_residuals_by_hour(cv_predictions):
956-
"""Plot the average residuals per hour of the day, one line and IQR band per CV fold."""
955+
def plot_binned_residuals(cv_predictions, by="hour"):
956+
"""Plot the average residuals binned by time period, one line per CV fold."""
957+
# Configure binning based on the 'by' parameter
958+
if by == "hour":
959+
time_column = "hour_of_day"
960+
time_extractor = pl.col("prediction_time").dt.hour().alias(time_column)
961+
x_title = "Hour of day"
962+
elif by == "month":
963+
time_column = "month_of_year"
964+
time_extractor = pl.col("prediction_time").dt.month().alias(time_column)
965+
x_title = "Month of year"
966+
else:
967+
raise ValueError(f"Unsupported binning method: {by}. Use 'hour' or 'month'.")
968+
957969
all_iqr_bands = []
958970
all_mean_lines = []
959971

@@ -963,24 +975,25 @@ def plot_residuals_by_hour(cv_predictions):
963975
max_date = cv_prediction["prediction_time"].max().strftime("%Y-%m-%d")
964976
fold_label = f"#{i+1} - {min_date} to {max_date}"
965977

966-
residuals_by_hour_detailed = cv_prediction.with_columns(
978+
# Create residuals and time binning columns
979+
residuals_detailed = cv_prediction.with_columns(
967980
[
968981
(pl.col("predicted_load_mw") - pl.col("load_mw")).alias("residual"),
969-
pl.col("prediction_time").dt.hour().alias("hour_of_day"),
982+
time_extractor,
970983
]
971984
)
972985

973986
# Calculate statistics for this CV fold
974987
residuals_stats = (
975-
residuals_by_hour_detailed.group_by("hour_of_day")
988+
residuals_detailed.group_by(time_column)
976989
.agg(
977990
[
978-
pl.col("residual").mean().alias("mean_residual"),
979-
pl.col("residual").quantile(0.25).alias("q25_residual"),
980-
pl.col("residual").quantile(0.75).alias("q75_residual"),
991+
pl.col("residual").mean().round(1).alias("mean_residual"),
992+
pl.col("residual").quantile(0.25).round(1).alias("q25_residual"),
993+
pl.col("residual").quantile(0.75).round(1).alias("q75_residual"),
981994
]
982995
)
983-
.sort("hour_of_day")
996+
.sort(time_column)
984997
.with_columns(pl.lit(fold_label).alias("fold"))
985998
)
986999

@@ -989,7 +1002,7 @@ def plot_residuals_by_hour(cv_predictions):
9891002
altair.Chart(residuals_stats)
9901003
.mark_area(opacity=0.15)
9911004
.encode(
992-
x=altair.X("hour_of_day:O", title="Hour of day"),
1005+
x=altair.X(f"{time_column}:O", title=x_title),
9931006
y=altair.Y("q25_residual:Q"),
9941007
y2=altair.Y2("q75_residual:Q"),
9951008
)
@@ -1000,9 +1013,10 @@ def plot_residuals_by_hour(cv_predictions):
10001013
altair.Chart(residuals_stats)
10011014
.mark_line(tooltip=True, point=True, opacity=0.8)
10021015
.encode(
1003-
x=altair.X("hour_of_day:O", title="Hour of day"),
1016+
x=altair.X(f"{time_column}:O", title=x_title),
10041017
y=altair.Y("mean_residual:Q", title="Mean residual (MW)"),
1005-
color=altair.Color("fold:N", legend=altair.Legend(title="CV Fold")),
1018+
color=altair.Color("fold:N", legend=None),
1019+
detail="fold:N",
10061020
)
10071021
)
10081022

@@ -1020,89 +1034,18 @@ def plot_residuals_by_hour(cv_predictions):
10201034
combined_lines += line
10211035

10221036
# Layer the IQR bands behind the mean lines
1023-
return (combined_iqr + combined_lines).resolve_scale(color="shared")
1037+
return (combined_iqr + combined_lines).resolve_scale(color="independent")
10241038

10251039

1026-
plot_residuals_by_hour(cv_predictions).interactive().properties(
1040+
plot_binned_residuals(cv_predictions, by="hour").interactive().properties(
10271041
title="Residuals by hour of the day from cross-validation predictions"
10281042
)
10291043

10301044

10311045
# %%
1032-
def plot_residuals_by_month(cv_predictions):
1033-
"""Plot the average residuals per month of the year, one line and IQR band per CV fold."""
1034-
all_iqr_bands = []
1035-
all_mean_lines = []
1036-
1037-
for i, cv_prediction in enumerate(cv_predictions):
1038-
# Get date range for this CV fold
1039-
min_date = cv_prediction["prediction_time"].min().strftime("%Y-%m-%d")
1040-
max_date = cv_prediction["prediction_time"].max().strftime("%Y-%m-%d")
1041-
fold_label = f"#{i+1} - {min_date} to {max_date}"
1042-
1043-
residuals_by_month_detailed = cv_prediction.with_columns(
1044-
[
1045-
(pl.col("predicted_load_mw") - pl.col("load_mw")).alias("residual"),
1046-
pl.col("prediction_time").dt.month().alias("month_of_year"),
1047-
]
1048-
)
1049-
1050-
# Calculate statistics for this CV fold
1051-
residuals_stats = (
1052-
residuals_by_month_detailed.group_by("month_of_year")
1053-
.agg(
1054-
[
1055-
pl.col("residual").mean().alias("mean_residual"),
1056-
pl.col("residual").quantile(0.25).alias("q25_residual"),
1057-
pl.col("residual").quantile(0.75).alias("q75_residual"),
1058-
]
1059-
)
1060-
.sort("month_of_year")
1061-
.with_columns(pl.lit(fold_label).alias("fold"))
1062-
)
1063-
1064-
# Create IQR band for this CV fold
1065-
iqr_band = (
1066-
altair.Chart(residuals_stats)
1067-
.mark_area(opacity=0.15)
1068-
.encode(
1069-
x=altair.X("month_of_year:O", title="Month of year"),
1070-
y=altair.Y("q25_residual:Q"),
1071-
y2=altair.Y2("q75_residual:Q"),
1072-
)
1073-
)
1074-
1075-
# Create mean line for this CV fold
1076-
mean_line = (
1077-
altair.Chart(residuals_stats)
1078-
.mark_line(tooltip=True, point=True, opacity=0.8)
1079-
.encode(
1080-
x=altair.X("month_of_year:O", title="Month of year"),
1081-
y=altair.Y("mean_residual:Q", title="Mean residual (MW)"),
1082-
color=altair.Color("fold:N", legend=altair.Legend(title="CV Fold")),
1083-
)
1084-
)
1085-
1086-
all_iqr_bands.append(iqr_band)
1087-
all_mean_lines.append(mean_line)
1088-
1089-
# Combine all IQR bands
1090-
combined_iqr = all_iqr_bands[0]
1091-
for band in all_iqr_bands[1:]:
1092-
combined_iqr += band
1093-
1094-
# Combine all mean lines
1095-
combined_lines = all_mean_lines[0]
1096-
for line in all_mean_lines[1:]:
1097-
combined_lines += line
1098-
1099-
# Layer the IQR bands behind the mean lines
1100-
return (combined_iqr + combined_lines).properties(
1101-
title="Residuals by month of the year from cross-validation predictions"
1102-
)
1103-
1104-
1105-
plot_residuals_by_month(cv_predictions).interactive()
1046+
plot_binned_residuals(cv_predictions, by="month").interactive().properties(
1047+
title="Residuals by hour of the day from cross-validation predictions"
1048+
)
11061049

11071050
# %%
11081051
ts_cv_2 = TimeSeriesSplit(

0 commit comments

Comments
 (0)