Skip to content

Commit fc935cc

Browse files
committed
TutorTask527: Reviewer changes + minor equivalent change in the FRED downloader
Pre-commit checks: All checks passed ✅
1 parent 0093139 commit fc935cc

2 files changed

Lines changed: 51 additions & 19 deletions

File tree

causal_automl/download_fred_data.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,6 @@ class FredDataDownloader:
3030
def __init__(self) -> None:
3131
"""
3232
Initialize the FRED data downloader with the API key.
33-
34-
If no FRED API key is passed as a parameter, it is read from the
35-
environment variable.
36-
37-
:param api_key: FRED API key
3833
"""
3934
hdbg.dassert_in(
4035
"FRED_API_KEY",

causal_automl/download_gridstatus_data.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import logging
88
import os
99
import time
10-
from typing import Dict, Optional
10+
from typing import Dict, Optional, Union
1111

1212
import gridstatusio
1313
import helpers.hdbg as hdbg
@@ -30,11 +30,6 @@ class GridstatusDataDownloader:
3030
def __init__(self) -> None:
3131
"""
3232
Initialize the GridStatus data downloader with the API key.
33-
34-
If no API key is passed as a parameter, it is read from the
35-
GRIDSTATUS_API_KEY environment variable.
36-
37-
:param api_key: GridStatus API key
3833
"""
3934
hdbg.dassert_in(
4035
"GRIDSTATUS_API_KEY",
@@ -49,8 +44,8 @@ def __init__(self) -> None:
4944
def download_series(
5045
self,
5146
id_: str,
52-
start_timestamp: Optional[pd.Timestamp] = None,
53-
end_timestamp: Optional[pd.Timestamp] = None,
47+
start_timestamp: Optional[Union[str, pd.Timestamp]] = None,
48+
end_timestamp: Optional[Union[str, pd.Timestamp]] = None,
5449
) -> Optional[pd.DataFrame]:
5550
"""
5651
Download historical series data.
@@ -69,13 +64,14 @@ def download_series(
6964
0.5
7065
```
7166
72-
:param id_: GridStatus dataset identifier (e.g., "caiso_as_prices.spinning_reserves")
73-
:param start_timestamp: first observation timestamp (e.g., "2010-01-01 08:00:00+00:00")
67+
:param id_: Gridstatus series identifier (e.g., "caiso_as_prices.spinning_reserves")
68+
:param start_timestamp: first observation timestamp
69+
(e.g., "2010-01-01 08:00:00+00:00" or pd.Timestamp("2023-04-01 01:00:00"))
7470
:param end_timestamp: last observation timestamp
75-
:return: relevant GridStatus series data
71+
:return: relevant Gridstatus series data
7672
"""
7773
# Build request parameters.
78-
id_series, name_series = id_.split(".", 1)
74+
id_dataset, name_series = id_.split(".", 1)
7975
request_kwargs: Dict[str, str] = {}
8076
if start_timestamp is not None:
8177
request_kwargs["start"] = start_timestamp
@@ -89,7 +85,7 @@ def download_series(
8985
try:
9086
# Download the data for the dataset.
9187
df = self._client.get_dataset(
92-
dataset=id_series,
88+
dataset=id_dataset,
9389
columns=[name_series],
9490
**request_kwargs,
9591
)
@@ -106,11 +102,52 @@ def download_series(
106102
continue
107103
# Log success and return.
108104
_LOG.info(
109-
"Downloaded dataset %s with %d records",
105+
"Downloaded series %s with %d records",
110106
id_,
111107
len(df),
112108
)
113109
return df
114110
raise RuntimeError(
115111
f"Failed to fetch after {max_attempts} attempts. Errors per run: {err_msgs}"
116112
)
113+
114+
def filter_series(
115+
self,
116+
df: pd.DataFrame,
117+
id_: str,
118+
filters: Dict[str, str],
119+
) -> Optional[pd.DataFrame]:
120+
"""
121+
Filter out a single time series from a Gridstatus dataset.
122+
123+
Apply single filters across columns (e.g., `region`, `market`),
124+
drop missing rows and return end timestamp-indexed single series.
125+
126+
:param df: Gridstatus data series to filter
127+
:param id_: Gridstatus series identifier (e.g., "caiso_as_prices.spinning_reserves")
128+
:param filters: filters to apply on the dataset
129+
(e.g., {"region":"AS_CAISO_EXP", "market":"DAM"})
130+
:return: filtered Gridstatus series
131+
"""
132+
# Filter data.
133+
filtered_data = df.copy()
134+
for k, v in filters.items():
135+
hdbg.dassert_in(
136+
k,
137+
filtered_data.columns,
138+
"%s not found in columns: %s",
139+
k,
140+
list(filtered_data.columns),
141+
)
142+
filtered_data = filtered_data[filtered_data[k] == v]
143+
# Find the series name.
144+
name_series = id_.split(".", 1)[1]
145+
# Drop missing value rows.
146+
filtered_data = filtered_data.dropna(subset=[name_series])
147+
if filtered_data.empty:
148+
_LOG.warning("No data remaining after applying filters")
149+
return None
150+
filtered_data = filtered_data[["interval_end_utc", name_series]]
151+
filtered_data = filtered_data.set_index("interval_end_utc")
152+
filtered_data = filtered_data.sort_index()
153+
return filtered_data

0 commit comments

Comments
 (0)