Skip to content

Commit a48d4d0

Browse files
committed
Small improvements in horizon plots
1 parent 1e16b62 commit a48d4d0

2 files changed

Lines changed: 9 additions & 10 deletions

File tree

content/python_files/feature_engineering.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def load_electricity_load_data(time, data_source_folder):
267267
on="time",
268268
)
269269

270+
270271
# %% [markdown]
271272
#
272273
# Let's load the data and check if there are missing values since we will use
@@ -1195,12 +1196,10 @@ def build_targets(prediction_time, electricity, horizons):
11951196

11961197
# %%
11971198
plot_at_time = datetime.datetime(2021, 4, 19, 0, 0, tzinfo=datetime.timezone.utc)
1198-
historical_timedelta = datetime.timedelta(hours=24 * 5)
11991199
plot_horizon_forecast(
12001200
targets,
12011201
named_predictions,
12021202
plot_at_time,
1203-
historical_timedelta,
12041203
target_column_name_pattern,
12051204
).skb.preview()
12061205

@@ -1283,7 +1282,7 @@ def scoring(regressor, X, y):
12831282
chart = (
12841283
altair.Chart(
12851284
data_long,
1286-
title=f"{dataset_type.title()} {metric_name.upper()} Scores by Horizon",
1285+
title=f"{dataset_type.title()} {metric_name.upper()} scores by horizon",
12871286
)
12881287
.mark_boxplot(extent="min-max")
12891288
.encode(
@@ -1357,12 +1356,10 @@ def scoring(regressor, X, y):
13571356

13581357
# %%
13591358
plot_at_time = datetime.datetime(2021, 4, 24, 0, 0, tzinfo=datetime.timezone.utc)
1360-
historical_timedelta = datetime.timedelta(hours=24 * 5)
13611359
plot_horizon_forecast(
13621360
targets,
13631361
named_predictions_rf,
13641362
plot_at_time,
1365-
historical_timedelta,
13661363
target_column_name_pattern,
13671364
).skb.preview()
13681365

@@ -1384,7 +1381,9 @@ def scoring(regressor, X, y):
13841381

13851382
for metric_name, dataset_type in itertools.product(["mape", "r2"], ["train", "test"]):
13861383
columns = multioutput_cv_results_rf.columns[
1387-
multioutput_cv_results.columns.str.startswith(f"{dataset_type}_{metric_name}")
1384+
multioutput_cv_results_rf.columns.str.startswith(
1385+
f"{dataset_type}_{metric_name}"
1386+
)
13881387
]
13891388
data_to_plot = multioutput_cv_results_rf[columns]
13901389
data_to_plot.columns = [

content/python_files/tutorial_helpers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -491,8 +491,8 @@ def plot_horizon_forecast(
491491
targets,
492492
named_predictions,
493493
plot_at_time,
494-
historical_timedelta,
495494
target_column_name_pattern,
495+
past_hours=7 * 24,
496496
):
497497
"""Plot the true target and the forecast values for different horizons.
498498
@@ -504,10 +504,10 @@ def plot_horizon_forecast(
504504
A DataFrame containing the predicted values.
505505
plot_at_time : datetime.datetime
506506
The time at which to plot the forecast.
507-
historical_timedelta : datetime.timedelta
508-
The historical timedelta to use for the plot.
509507
target_column_name_pattern : str
510508
The pattern to use for the target column names.
509+
past_hours : int, default=7 * 24
510+
The number of past hours to include in the plot.
511511
512512
Returns
513513
-------
@@ -521,7 +521,7 @@ def plot_horizon_forecast(
521521
],
522522
how="horizontal",
523523
)
524-
start_time = plot_at_time - historical_timedelta
524+
start_time = plot_at_time - datetime.timedelta(hours=past_hours)
525525
end_time = plot_at_time + datetime.timedelta(hours=named_predictions.shape[1])
526526
true_values_past = merged_data.filter(
527527
pl.col("prediction_time").is_between(start_time, plot_at_time, closed="both")

0 commit comments

Comments
 (0)