Skip to content

Commit 3dfe953

Browse files
committed
checkpoint
1 parent 3161c9d commit 3dfe953

2 files changed

Lines changed: 114 additions & 58 deletions

File tree

causal_automl/TutorTask401_EIA_metadata_downloader_pipeline/eia_utils.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
"""
66

77
import logging
8-
from typing import Any, Dict, List, Optional, Tuple
8+
import re
9+
from typing import Any, Dict, List, Optional, Tuple, cast
910

11+
import helpers.hdbg as hdbg
1012
import matplotlib.pyplot as plt
1113
import pandas as pd
1214
import requests
@@ -312,8 +314,11 @@ def build_full_url(
312314
"""
313315
Build an EIA v2 API URL to data endpoint.
314316
315-
This modifies the base metadata URL to point to the actual time series
316-
data endpoint, optionally appending facet values and date range.
317+
This function modifies the base metadata URL by:
318+
- Replacing the metadata endpoint with the actual data endpoint
319+
- Injecting the provided API key
320+
- Appending optional facet filters
321+
- Appending start and end timestamps formatted to match the series frequency
317322
318323
:param base_url: base API URL with frequency and metric, excluding
319324
facet values,
@@ -325,13 +330,17 @@ def build_full_url(
325330
:return: full EIA API URL to data endpoint,
326331
e.g, "https://api.eia.gov/v2/electricity/retail-sales/data?api_key=abcd1234xyz&frequency=monthly&data[0]=price&facets[stateid][]=KS&facets[sectorid][]=OTH"
327332
"""
333+
match = cast(re.Match[str], re.search(r"frequency=([a-zA-Z\-]+)", base_url))
334+
frequency = match.group(1)
328335
base_url = base_url.replace("?", "/data?")
329336
url = base_url.replace("{API_KEY}", api_key)
330337
query_parts = []
331338
if start_timestamp:
332-
query_parts.append(f"&start={start_timestamp}")
339+
formatted_start = _format_timestamp(start_timestamp, frequency)
340+
query_parts.append(f"&start={formatted_start}")
333341
if end_timestamp:
334-
query_parts.append(f"&end={end_timestamp}")
342+
formatted_end = _format_timestamp(end_timestamp, frequency)
343+
query_parts.append(f"&end={formatted_end}")
335344
if facet_input:
336345
# Add facet values when specified.
337346
for facet_id, value in facet_input.items():
@@ -340,6 +349,41 @@ def build_full_url(
340349
return full_url
341350

342351

352+
def _format_timestamp(timestamp: pd.Timestamp, frequency: str) -> pd.Timestamp:
353+
"""
354+
Format a timestamp based on the EIA time series frequency.
355+
356+
Supported formats:
357+
- "annual": "YYYY"
358+
- "quarterly": "YYYY-QN"
359+
- "monthly": "YYYY-MM"
360+
- "daily": "YYYY-MM-DD"
361+
- "hourly": "YYYY-MM-DDTHH"
362+
- "local-hourly": "YYYY-MM-DDTHH-ZZ" (fixed timezone offset, e.g., -00)
363+
364+
:param timestamp: the timestamp to format
365+
:param frequency: the frequency type (e.g., "monthly", "local-hourly")
366+
:return: formatted timestamp
367+
"""
368+
result = ""
369+
if frequency == "annual":
370+
result = timestamp.strftime("%Y")
371+
elif frequency == "monthly":
372+
result = timestamp.strftime("%Y-%m")
373+
elif frequency == "quarterly":
374+
q = (timestamp.month - 1) // 3 + 1
375+
result = f"{timestamp.year}-Q{q}"
376+
elif frequency == "daily":
377+
result = timestamp.strftime("%Y-%m-%d")
378+
elif frequency == "hourly":
379+
result = timestamp.strftime("%Y-%m-%dT%H")
380+
elif frequency == "local-hourly":
381+
result = timestamp.strftime("%Y-%m-%dT%H") + "-00"
382+
else:
383+
raise ValueError(f"Unsupported frequency: {frequency}")
384+
return result
385+
386+
343387
def plot_distribution(df_metadata: pd.DataFrame, column: str, title: str) -> None:
344388
"""
345389
Plot a distribution count for a specified metadata column.

causal_automl/download_eia_data.py

Lines changed: 65 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ class EiaDataDownloader:
2929
Download historical data from EIA.
3030
"""
3131

32-
def __init__(
33-
self, *, aws_profile: Optional[str] = "ck"
34-
) -> None:
32+
def __init__(self, *, aws_profile: str = "ck") -> None:
3533
"""
3634
Initialize the EIA data downloader with the API key and AWS profile.
3735
@@ -47,7 +45,7 @@ def __init__(
4745
self._api_key = os.getenv("EIA_API_KEY")
4846
self._client = myeia.API(token=self._api_key)
4947
self._aws_profile = aws_profile
50-
self.base_url = "https://api.eia.gov/v2/"
48+
self._metadata_index_by_category: Dict[str, pd.DataFrame] = {}
5149

5250
def filter_series(
5351
self,
@@ -58,29 +56,26 @@ def filter_series(
5856
"""
5957
Filter and clean a single time series from an EIA dataset.
6058
61-
Apply facet filters (e.g., state, sector) to select one unique
62-
series, drop missing values, and convert the time column to a
63-
UTC-indexed datetime format.
59+
This function performs data post-processing:
60+
- Filter by facet values (e.g., "stateid", "sectorid")
61+
- Retain only the period and metric column
62+
- Convert the period column to UTC datetime
63+
- Set the period as the index and sort chronologically
6464
6565
:param df: EIA series data
6666
:param id_: EIA series ID, e.g.,
6767
"electricity.retail_sales.monthly.price"
68-
:param facets: facet filters,
68+
:param facets: facet filters,
6969
e.g., {"stateid": "WI", "sectorid": "ALL"}
7070
:return: data of single time series with one facet value per
7171
facet type
7272
7373
Example output:
7474
```
75-
period stateid stateDescription sectorid
76-
2001-01-01T00:00:00+00:00 WI Wisconsin ALL
77-
2001-02-01T00:00:00+00:00 WI Wisconsin ALL
78-
2001-03-01T00:00:00+00:00 WI Wisconsin ALL
79-
80-
sectorName price price-units
81-
all sectors 5.9 cents per kilowatt-hour
82-
all sectors 5.98 cents per kilowatt-hour
83-
all sectors 5.93 cents per kilowatt-hour
75+
period price
76+
2001-01-01T00:00:00+00:00 5.9
77+
2001-02-01T00:00:00+00:00 5.98
78+
2001-03-01T00:00:00+00:00 5.93
8479
```
8580
"""
8681
# Filter data with given facet values.
@@ -94,16 +89,15 @@ def filter_series(
9489
)
9590
df = df[df[key] == val]
9691
# Detect the metric column.
97-
_, data_identifier = self._parse_id(id_)
92+
_, _, _, data_identifier = self._parse_id(id_)
93+
df = df[["period", data_identifier]]
9894
# Drop rows with missing value.
9995
df = df.dropna(subset=[data_identifier])
10096
if df.empty:
10197
_LOG.warning("No data remaining after applying facets.")
102-
# Convert to datetime index.
103-
df["period"] = pd.to_datetime(df["period"])
104-
df = df.rename(columns={"period": "period (UTC)"})
105-
df = df.set_index("period (UTC)")
106-
df.index = df.index.tz_localize("UTC")
98+
# Convert to datetime and index.
99+
df["period"] = pd.to_datetime(df["period"]).dt.tz_localize("UTC")
100+
df = df.set_index("period")
107101
df = df.sort_index()
108102
return df
109103

@@ -118,19 +112,23 @@ def download_series(
118112
"""
119113
Download EIA historical series data.
120114
121-
This method retrieves the full set of time series linked to an
122-
EIA identifier, including all combinations of facet values
123-
(e.g., `stateid`, `sectorid`). When no start and end timestamps are
124-
passed, the entire time series is downloaded.
115+
This method retrieves the full set of time series linked to an
116+
EIA identifier, including all combinations of facet values
117+
(e.g., `stateid`, `sectorid`). When no start and end timestamps are
118+
passed, the entire time series is downloaded.
125119
126-
:param id_: EIA series ID, e.g.,
127-
"electricity.retail_sales.monthly.price"
128-
:param start_timestamp: first observation date
129-
:param end_timestamp: last observation date
130-
:param max_rows_per_call: max data rows per api call
131-
:return: full time series data with all facets
120+
Pagination is handled internally. The `max_rows_per_call` parameter
121+
controls the page size for each API request, but the method will
122+
continue fetching until all available data is retrieved.
132123
133-
Example output:
124+
:param id_: EIA series ID, e.g.,
125+
"electricity.retail_sales.monthly.price"
126+
:param start_timestamp: first observation date
127+
:param end_timestamp: last observation date
128+
:param max_rows_per_call: max data rows per API call
129+
:return: full time series data with all facets
130+
131+
Example output:
134132
```
135133
period stateid stateDescription sectorid sectorName
136134
2020-09 WI Wisconsin IND industrial
@@ -145,7 +143,7 @@ def download_series(
145143
"""
146144
# Get base url from metadata index.
147145
base_url = self._get_metadata_url(id_)
148-
# Build URL query with api key and timestamps.
146+
# Build URL query with API key and timestamps.
149147
url = catemdpeu.build_full_url(
150148
base_url,
151149
self._api_key,
@@ -160,7 +158,7 @@ def download_series(
160158
data = self._client.get_response(paginated_url, self._client.header)
161159
data_chunks.append(data)
162160
if len(data) < max_rows_per_call:
163-
# Exit loop when its the final page of data.
161+
# Exit loop when it's the final page of data.
164162
break
165163
offset += max_rows_per_call
166164
if not data_chunks:
@@ -169,31 +167,40 @@ def download_series(
169167
_LOG.debug("Downloaded %d rows for id=%s", len(df), id_)
170168
return df
171169

172-
def _parse_id(self, id_: str) -> Tuple[str, str]:
170+
def _parse_id(self, id_: str) -> Tuple[str, str, str, str]:
173171
"""
174172
Parse an EIA time series ID into its components.
175173
174+
EIA time series IDs follow the format:
175+
<category>.<subroute>.<frequency>.<data_identifier>
176+
177+
Underscores are converted to dashes to match the EIA API format.
178+
176179
:param id_: EIA time series ID,
177180
e.g., "electricity.retail_sales.monthly.price"
178181
:return:
179182
- top-level EIA category, e.g., "electricity"
183+
- subroute in the category, e.g., "retail-sales"
184+
- reporting frequency, e.g., "monthly"
180185
- data identifier, e.g., "price"
181186
"""
182187
id_ = id_.replace("_", "-")
183188
parts = id_.split(".")
184189
category = parts[0]
190+
frequency = parts[-2]
185191
data_identifier = parts[-1]
186-
return category, data_identifier
192+
route_parts = parts[1:-2]
193+
subroute = "/".join(route_parts)
194+
return category, subroute, frequency, data_identifier
187195

188-
def _get_latest_metadata_s3_path(self, category: str) -> str:
196+
def _get_latest_metadata_from_s3(self, category: str) -> pd.DataFrame:
189197
"""
190-
Get the latest versioned metadata file S3 path for a given category.
198+
Get the latest versioned metadata index file from S3 for a category.
191199
192200
:param category: top-level EIA category, e.g., "electricity"
193-
:return: full S3 path to the latest version of the metadata CSV
194-
e.g., "eia_electricity_metadata_original_v2.0.csv"
201+
:return: latest versioned metadata index
195202
"""
196-
# Get file names from s3 bucket.
203+
# Get file names from S3 bucket.
197204
base_dir = "s3://causify-data-collaborators/causal_automl/metadata"
198205
pattern = f"eia_{category}_metadata_original_v*"
199206
files = hs3.listdir(
@@ -211,25 +218,30 @@ def _get_latest_metadata_s3_path(self, category: str) -> str:
211218
# Get latest file version.
212219
files.sort(reverse=True)
213220
s3_path = f"s3://{files[0]}"
214-
return s3_path
221+
# Load latest metadata index file from S3.
222+
csv_str = hs3.from_file(s3_path, aws_profile=self._aws_profile)
223+
df = pd.read_csv(io.StringIO(csv_str))
224+
return df
215225

216226
def _get_metadata_url(self, id_: str) -> str:
217227
"""
228+
Get base URL for given series ID from the metadata index.
229+
218230
:param id_: EIA time series ID,
219231
e.g., "electricity.retail_sales.monthly.price"
220-
:param category: top-level EIA category, e.g., "electricity"
221232
:return: base API URL with frequency and metric, excluding facet values,
222233
e.g., "https://api.eia.gov/v2/electricity/retail-sales?api_key={API_KEY}&frequency=monthly&data[0]=revenue"
223234
"""
224-
category, _ = self._parse_id(id_)
225-
# Load latest metadata index file from s3.
226-
s3_path = self._get_latest_metadata_s3_path(category)
227-
csv_str = hs3.from_file(s3_path, aws_profile=self._aws_profile)
228-
df = pd.read_csv(io.StringIO(csv_str))
235+
category, _, _, _ = self._parse_id(id_)
236+
# Load latest metadata index file from S3.
237+
if category not in self._metadata_index_by_category:
238+
self._metadata_index_by_category[category] = (
239+
self._get_latest_metadata_from_s3(category)
240+
)
241+
df = self._metadata_index_by_category[category]
229242
# Filter for exact ID match.
230243
match = df[df["id"] == id_]
231244
if match.empty:
232-
raise ValueError(f"Invalid id: '{id_}'")
233-
row = match.iloc[0]
234-
base_url = str(row["url"])
245+
raise ValueError(f"Invalid ID: '{id_}'")
246+
base_url: str = match.iloc[0]["url"]
235247
return base_url

0 commit comments

Comments
 (0)