Skip to content

Commit f569062

Browse files
committed
Refactor data loading to use literal input vars for the full pipeline
1 parent f4f44ea commit f569062

1 file changed

Lines changed: 127 additions & 114 deletions

File tree

content/python_files/feature_engineering.py

Lines changed: 127 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,34 @@
6363
# system will also be useful later.
6464

6565
# %%
66-
time_range_start = pl.datetime(2021, 3, 23, hour=0, time_zone="UTC")
67-
time_range_end = pl.datetime(2025, 5, 31, hour=23, time_zone="UTC")
68-
time = skrub.var(
69-
"time",
70-
pl.DataFrame().with_columns(
66+
historical_data_start_time = skrub.var(
67+
"historical_data_start_time", pl.datetime(2021, 3, 23, hour=0, time_zone="UTC")
68+
)
69+
historical_data_end_time = skrub.var(
70+
"historical_data_end_time", pl.datetime(2025, 5, 31, hour=23, time_zone="UTC")
71+
)
72+
73+
74+
# %%
75+
@skrub.deferred
76+
def build_historical_time_range(
77+
historical_data_start_time,
78+
historical_data_end_time,
79+
time_interval="1h",
80+
time_zone="UTC",
81+
):
82+
"""Define an historical time range shared by all data sources."""
83+
return pl.DataFrame().with_columns(
7184
pl.datetime_range(
72-
start=time_range_start,
73-
end=time_range_end,
74-
time_zone="UTC",
75-
interval="1h",
85+
start=historical_data_start_time,
86+
end=historical_data_end_time,
87+
time_zone=time_zone,
88+
interval=time_interval,
7689
).alias("time"),
77-
),
78-
)
90+
)
91+
92+
93+
time = build_historical_time_range(historical_data_start_time, historical_data_end_time)
7994
time
8095

8196
# %% [markdown]
@@ -86,8 +101,9 @@
86101
# if you want to re-run this notebook with more recent data.
87102

88103
# %%
89-
data_source_folder = Path("../datasets")
90-
for data_file in sorted(data_source_folder.iterdir()):
104+
data_source_folder = skrub.var("data_source_folder", Path("../datasets"))
105+
106+
for data_file in sorted(data_source_folder.skb.eval().iterdir()):
91107
print(data_file)
92108

93109
# %% [markdown]
@@ -97,58 +113,43 @@
97113
# electricity demand.
98114

99115
# %%
100-
city_names = [
101-
"paris",
102-
"lyon",
103-
"marseille",
104-
"toulouse",
105-
"lille",
106-
"limoges",
107-
"nantes",
108-
"strasbourg",
109-
"brest",
110-
"bayonne",
111-
]
112-
113-
# %%
114-
all_city_weather_raw = {}
115-
for city_name in city_names:
116-
# all_city_weather_raw[city_name] = skrub.var(
117-
# f"{city_name}_weather_raw",
118-
all_city_weather_raw[city_name] = (
119-
pl.from_arrow(read_table(f"../datasets/weather_{city_name}.parquet"))
120-
).with_columns(
121-
[
122-
pl.col("time").dt.cast_time_unit(
123-
"us"
124-
), # Ensure time column has the same type
125-
]
126-
)
127-
128-
# %%
129-
all_city_weather_raw["brest"]
116+
city_names = skrub.var(
117+
"city_names",
118+
[
119+
"paris",
120+
"lyon",
121+
"marseille",
122+
"toulouse",
123+
"lille",
124+
"limoges",
125+
"nantes",
126+
"strasbourg",
127+
"brest",
128+
"bayonne",
129+
],
130+
)
130131

131-
# %%
132-
all_city_weather_raw["brest"].drop_nulls(subset=["temperature_2m"])
133132

133+
@skrub.deferred
134+
def load_weather_data(time, city_names, data_source_folder):
135+
"""Load and horizontal stack historical weather forecast data for each city."""
136+
all_city_weather = time
137+
for city_name in city_names:
138+
all_city_weather = all_city_weather.join(
139+
pl.from_arrow(
140+
read_table(f"{data_source_folder}/weather_{city_name}.parquet")
141+
)
142+
.with_columns([pl.col("time").dt.cast_time_unit("us")])
143+
.rename(lambda x: x if x == "time" else "weather_" + x + "_" + city_name),
144+
on="time",
145+
)
146+
return all_city_weather
134147

135-
# %%
136-
all_city_weather = time.skb.eval()
137-
for city_name, city_weather_raw in all_city_weather_raw.items():
138-
all_city_weather = all_city_weather.join(
139-
city_weather_raw.rename(
140-
lambda x: x if x == "time" else "weather_" + x + "_" + city_name
141-
),
142-
on="time",
143-
how="inner",
144-
)
145148

146-
all_city_weather = skrub.var(
147-
"all_city_weather",
148-
all_city_weather,
149-
)
149+
all_city_weather = load_weather_data(time, city_names, data_source_folder)
150150
all_city_weather
151151

152+
152153
# %% [markdown]
153154
# ## Calendar and holidays features
154155
#
@@ -163,66 +164,70 @@
163164
# Similarly for the calendar features: all the time features are extracted from
164165
# the time in the French timezone.
165166
# %%
166-
holidays_fr = holidays.country_holidays("FR", years=range(2019, 2026))
167+
@skrub.deferred
168+
def prepare_french_calendar_data(time):
169+
fr_time = pl.col("time").dt.convert_time_zone("Europe/Paris")
170+
fr_year_min = time.select(fr_time.dt.year().min()).item()
171+
fr_year_max = time.select(fr_time.dt.year().max()).item()
172+
holidays_fr = holidays.country_holidays(
173+
"FR", years=range(fr_year_min, fr_year_max + 1)
174+
)
175+
return time.with_columns(
176+
[
177+
fr_time.dt.hour().alias("cal_hour_of_day"),
178+
fr_time.dt.weekday().alias("cal_day_of_week"),
179+
fr_time.dt.ordinal_day().alias("cal_day_of_year"),
180+
fr_time.dt.year().alias("cal_year"),
181+
fr_time.dt.date().is_in(holidays_fr.keys()).alias("cal_is_holiday"),
182+
],
183+
)
167184

168-
fr_time = pl.col("time").dt.convert_time_zone("Europe/Paris")
169-
calendar = time.with_columns(
170-
[
171-
fr_time.dt.hour().alias("cal_hour_of_day"),
172-
fr_time.dt.weekday().alias("cal_day_of_week"),
173-
fr_time.dt.ordinal_day().alias("cal_day_of_year"),
174-
fr_time.dt.year().alias("cal_year"),
175-
fr_time.dt.date().is_in(holidays_fr.keys()).alias("cal_is_holiday"),
176-
],
177-
)
185+
186+
calendar = prepare_french_calendar_data(time)
178187
calendar
179188

189+
180190
# %% [markdown]
181191
#
182192
# ## Electricity load data
183193
#
184194
# Finally we load the electricity load data. This data will both be used as a
185195
# target variable but also to craft some lagged and window-aggregated features.
186196
# %%
187-
load_data_files = [
188-
data_file
189-
for data_file in sorted(data_source_folder.iterdir())
190-
if data_file.name.startswith("Total Load - Day Ahead")
191-
and data_file.name.endswith(".csv")
192-
]
193-
# %%
194-
electricity_raw = skrub.var(
195-
"electricity_raw",
196-
pl.concat(
197-
[
198-
pl.from_pandas(pd.read_csv(data_file, na_values=["N/A", "-"])).drop(
199-
["Day-ahead Total Load Forecast [MW] - BZN|FR"]
197+
@skrub.deferred
198+
def load_electricity_load_data(time, data_source_folder):
199+
"""Load and aggregate historical load data from the raw CSV files."""
200+
load_data_files = [
201+
data_file
202+
for data_file in sorted(data_source_folder.iterdir())
203+
if data_file.name.startswith("Total Load - Day Ahead")
204+
and data_file.name.endswith(".csv")
205+
]
206+
return time.join(
207+
(
208+
pl.concat(
209+
[
210+
pl.from_pandas(pd.read_csv(data_file, na_values=["N/A", "-"])).drop(
211+
["Day-ahead Total Load Forecast [MW] - BZN|FR"]
212+
)
213+
for data_file in load_data_files
214+
]
215+
).select(
216+
[
217+
pl.col("Time (UTC)")
218+
.str.split(by=" - ")
219+
.list.first()
220+
.str.to_datetime("%d.%m.%Y %H:%M", time_zone="UTC")
221+
.alias("time"),
222+
pl.col("Actual Total Load [MW] - BZN|FR").alias("load_mw"),
223+
]
200224
)
201-
for data_file in load_data_files
202-
],
203-
how="vertical",
204-
),
205-
)
206-
electricity_raw
207-
208-
# %%
209-
electricity = (
210-
electricity_raw.with_columns(
211-
[
212-
pl.col("Time (UTC)")
213-
.str.split(by=" - ")
214-
.list.first()
215-
.str.to_datetime("%d.%m.%Y %H:%M", time_zone="UTC")
216-
.alias("time"),
217-
]
225+
),
226+
on="time",
218227
)
219-
.drop(["Time (UTC)"])
220-
.rename({"Actual Total Load [MW] - BZN|FR": "load_mw"})
221-
.filter(pl.col("time").dt.minute().eq(0))
222-
.filter(pl.col("time") >= time_range_start)
223-
.filter(pl.col("time") <= time_range_end)
224-
.select(["time", "load_mw"])
225-
)
228+
229+
230+
electricity = load_electricity_load_data(time, data_source_folder)
226231
electricity
227232

228233
# %%
@@ -392,17 +397,25 @@ def iqr(col, *, window_size: int):
392397
# models via backtesting.
393398

394399
# %%
395-
prediction_time = time = skrub.var(
396-
"prediction_time",
397-
pl.DataFrame().with_columns(
400+
prediction_start_time = skrub.var(
401+
"prediction_start_time", historical_data_start_time.skb.eval() + pl.duration(days=7)
402+
)
403+
prediction_end_time = skrub.var(
404+
"prediction_end_time", historical_data_end_time.skb.eval() - pl.duration(hours=24)
405+
)
406+
407+
@skrub.deferred
408+
def define_prediction_time_range(prediction_start_time, prediction_end_time):
409+
return pl.DataFrame().with_columns(
398410
pl.datetime_range(
399-
start=time_range_start + pl.duration(days=7),
400-
end=time_range_end - pl.duration(hours=24),
411+
start=prediction_start_time,
412+
end=prediction_end_time,
401413
time_zone="UTC",
402414
interval="1h",
403415
).alias("prediction_time"),
404-
),
405-
)
416+
)
417+
418+
prediction_time = define_prediction_time_range(prediction_start_time, prediction_end_time)
406419
prediction_time
407420

408421

0 commit comments

Comments
 (0)