@@ -206,21 +206,29 @@ def plot_reliability_diagram(
206206 all_loads = pl .concat (all_loads ["load_mw" , "predicted_load_mw" ])
207207 min_load , max_load = all_loads .min (), all_loads .max ()
208208 scale = altair .Scale (domain = [min_load , max_load ])
209+ if kind == "mean" :
210+ y_name = "mean_load_mw"
211+ agg_expr = pl .col ("load_mw" ).mean ()
212+ elif kind == "quantile" :
213+ y_name = "quantile_of_load_mw"
214+ agg_expr = pl .col ("load_mw" ).quantile (quantile_level )
215+ else :
216+ raise ValueError (f"Unknown kind: { kind } . Use 'mean' or 'quantile'." )
209217
210218 chart = (
211219 altair .Chart (
212220 pl .DataFrame (
213221 {
214222 "mean_predicted_load_mw" : [min_load , max_load ],
215- "mean_load_mw" : [min_load , max_load ],
223+ y_name : [min_load , max_load ],
216224 "label" : ["Perfect" ] * 2 ,
217225 }
218226 )
219227 )
220228 .mark_line (tooltip = True , opacity = 0.8 , strokeDash = [5 , 5 ])
221229 .encode (
222230 x = altair .X ("mean_predicted_load_mw:Q" , scale = scale ),
223- y = altair .Y ("mean_load_mw :Q" , scale = scale ),
231+ y = altair .Y (f" { y_name } :Q" , scale = scale ),
224232 color = altair .Color (
225233 "label:N" ,
226234 scale = altair .Scale (range = ["black" ]),
@@ -234,15 +242,6 @@ def plot_reliability_diagram(
234242 max_date = cv_predictions_i ["prediction_time" ].max ().strftime ("%Y-%m-%d" )
235243 fold_label = f"#{ fold_idx } - { min_date } to { max_date } "
236244
237- if kind == "mean" :
238- y_name = "mean_load_mw"
239- agg_expr = pl .col ("load_mw" ).mean ()
240- elif kind == "quantile" :
241- y_name = "quantile_of_load_mw"
242- agg_expr = pl .col ("load_mw" ).quantile (quantile_level )
243- else :
244- raise ValueError (f"Unknown kind: { kind } . Use 'mean' or 'quantile'." )
245-
246245 mean_per_bins = (
247246 cv_predictions_i .group_by (
248247 pl .col ("predicted_load_mw" ).qcut (np .linspace (0 , 1 , n_bins ))
0 commit comments