Skip to content

Commit d734a86

Browse files
committed
update lorenz curve
1 parent 1a46e93 commit d734a86

1 file changed

Lines changed: 72 additions & 51 deletions

File tree

content/python_files/feature_engineering.py

Lines changed: 72 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,12 @@ def splitter(X, y, index_generator):
629629

630630
def lorenz_curve(observed_value, predicted_value, n_samples=1_000):
631631
"""Compute the Lorenz curve for a given true and predicted values."""
632+
633+
def gini_index(cum_proportion_population, cum_proportion_y_true):
634+
from sklearn.metrics import auc
635+
636+
return 1 - 2 * auc(cum_proportion_population, cum_proportion_y_true)
637+
632638
observed_value = np.asarray(observed_value)
633639
predicted_value = np.asarray(predicted_value)
634640

@@ -642,98 +648,113 @@ def lorenz_curve(observed_value, predicted_value, n_samples=1_000):
642648
cum_proportion_y_true = np.cumsum(observed_value_sorted)
643649
cum_proportion_y_true /= cum_proportion_y_true[-1]
644650

651+
gini_model = gini_index(cum_proportion_population, cum_proportion_y_true)
652+
645653
cum_proportion_population_interpolated = np.linspace(0, 1, n_samples)
646654
cum_proportion_y_true_interpolated = np.interp(
647655
cum_proportion_population_interpolated,
648656
cum_proportion_population,
649657
cum_proportion_y_true,
650658
)
651659

652-
return cum_proportion_population_interpolated, cum_proportion_y_true_interpolated
660+
return pl.DataFrame(
661+
{
662+
"cum_population": cum_proportion_population_interpolated,
663+
"cum_observed": cum_proportion_y_true_interpolated,
664+
}
665+
).with_columns(
666+
pl.lit(gini_model).alias("gini_index"),
667+
)
653668

654669

655-
def plot_lorenz_curve(observed_value, predicted_value, n_samples=1_000):
670+
def plot_lorenz_curve(cv_predictions, n_samples=1_000):
656671
"""Plot the Lorenz curve for a given true and predicted values."""
657672

658-
def gini_index(cum_proportion_population, cum_proportion_y_true):
659-
from sklearn.metrics import auc
673+
results = []
674+
for cv_idx, predictions in enumerate(cv_predictions):
675+
results.append(
676+
lorenz_curve(
677+
observed_value=predictions["load_mw"],
678+
predicted_value=predictions["predicted_load_mw"],
679+
n_samples=n_samples,
680+
).with_columns(
681+
pl.lit(cv_idx).alias("cv_idx"),
682+
pl.lit("model").alias("model"),
683+
)
684+
)
660685

661-
return 1 - 2 * auc(cum_proportion_population, cum_proportion_y_true)
686+
results.append(
687+
lorenz_curve(
688+
observed_value=predictions["load_mw"],
689+
predicted_value=predictions["load_mw"],
690+
n_samples=n_samples,
691+
).with_columns(
692+
pl.lit(cv_idx).alias("cv_idx"),
693+
pl.lit("oracle").alias("model"),
694+
)
695+
)
662696

663-
cum_population_model, cum_observed_model = lorenz_curve(
664-
observed_value, predicted_value, n_samples
665-
)
666-
gini_model = gini_index(cum_population_model, cum_observed_model)
697+
results = pl.concat(results)
667698

668-
cum_population_oracle, cum_observed_oracle = lorenz_curve(
669-
observed_value, observed_value, n_samples
699+
gini_stats = results.group_by("model").agg(
700+
[
701+
pl.col("gini_index")
702+
.mean()
703+
.map_elements(lambda x: f"{x:.4f}", return_dtype=pl.String)
704+
.alias("gini_mean"),
705+
pl.col("gini_index")
706+
.std()
707+
.map_elements(lambda x: f"{x:.4f}", return_dtype=pl.String)
708+
.alias("gini_std_dev"),
709+
]
670710
)
671-
gini_oracle = gini_index(cum_population_oracle, cum_observed_oracle)
672711

673-
model_chart = (
674-
altair.Chart(
675-
pl.DataFrame(
676-
{
677-
"cum_population": cum_population_model,
678-
"cum_observed": cum_observed_model,
679-
"model": f"Model (Gini index: {gini_model:.4f})",
680-
}
681-
)
682-
)
683-
.mark_line(strokeDash=[6, 3, 2, 3], tooltip=True)
684-
.encode(
685-
x=altair.X(
686-
"cum_population:Q",
687-
title="Fraction of observations sorted by predicted label",
688-
),
689-
y=altair.Y("cum_observed:Q", title="Cumulative observed load proportion"),
690-
color=altair.Color("model:N", legend=altair.Legend(title="Models")),
712+
results = results.join(gini_stats, on="model").with_columns(
713+
pl.format("{} ({} +/- {})", "model", "gini_mean", "gini_std_dev").alias(
714+
"model_label"
691715
)
692716
)
693717

694-
oracle_chart = (
718+
diagonal_chart = (
695719
altair.Chart(
696720
pl.DataFrame(
697721
{
698-
"cum_population": cum_population_oracle,
699-
"cum_observed": cum_observed_oracle,
700-
"model": f"Oracle (Gini index: {gini_oracle:.4f})",
722+
"cum_population": [0, 1],
723+
"cum_observed": [0, 1],
724+
"model_label": "Non-informative model",
701725
}
702726
)
703727
)
704-
.mark_line(strokeDash=[4, 4], tooltip=True)
728+
.mark_line(strokeDash=[4, 4], opacity=0.5, tooltip=True)
705729
.encode(
706730
x=altair.X(
707731
"cum_population:Q",
708732
title="Fraction of observations sorted by predicted label",
709733
),
710734
y=altair.Y("cum_observed:Q", title="Cumulative observed load proportion"),
711-
color=altair.Color("model:N", legend=altair.Legend(title="Models")),
735+
color=altair.Color("model_label:N", legend=altair.Legend(title="Models")),
712736
)
713737
)
714738

715-
diagonal_chart = (
716-
altair.Chart(
717-
pl.DataFrame(
718-
{
719-
"cum_population": [0, 1],
720-
"cum_observed": [0, 1],
721-
"model": "Non-informative model"
722-
}
723-
)
724-
)
725-
.mark_line(strokeDash=[4, 4], opacity=0.5, tooltip=True)
739+
model_chart = (
740+
altair.Chart(results)
741+
.mark_line(opacity=0.3, tooltip=True)
726742
.encode(
727743
x=altair.X(
728744
"cum_population:Q",
729-
title="Fraction of observations sorted by predicted label"
745+
title="Fraction of observations sorted by predicted label",
730746
),
731747
y=altair.Y("cum_observed:Q", title="Cumulative observed load proportion"),
732-
color=altair.Color("model:N", legend=altair.Legend(title="Models"))
748+
color=altair.Color("model_label:N", legend=altair.Legend(title="Models")),
749+
detail="cv_idx:N",
733750
)
734751
)
735752

736-
return model_chart + oracle_chart + diagonal_chart
753+
return model_chart + diagonal_chart
754+
755+
756+
plot_lorenz_curve(cv_predictions, n_samples=500).interactive()
757+
737758

738759
# %%
739760
def plot_reliability_diagram(cv_predictions, n_bins=10):
@@ -1181,4 +1202,4 @@ def scoring(regressor, X, y):
11811202

11821203
display(chart)
11831204

1184-
# %%
1205+
# %%

0 commit comments

Comments
 (0)