@@ -950,6 +950,88 @@ def plot_reliability_diagram(cv_predictions, n_bins=10):
950950 title = "Reliability diagram from cross-validation predictions"
951951)
952952
953+ # %%
954+ def plot_residuals_vs_predicted (cv_predictions ):
955+ """Plot residuals vs predicted values scatter plot for all CV folds."""
956+ all_scatter_plots = []
957+
958+ for i , cv_prediction in enumerate (cv_predictions ):
959+ # Get date range for this CV fold
960+ min_date = cv_prediction ["prediction_time" ].min ().strftime ("%Y-%m-%d" )
961+ max_date = cv_prediction ["prediction_time" ].max ().strftime ("%Y-%m-%d" )
962+ fold_label = f"#{ i + 1 } - { min_date } to { max_date } "
963+
964+ # Calculate residuals
965+ residuals_data = cv_prediction .with_columns (
966+ [(pl .col ("predicted_load_mw" ) - pl .col ("load_mw" )).alias ("residual" )]
967+ ).with_columns ([pl .lit (fold_label ).alias ("fold" )])
968+
969+ # Create scatter plot for this CV fold
970+ scatter_plot = (
971+ altair .Chart (residuals_data )
972+ .mark_circle (opacity = 0.6 , size = 20 )
973+ .encode (
974+ x = altair .X (
975+ "predicted_load_mw:Q" ,
976+ title = "Predicted Load (MW)" ,
977+ scale = altair .Scale (zero = False ),
978+ ),
979+ y = altair .Y ("residual:Q" , title = "Residual (MW)" ),
980+ color = altair .Color ("fold:N" , legend = None ),
981+ tooltip = [
982+ "prediction_time:T" ,
983+ "load_mw:Q" ,
984+ "predicted_load_mw:Q" ,
985+ "residual:Q" ,
986+ "fold:N" ,
987+ ],
988+ )
989+ )
990+
991+ all_scatter_plots .append (scatter_plot )
992+
993+ # Get the range of predicted values for the perfect line
994+ all_predictions = pl .concat (
995+ [cv_pred ["predicted_load_mw" ] for cv_pred in cv_predictions ]
996+ )
997+ min_pred , max_pred = all_predictions .min (), all_predictions .max ()
998+
999+ # Create perfect residuals line at y=0
1000+ perfect_line = (
1001+ altair .Chart (
1002+ pl .DataFrame (
1003+ {
1004+ "predicted_load_mw" : [min_pred , max_pred ],
1005+ "perfect_residual" : [0 , 0 ],
1006+ "label" : ["Perfect" ] * 2 ,
1007+ }
1008+ )
1009+ )
1010+ .mark_line (strokeDash = [5 , 5 ], opacity = 0.8 , color = "black" )
1011+ .encode (
1012+ x = altair .X ("predicted_load_mw:Q" , title = "Predicted Load (MW)" ),
1013+ y = altair .Y ("perfect_residual:Q" , title = "Residual (MW)" ),
1014+ color = altair .Color (
1015+ "label:N" ,
1016+ scale = altair .Scale (range = ["black" ]),
1017+ legend = None ,
1018+ ),
1019+ )
1020+ )
1021+
1022+ # Combine all scatter plots
1023+ combined_scatter = all_scatter_plots [0 ]
1024+ for plot in all_scatter_plots [1 :]:
1025+ combined_scatter += plot
1026+
1027+ # Layer the scatter plots with the perfect line
1028+ return (combined_scatter + perfect_line ).resolve_scale (color = "independent" )
1029+
1030+
1031+ plot_residuals_vs_predicted (cv_predictions ).interactive ().properties (
1032+ title = "Residuals vs Predicted Values from cross-validation predictions"
1033+ )
1034+
9531035
9541036# %%
9551037def plot_binned_residuals (cv_predictions , by = "hour" ):
0 commit comments