@@ -629,6 +629,12 @@ def splitter(X, y, index_generator):
629629
630630def lorenz_curve (observed_value , predicted_value , n_samples = 1_000 ):
631631 """Compute the Lorenz curve for a given true and predicted values."""
632+
633+ def gini_index (cum_proportion_population , cum_proportion_y_true ):
634+ from sklearn .metrics import auc
635+
636+ return 1 - 2 * auc (cum_proportion_population , cum_proportion_y_true )
637+
632638 observed_value = np .asarray (observed_value )
633639 predicted_value = np .asarray (predicted_value )
634640
@@ -642,98 +648,113 @@ def lorenz_curve(observed_value, predicted_value, n_samples=1_000):
642648 cum_proportion_y_true = np .cumsum (observed_value_sorted )
643649 cum_proportion_y_true /= cum_proportion_y_true [- 1 ]
644650
651+ gini_model = gini_index (cum_proportion_population , cum_proportion_y_true )
652+
645653 cum_proportion_population_interpolated = np .linspace (0 , 1 , n_samples )
646654 cum_proportion_y_true_interpolated = np .interp (
647655 cum_proportion_population_interpolated ,
648656 cum_proportion_population ,
649657 cum_proportion_y_true ,
650658 )
651659
652- return cum_proportion_population_interpolated , cum_proportion_y_true_interpolated
660+ return pl .DataFrame (
661+ {
662+ "cum_population" : cum_proportion_population_interpolated ,
663+ "cum_observed" : cum_proportion_y_true_interpolated ,
664+ }
665+ ).with_columns (
666+ pl .lit (gini_model ).alias ("gini_index" ),
667+ )
653668
654669
655- def plot_lorenz_curve (observed_value , predicted_value , n_samples = 1_000 ):
670+ def plot_lorenz_curve (cv_predictions , n_samples = 1_000 ):
656671 """Plot the Lorenz curve for a given true and predicted values."""
657672
658- def gini_index (cum_proportion_population , cum_proportion_y_true ):
659- from sklearn .metrics import auc
673+ results = []
674+ for cv_idx , predictions in enumerate (cv_predictions ):
675+ results .append (
676+ lorenz_curve (
677+ observed_value = predictions ["load_mw" ],
678+ predicted_value = predictions ["predicted_load_mw" ],
679+ n_samples = n_samples ,
680+ ).with_columns (
681+ pl .lit (cv_idx ).alias ("cv_idx" ),
682+ pl .lit ("model" ).alias ("model" ),
683+ )
684+ )
660685
661- return 1 - 2 * auc (cum_proportion_population , cum_proportion_y_true )
686+ results .append (
687+ lorenz_curve (
688+ observed_value = predictions ["load_mw" ],
689+ predicted_value = predictions ["load_mw" ],
690+ n_samples = n_samples ,
691+ ).with_columns (
692+ pl .lit (cv_idx ).alias ("cv_idx" ),
693+ pl .lit ("oracle" ).alias ("model" ),
694+ )
695+ )
662696
663- cum_population_model , cum_observed_model = lorenz_curve (
664- observed_value , predicted_value , n_samples
665- )
666- gini_model = gini_index (cum_population_model , cum_observed_model )
697+ results = pl .concat (results )
667698
668- cum_population_oracle , cum_observed_oracle = lorenz_curve (
669- observed_value , observed_value , n_samples
699+ gini_stats = results .group_by ("model" ).agg (
700+ [
701+ pl .col ("gini_index" )
702+ .mean ()
703+ .map_elements (lambda x : f"{ x :.4f} " , return_dtype = pl .String )
704+ .alias ("gini_mean" ),
705+ pl .col ("gini_index" )
706+ .std ()
707+ .map_elements (lambda x : f"{ x :.4f} " , return_dtype = pl .String )
708+ .alias ("gini_std_dev" ),
709+ ]
670710 )
671- gini_oracle = gini_index (cum_population_oracle , cum_observed_oracle )
672711
673- model_chart = (
674- altair .Chart (
675- pl .DataFrame (
676- {
677- "cum_population" : cum_population_model ,
678- "cum_observed" : cum_observed_model ,
679- "model" : f"Model (Gini index: { gini_model :.4f} )" ,
680- }
681- )
682- )
683- .mark_line (strokeDash = [6 , 3 , 2 , 3 ], tooltip = True )
684- .encode (
685- x = altair .X (
686- "cum_population:Q" ,
687- title = "Fraction of observations sorted by predicted label" ,
688- ),
689- y = altair .Y ("cum_observed:Q" , title = "Cumulative observed load proportion" ),
690- color = altair .Color ("model:N" , legend = altair .Legend (title = "Models" )),
712+ results = results .join (gini_stats , on = "model" ).with_columns (
713+ pl .format ("{} ({} +/- {})" , "model" , "gini_mean" , "gini_std_dev" ).alias (
714+ "model_label"
691715 )
692716 )
693717
694- oracle_chart = (
718+ diagonal_chart = (
695719 altair .Chart (
696720 pl .DataFrame (
697721 {
698- "cum_population" : cum_population_oracle ,
699- "cum_observed" : cum_observed_oracle ,
700- "model " : f"Oracle (Gini index: { gini_oracle :.4f } ) " ,
722+ "cum_population" : [ 0 , 1 ] ,
723+ "cum_observed" : [ 0 , 1 ] ,
724+ "model_label " : "Non-informative model " ,
701725 }
702726 )
703727 )
704- .mark_line (strokeDash = [4 , 4 ], tooltip = True )
728+ .mark_line (strokeDash = [4 , 4 ], opacity = 0.5 , tooltip = True )
705729 .encode (
706730 x = altair .X (
707731 "cum_population:Q" ,
708732 title = "Fraction of observations sorted by predicted label" ,
709733 ),
710734 y = altair .Y ("cum_observed:Q" , title = "Cumulative observed load proportion" ),
711- color = altair .Color ("model :N" , legend = altair .Legend (title = "Models" )),
735+ color = altair .Color ("model_label :N" , legend = altair .Legend (title = "Models" )),
712736 )
713737 )
714738
715- diagonal_chart = (
716- altair .Chart (
717- pl .DataFrame (
718- {
719- "cum_population" : [0 , 1 ],
720- "cum_observed" : [0 , 1 ],
721- "model" : "Non-informative model"
722- }
723- )
724- )
725- .mark_line (strokeDash = [4 , 4 ], opacity = 0.5 , tooltip = True )
739+ model_chart = (
740+ altair .Chart (results )
741+ .mark_line (opacity = 0.3 , tooltip = True )
726742 .encode (
727743 x = altair .X (
728744 "cum_population:Q" ,
729- title = "Fraction of observations sorted by predicted label"
745+ title = "Fraction of observations sorted by predicted label" ,
730746 ),
731747 y = altair .Y ("cum_observed:Q" , title = "Cumulative observed load proportion" ),
732- color = altair .Color ("model:N" , legend = altair .Legend (title = "Models" ))
748+ color = altair .Color ("model_label:N" , legend = altair .Legend (title = "Models" )),
749+ detail = "cv_idx:N" ,
733750 )
734751 )
735752
736- return model_chart + oracle_chart + diagonal_chart
753+ return model_chart + diagonal_chart
754+
755+
756+ plot_lorenz_curve (cv_predictions , n_samples = 500 ).interactive ()
757+
737758
738759# %%
739760def plot_reliability_diagram (cv_predictions , n_bins = 10 ):
@@ -1181,4 +1202,4 @@ def scoring(regressor, X, y):
11811202
11821203 display (chart )
11831204
1184- # %%
1205+ # %%
0 commit comments