Skip to content

Commit 202ec52

Browse files
committed
Improve legend for the reliability diagram
1 parent ee43715 commit 202ec52

1 file changed

Lines changed: 18 additions & 9 deletions

File tree

content/python_files/feature_engineering.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,8 @@ def plot_reliability_diagram(cv_predictions, n_bins=10):
639639
all_loads = pl.concat(all_loads["load_mw", "predicted_load_mw"])
640640
min_load, max_load = all_loads.min(), all_loads.max()
641641
scale = altair.Scale(domain=[min_load, max_load])
642+
643+
# Create the perfect line
642644
chart = (
643645
altair.Chart(
644646
pl.DataFrame(
@@ -660,7 +662,14 @@ def plot_reliability_diagram(cv_predictions, n_bins=10):
660662
),
661663
)
662664
)
665+
666+
# Add lines for each CV fold with date labels
663667
for i, cv_predictions_i in enumerate(cv_predictions):
668+
# Get date range for this CV fold
669+
min_date = cv_predictions_i["prediction_time"].min().strftime("%Y-%m-%d")
670+
max_date = cv_predictions_i["prediction_time"].max().strftime("%Y-%m-%d")
671+
fold_label = f"#{i+1} - {min_date} to {max_date}"
672+
664673
mean_per_bins = (
665674
cv_predictions_i.group_by(
666675
pl.col("predicted_load_mw").qcut(np.linspace(0, 1, n_bins))
@@ -672,16 +681,23 @@ def plot_reliability_diagram(cv_predictions, n_bins=10):
672681
]
673682
)
674683
.sort("predicted_load_mw")
684+
.with_columns(pl.lit(fold_label).alias("fold"))
675685
)
686+
676687
chart += (
677688
altair.Chart(mean_per_bins)
678689
.mark_line(tooltip=True, point=True, opacity=0.8)
679690
.encode(
680691
x=altair.X("mean_predicted_load_mw:Q", scale=scale),
681692
y=altair.Y("mean_load_mw:Q", scale=scale),
693+
color=altair.Color(
694+
"fold:N",
695+
legend=altair.Legend(title=None),
696+
),
697+
detail= altair.Detail("fold:N")
682698
)
683699
)
684-
return chart
700+
return chart.resolve_scale(color="independent")
685701

686702

687703
plot_reliability_diagram(cv_predictions).interactive().properties(
@@ -699,7 +715,7 @@ def plot_residuals_by_hour(cv_predictions):
699715
# Get date range for this CV fold
700716
min_date = cv_prediction["prediction_time"].min().strftime("%Y-%m-%d")
701717
max_date = cv_prediction["prediction_time"].max().strftime("%Y-%m-%d")
702-
fold_label = f"#{i+1} ({min_date} to {max_date})"
718+
fold_label = f"#{i+1} - {min_date} to {max_date}"
703719

704720
residuals_by_hour_detailed = cv_prediction.with_columns(
705721
[
@@ -730,7 +746,6 @@ def plot_residuals_by_hour(cv_predictions):
730746
x=altair.X("hour_of_day:O", title="Hour of day"),
731747
y=altair.Y("q25_residual:Q"),
732748
y2=altair.Y2("q75_residual:Q"),
733-
tooltip=["hour_of_day:O", "fold:N", "q25_residual:Q", "q75_residual:Q"],
734749
)
735750
)
736751

@@ -808,12 +823,6 @@ def plot_residuals_by_month(cv_predictions):
808823
x=altair.X("month_of_year:O", title="Month of year"),
809824
y=altair.Y("q25_residual:Q"),
810825
y2=altair.Y2("q75_residual:Q"),
811-
tooltip=[
812-
"month_of_year:O",
813-
"fold:N",
814-
"q25_residual:Q",
815-
"q75_residual:Q",
816-
],
817826
)
818827
)
819828

0 commit comments

Comments
 (0)