Skip to content

Commit f4f44ea

Browse files
committed
tweaks
1 parent d734a86 commit f4f44ea

1 file changed

Lines changed: 5 additions & 6 deletions

File tree

content/python_files/feature_engineering.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -504,12 +504,13 @@ def build_features(
504504
"future_24h": s.glob("*_future_24h"),
505505
"non_paris_weather": s.glob("weather_*") & ~s.glob("weather_*_paris_*"),
506506
},
507-
name="dropped_features",
507+
name="dropped_cols",
508508
)
509509
)
510510
).skb.apply(
511511
HistGradientBoostingRegressor(
512512
random_state=0,
513+
loss=skrub.choose_from(["squared_error", "poisson", "gamma"], name="loss"),
513514
learning_rate=skrub.choose_float(
514515
0.01, 1, default=0.1, log=True, name="learning_rate"
515516
),
@@ -587,7 +588,6 @@ def build_features(
587588

588589
# %%
589590
def collect_cv_predictions(pipelines, cv_splitter, predictions, prediction_time):
590-
591591
index_generator = cv_splitter.split(prediction_time.skb.eval())
592592

593593
def splitter(X, y, index_generator):
@@ -597,7 +597,6 @@ def splitter(X, y, index_generator):
597597
return X[train_idx], X[test_idx], y[train_idx], y[test_idx]
598598

599599
results = []
600-
601600
for (_, test_idx), pipeline in zip(
602601
cv_splitter.split(prediction_time.skb.eval()), pipelines
603602
):
@@ -823,7 +822,7 @@ def plot_reliability_diagram(cv_predictions, n_bins=10):
823822
"fold:N",
824823
legend=altair.Legend(title=None),
825824
),
826-
detail= altair.Detail("fold:N")
825+
detail=altair.Detail("fold:N"),
827826
)
828827
)
829828
return chart.resolve_scale(color="independent")
@@ -848,7 +847,7 @@ def plot_residuals_by_hour(cv_predictions):
848847

849848
residuals_by_hour_detailed = cv_prediction.with_columns(
850849
[
851-
(pl.col("load_mw") - pl.col("predicted_load_mw")).alias("residual"),
850+
(pl.col("predicted_load_mw") - pl.col("load_mw")).alias("residual"),
852851
pl.col("prediction_time").dt.hour().alias("hour_of_day"),
853852
]
854853
)
@@ -925,7 +924,7 @@ def plot_residuals_by_month(cv_predictions):
925924

926925
residuals_by_month_detailed = cv_prediction.with_columns(
927926
[
928-
(pl.col("load_mw") - pl.col("predicted_load_mw")).alias("residual"),
927+
(pl.col("predicted_load_mw") - pl.col("load_mw")).alias("residual"),
929928
pl.col("prediction_time").dt.month().alias("month_of_year"),
930929
]
931930
)

0 commit comments

Comments
 (0)