@@ -639,6 +639,8 @@ def plot_reliability_diagram(cv_predictions, n_bins=10):
639639 all_loads = pl .concat (all_loads ["load_mw" , "predicted_load_mw" ])
640640 min_load , max_load = all_loads .min (), all_loads .max ()
641641 scale = altair .Scale (domain = [min_load , max_load ])
642+
643+ # Create the perfect line
642644 chart = (
643645 altair .Chart (
644646 pl .DataFrame (
@@ -660,7 +662,14 @@ def plot_reliability_diagram(cv_predictions, n_bins=10):
660662 ),
661663 )
662664 )
665+
666+ # Add lines for each CV fold with date labels
663667 for i , cv_predictions_i in enumerate (cv_predictions ):
668+ # Get date range for this CV fold
669+ min_date = cv_predictions_i ["prediction_time" ].min ().strftime ("%Y-%m-%d" )
670+ max_date = cv_predictions_i ["prediction_time" ].max ().strftime ("%Y-%m-%d" )
671+ fold_label = f"#{ i + 1 } - { min_date } to { max_date } "
672+
664673 mean_per_bins = (
665674 cv_predictions_i .group_by (
666675 pl .col ("predicted_load_mw" ).qcut (np .linspace (0 , 1 , n_bins ))
@@ -672,16 +681,23 @@ def plot_reliability_diagram(cv_predictions, n_bins=10):
672681 ]
673682 )
674683 .sort ("predicted_load_mw" )
684+ .with_columns (pl .lit (fold_label ).alias ("fold" ))
675685 )
686+
676687 chart += (
677688 altair .Chart (mean_per_bins )
678689 .mark_line (tooltip = True , point = True , opacity = 0.8 )
679690 .encode (
680691 x = altair .X ("mean_predicted_load_mw:Q" , scale = scale ),
681692 y = altair .Y ("mean_load_mw:Q" , scale = scale ),
693+ color = altair .Color (
694+ "fold:N" ,
695+ legend = altair .Legend (title = None ),
696+ ),
697+ detail = altair .Detail ("fold:N" )
682698 )
683699 )
684- return chart
700+ return chart . resolve_scale ( color = "independent" )
685701
686702
687703plot_reliability_diagram (cv_predictions ).interactive ().properties (
@@ -699,7 +715,7 @@ def plot_residuals_by_hour(cv_predictions):
699715 # Get date range for this CV fold
700716 min_date = cv_prediction ["prediction_time" ].min ().strftime ("%Y-%m-%d" )
701717 max_date = cv_prediction ["prediction_time" ].max ().strftime ("%Y-%m-%d" )
702- fold_label = f"#{ i + 1 } ( { min_date } to { max_date } ) "
718+ fold_label = f"#{ i + 1 } - { min_date } to { max_date } "
703719
704720 residuals_by_hour_detailed = cv_prediction .with_columns (
705721 [
@@ -730,7 +746,6 @@ def plot_residuals_by_hour(cv_predictions):
730746 x = altair .X ("hour_of_day:O" , title = "Hour of day" ),
731747 y = altair .Y ("q25_residual:Q" ),
732748 y2 = altair .Y2 ("q75_residual:Q" ),
733- tooltip = ["hour_of_day:O" , "fold:N" , "q25_residual:Q" , "q75_residual:Q" ],
734749 )
735750 )
736751
@@ -808,12 +823,6 @@ def plot_residuals_by_month(cv_predictions):
808823 x = altair .X ("month_of_year:O" , title = "Month of year" ),
809824 y = altair .Y ("q25_residual:Q" ),
810825 y2 = altair .Y2 ("q75_residual:Q" ),
811- tooltip = [
812- "month_of_year:O" ,
813- "fold:N" ,
814- "q25_residual:Q" ,
815- "q75_residual:Q" ,
816- ],
817826 )
818827 )
819828
0 commit comments