Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
"""

import logging
from typing import Any, Dict, List, Tuple
import re
from typing import Any, Dict, List, Optional, Tuple, cast

import helpers.hdbg as hdbg
import matplotlib.pyplot as plt
import pandas as pd
import requests
Expand Down Expand Up @@ -118,11 +120,17 @@ def _get_api_request(self, route: str) -> Dict[str, Any]:
# Build the full API request URL.
url = f"{self._base_url}/{route}?api_key={self._api_key}"
# Send HTTP GET request to the EIA API.
# TODO(alvino): Add error handling for the HTTP request to handle
# potential exceptions such as connection errors or timeouts.
response = requests.get(url, timeout=20)
# Parse JSON content.
# TODO(alvino): Check if the response is successful (e.g.,
# `response.status_code == 200`) before attempting to parse the JSON
# content.
json_data = response.json()
# Get response from parsed payload.
data: Dict[str, Any] = {}
# TODO(alvino): Add error handling for JSON parsing to manage potential parsing errors.
data = json_data.get("response", {})
return data

Expand Down Expand Up @@ -238,6 +246,8 @@ def _extract_metadata(
# Determine parameter CSV path for associated facet values.
param_file_path = f"eia_parameters_v{self._version_num}/{dataset_id_clean}_parameters.csv"
# Flattened metadata row for one frequency and metric combination.
# TODO(gp): `.get()` will use `None` if there is a missing
# value in the dictionary. Is this the intended behavior?
Comment thread
aangelo9 marked this conversation as resolved.
Outdated
metadata = {
"url": url,
"id": f"{route_clean}.{frequency_id}.{metric_id_clean}",
Expand Down Expand Up @@ -270,6 +280,7 @@ def _get_facet_values(
:param route: dataset route under the EIA v2 API
:return: data containing all facet values
"""
hdbg.dassert_in("facets", metadata)
Comment thread
aangelo9 marked this conversation as resolved.
Outdated
facets = metadata["facets"]
rows = []
for facet in facets:
Expand All @@ -295,31 +306,84 @@ def _get_facet_values(
def build_full_url(
base_url: str,
api_key: str,
facet_input: Dict[str, str],
*,
facet_input: Optional[Dict[str, str]] = None,
start_timestamp: Optional[pd.Timestamp] = None,
end_timestamp: Optional[pd.Timestamp] = None,
) -> str:
"""
Build a full EIA v2 API URL by appending one facet value per facet type.
Build an EIA v2 API URL to data endpoint.

This modifies the base metadata URL to point to the actual time series
data endpoint.
This function modifies the base metadata URL by:
- Replacing the metadata endpoint with the actual data endpoint
- Injecting the provided API key
- Appending optional facet filters
- Appending start and end timestamps formatted to match the series frequency

:param base_url: base API URL with frequency and metric, excluding
facet values,
e.g., "https://api.eia.gov/v2/electricity/retail-sales?api_key={API_KEY}&frequency=monthly&data[0]=revenue"
:param api_key: EIA API key, e.g., "abcd1234xyz"
:param facet_input: specified facet values, e.g., {"stateid": "KS", "sectorid": "COM"}
:return: full EIA API URL with all required facet parameters,
:param start_timestamp: first observation date
:param end_timestamp: last observation date
:return: full EIA API URL to data endpoint,
e.g, "https://api.eia.gov/v2/electricity/retail-sales/data?api_key=abcd1234xyz&frequency=monthly&data[0]=price&facets[stateid][]=KS&facets[sectorid][]=OTH"
"""
match = cast(re.Match[str], re.search(r"frequency=([a-zA-Z\-]+)", base_url))
frequency = match.group(1)
base_url = base_url.replace("?", "/data?")
url = base_url.replace("{API_KEY}", api_key)
query_parts = []
for facet_id, value in facet_input.items():
query_parts.append(f"&facets[{facet_id}][]={value}")
if start_timestamp:
formatted_start = _format_timestamp(start_timestamp, frequency)
query_parts.append(f"&start={formatted_start}")
if end_timestamp:
formatted_end = _format_timestamp(end_timestamp, frequency)
query_parts.append(f"&end={formatted_end}")
if facet_input:
# Add facet values when specified.
for facet_id, value in facet_input.items():
query_parts.append(f"&facets[{facet_id}][]={value}")
full_url = url + "".join(query_parts)
return full_url


def _format_timestamp(timestamp: pd.Timestamp, frequency: str) -> pd.Timestamp:
"""
Format a timestamp based on the EIA time series frequency.

Supported formats:
- "annual": "YYYY"
- "quarterly": "YYYY-QN"
- "monthly": "YYYY-MM"
- "daily": "YYYY-MM-DD"
- "hourly": "YYYY-MM-DDTHH"
- "local-hourly": "YYYY-MM-DDTHH-ZZ" (fixed timezone offset, e.g., "-00")

:param timestamp: the timestamp to format
:param frequency: the frequency type (e.g., "monthly", "local-hourly")
:return: formatted timestamp
"""
result = ""
if frequency == "annual":
result = timestamp.strftime("%Y")
elif frequency == "monthly":
result = timestamp.strftime("%Y-%m")
elif frequency == "quarterly":
q = (timestamp.month - 1) // 3 + 1
result = f"{timestamp.year}-Q{q}"
elif frequency == "daily":
result = timestamp.strftime("%Y-%m-%d")
elif frequency == "hourly":
result = timestamp.strftime("%Y-%m-%dT%H")
elif frequency == "local-hourly":
result = timestamp.strftime("%Y-%m-%dT%H") + "-00"
else:
raise ValueError(f"Unsupported frequency: {frequency}")
return result


def plot_distribution(df_metadata: pd.DataFrame, column: str, title: str) -> None:
"""
Plot a distribution count for a specified metadata column.
Expand All @@ -329,8 +393,7 @@ def plot_distribution(df_metadata: pd.DataFrame, column: str, title: str) -> Non
'frequency_id', 'data_units')
:param title: title for the plot
"""
if column not in df_metadata.columns:
raise ValueError(f"Column '{column}' not found in metadata index.")
hdbg.dassert_in(column, df_metadata.columns)
Comment thread
aangelo9 marked this conversation as resolved.
Outdated
counts = df_metadata[column].value_counts()
ax = counts.plot(kind="bar", figsize=(8, 4), title=title)
ax.set_xlabel(column.replace("_", " ").title())
Expand Down
247 changes: 247 additions & 0 deletions causal_automl/download_eia_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
"""
Import as:

import causal_automl.download_eia_data as cadoeida
"""

import io
import logging
import os
from typing import Dict, Optional, Tuple

import helpers.hdbg as hdbg
import helpers.hs3 as hs3
import myeia
import pandas as pd

import causal_automl.TutorTask401_EIA_metadata_downloader_pipeline.eia_utils as catemdpeu

_LOG = logging.getLogger(__name__)


# #############################################################################
# EiaDataDownloader
# #############################################################################


class EiaDataDownloader:
"""
Download historical data from EIA.
"""

def __init__(self, *, aws_profile: str = "ck") -> None:
"""
Initialize the EIA data downloader with the API key and AWS profile.

EIA API key is read from the environment variable.

:param aws_profile: AWS CLI profile name used for authentication
"""
hdbg.dassert_in(
"EIA_API_KEY",
os.environ,
msg="EIA_API_KEY is not found in environment variables",
)
self._api_key = os.getenv("EIA_API_KEY")
self._client = myeia.API(token=self._api_key)
self._aws_profile = aws_profile
self._metadata_index_by_category: Dict[str, pd.DataFrame] = {}

def filter_series(
self,
df: pd.DataFrame,
id_: str,
facets: Dict[str, str],
) -> pd.DataFrame:
"""
Filter and clean a single time series from an EIA dataset.

This function performs data post-processing:
- Filter by facet values (e.g., "stateid", "sectorid")
- Retain only the period and metric column
- Convert the period column to UTC datetime
- Set the period as the index and sort chronologically

:param df: EIA series data
:param id_: EIA series ID, e.g.,
"electricity.retail_sales.monthly.price"
:param facets: facet filters,
e.g., {"stateid": "WI", "sectorid": "ALL"}
:return: data of single time series with one facet value per
facet type

Example output:
```
period price
2001-01-01T00:00:00+00:00 5.9
2001-02-01T00:00:00+00:00 5.98
2001-03-01T00:00:00+00:00 5.93
```
"""
# Filter data with given facet values.
for key, val in facets.items():
hdbg.dassert_in(
key,
df.columns,
"Facet '%s' not found in data columns=%s",
key,
list(df.columns),
)
df = df[df[key] == val]
# Detect the metric column.
_, _, _, data_identifier = self._parse_id(id_)
df = df[["period", data_identifier]]
# Drop rows with missing value.
df = df.dropna(subset=[data_identifier])
Comment thread
aangelo9 marked this conversation as resolved.
if df.empty:
_LOG.warning("No data remaining after applying facets.")
Comment thread
aangelo9 marked this conversation as resolved.
Outdated
# Convert to datetime and index.
df["period"] = pd.to_datetime(df["period"]).dt.tz_localize("UTC")
df = df.set_index("period")
df = df.sort_index()
return df

def download_series(
self,
id_: str,
*,
start_timestamp: Optional[pd.Timestamp] = None,
end_timestamp: Optional[pd.Timestamp] = None,
max_rows_per_call: int = 5000,
) -> pd.DataFrame:
"""
Download EIA historical series data.

This method retrieves the full set of time series linked to an
EIA identifier, including all combinations of facet values
(e.g., `stateid`, `sectorid`). When no start and end timestamps are
passed, the entire time series is downloaded.

Pagination is handled internally. The `max_rows_per_call` parameter
controls the page size for each API request, but the method will
continue fetching until all available data is retrieved.

:param id_: EIA series ID, e.g.,
"electricity.retail_sales.monthly.price"
:param start_timestamp: first observation date
:param end_timestamp: last observation date
:param max_rows_per_call: max data rows per API call
:return: full time series data with all facets

Example output:
```
period stateid stateDescription sectorid sectorName
2020-09 WI Wisconsin IND industrial
2020-09 WY Wyoming ALL all sectors
2020-09 IA Iowa RES Residential

price price-units
7.45 cents per kilowatt-hour
8.55 cents per kilowatt-hour
12.65 cents per kilowatt-hour
```
"""
# Get base url from metadata index.
base_url = self._get_metadata_url(id_)
# Build URL query with API key and timestamps.
url = catemdpeu.build_full_url(
base_url,
self._api_key,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
)
data_chunks = []
offset = 0
while True:
# Construct the paginated URL for the current offset.
paginated_url = f"{url}&offset={offset}&length={max_rows_per_call}"
data = self._client.get_response(paginated_url, self._client.header)
Comment thread
aangelo9 marked this conversation as resolved.
data_chunks.append(data)
if len(data) < max_rows_per_call:
# Exit loop when it's the final page of data.
break
offset += max_rows_per_call
if not data_chunks:
_LOG.warning("No data returned under given id.")
df = pd.concat(data_chunks, ignore_index=True)
_LOG.debug("Downloaded %d rows for id=%s", len(df), id_)
return df

def _parse_id(self, id_: str) -> Tuple[str, str, str, str]:
"""
Parse an EIA time series ID into its components.
Comment thread
aangelo9 marked this conversation as resolved.

EIA time series IDs follow the format:
<category>.<subroute>.<frequency>.<data_identifier>

Underscores are converted to dashes to match the EIA API format.

:param id_: EIA time series ID,
e.g., "electricity.retail_sales.monthly.price"
:return:
- top-level EIA category, e.g., "electricity"
- subroute in the category, e.g., "retail-sales"
- reporting frequency, e.g., "monthly"
- data identifier, e.g., "price"
"""
id_ = id_.replace("_", "-")
parts = id_.split(".")
category = parts[0]
frequency = parts[-2]
data_identifier = parts[-1]
route_parts = parts[1:-2]
subroute = "/".join(route_parts)
return category, subroute, frequency, data_identifier

def _get_latest_metadata_from_s3(self, category: str) -> pd.DataFrame:
"""
Get the latest versioned metadata index file from S3 for a category.

:param category: top-level EIA category, e.g., "electricity"
:return: latest versioned metadata index
"""
# Get file names from S3 bucket.
base_dir = "s3://causify-data-collaborators/causal_automl/metadata"
pattern = f"eia_{category}_metadata_original_v*"
files = hs3.listdir(
dir_name=base_dir,
pattern=pattern,
only_files=True,
use_relative_paths=False,
aws_profile=self._aws_profile,
maxdepth=1,
)
if not files:
raise FileNotFoundError(
f"No metadata index file found for category: '{category}' in S3."
)
# Get latest file version.
files.sort(reverse=True)
s3_path = f"s3://{files[0]}"
# Load latest metadata index file from S3.
csv_str = hs3.from_file(s3_path, aws_profile=self._aws_profile)
df = pd.read_csv(io.StringIO(csv_str))
return df

def _get_metadata_url(self, id_: str) -> str:
"""
Get base URL for given series ID from the metadata index.

:param id_: EIA time series ID,
e.g., "electricity.retail_sales.monthly.price"
:return: base API URL with frequency and metric, excluding facet values,
e.g., "https://api.eia.gov/v2/electricity/retail-sales?api_key={API_KEY}&frequency=monthly&data[0]=revenue"
"""
category, _, _, _ = self._parse_id(id_)
# Load latest metadata index file from S3.
if category not in self._metadata_index_by_category:
self._metadata_index_by_category[category] = (
self._get_latest_metadata_from_s3(category)
)
df = self._metadata_index_by_category[category]
# Filter for exact ID match.
match = df[df["id"] == id_]
if match.empty:
raise ValueError(f"Invalid ID: '{id_}'")
base_url: str = match.iloc[0]["url"]
return base_url