Skip to content

Commit 9d80424

Browse files
committed
Fix ylabel of quantile regression reliability diagrams
1 parent ee838a2 commit 9d80424

1 file changed

Lines changed: 10 additions & 11 deletions

File tree

content/python_files/tutorial_helpers.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -206,21 +206,29 @@ def plot_reliability_diagram(
206206
all_loads = pl.concat(all_loads["load_mw", "predicted_load_mw"])
207207
min_load, max_load = all_loads.min(), all_loads.max()
208208
scale = altair.Scale(domain=[min_load, max_load])
209+
if kind == "mean":
210+
y_name = "mean_load_mw"
211+
agg_expr = pl.col("load_mw").mean()
212+
elif kind == "quantile":
213+
y_name = "quantile_of_load_mw"
214+
agg_expr = pl.col("load_mw").quantile(quantile_level)
215+
else:
216+
raise ValueError(f"Unknown kind: {kind}. Use 'mean' or 'quantile'.")
209217

210218
chart = (
211219
altair.Chart(
212220
pl.DataFrame(
213221
{
214222
"mean_predicted_load_mw": [min_load, max_load],
215-
"mean_load_mw": [min_load, max_load],
223+
y_name: [min_load, max_load],
216224
"label": ["Perfect"] * 2,
217225
}
218226
)
219227
)
220228
.mark_line(tooltip=True, opacity=0.8, strokeDash=[5, 5])
221229
.encode(
222230
x=altair.X("mean_predicted_load_mw:Q", scale=scale),
223-
y=altair.Y("mean_load_mw:Q", scale=scale),
231+
y=altair.Y(f"{y_name}:Q", scale=scale),
224232
color=altair.Color(
225233
"label:N",
226234
scale=altair.Scale(range=["black"]),
@@ -234,15 +242,6 @@ def plot_reliability_diagram(
234242
max_date = cv_predictions_i["prediction_time"].max().strftime("%Y-%m-%d")
235243
fold_label = f"#{fold_idx} - {min_date} to {max_date}"
236244

237-
if kind == "mean":
238-
y_name = "mean_load_mw"
239-
agg_expr = pl.col("load_mw").mean()
240-
elif kind == "quantile":
241-
y_name = "quantile_of_load_mw"
242-
agg_expr = pl.col("load_mw").quantile(quantile_level)
243-
else:
244-
raise ValueError(f"Unknown kind: {kind}. Use 'mean' or 'quantile'.")
245-
246245
mean_per_bins = (
247246
cv_predictions_i.group_by(
248247
pl.col("predicted_load_mw").qcut(np.linspace(0, 1, n_bins))

0 commit comments

Comments
 (0)