Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import logging
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import pandas as pd
Expand Down Expand Up @@ -295,27 +295,38 @@ def _get_facet_values(
def build_full_url(
base_url: str,
api_key: str,
facet_input: Dict[str, str],
*,
facet_input: Optional[Dict[str, str]] = None,
start_timestamp: Optional[pd.Timestamp] = None,
end_timestamp: Optional[pd.Timestamp] = None,
) -> str:
"""
Build a full EIA v2 API URL by appending one facet value per facet type.
Build an EIA v2 API URL to data endpoint.

This modifies the base metadata URL to point to the actual time series
data endpoint.
data endpoint, optionally appending facet values and date range.

:param base_url: base API URL with frequency and metric, excluding
facet values,
e.g., "https://api.eia.gov/v2/electricity/retail-sales?api_key={API_KEY}&frequency=monthly&data[0]=revenue"
:param api_key: EIA API key, e.g., "abcd1234xyz"
:param facet_input: specified facet values, e.g., {"stateid": "KS", "sectorid": "COM"}
:return: full EIA API URL with all required facet parameters,
:param start_timestamp: first observation date
:param end_timestamp: last observation date
:return: full EIA API URL to data endpoint,
e.g, "https://api.eia.gov/v2/electricity/retail-sales/data?api_key=abcd1234xyz&frequency=monthly&data[0]=price&facets[stateid][]=KS&facets[sectorid][]=OTH"
"""
base_url = base_url.replace("?", "/data?")
url = base_url.replace("{API_KEY}", api_key)
query_parts = []
for facet_id, value in facet_input.items():
query_parts.append(f"&facets[{facet_id}][]={value}")
if start_timestamp:
query_parts.append(f"&start={start_timestamp}")
if end_timestamp:
query_parts.append(f"&end={end_timestamp}")
Comment thread
aangelo9 marked this conversation as resolved.
Outdated
if facet_input:
# Add facet values when specified.
for facet_id, value in facet_input.items():
query_parts.append(f"&facets[{facet_id}][]={value}")
full_url = url + "".join(query_parts)
return full_url

Expand Down
235 changes: 235 additions & 0 deletions causal_automl/download_eia_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
"""
Import as:

import causal_automl.download_eia_data as cadoeida
"""

import io
import logging
import os
from typing import Dict, Optional, Tuple

import helpers.hdbg as hdbg
import helpers.hs3 as hs3
import myeia
import pandas as pd

import causal_automl.TutorTask401_EIA_metadata_downloader_pipeline.eia_utils as catemdpeu

_LOG = logging.getLogger(__name__)


# #############################################################################
# EiaDataDownloader
# #############################################################################


class EiaDataDownloader:
"""
Download historical data from EIA.
"""

def __init__(
self, *, aws_profile: Optional[str] = "ck"
Comment thread
aangelo9 marked this conversation as resolved.
Outdated
) -> None:
"""
Initialize the EIA data downloader with the API key and AWS profile.

EIA API key is read from the environment variable.

:param aws_profile: AWS CLI profile name used for authentication
"""
hdbg.dassert_in(
"EIA_API_KEY",
os.environ,
msg="EIA_API_KEY is not found in environment variables",
)
self._api_key = os.getenv("EIA_API_KEY")
self._client = myeia.API(token=self._api_key)
self._aws_profile = aws_profile
self.base_url = "https://api.eia.gov/v2/"

def filter_series(
self,
df: pd.DataFrame,
id_: str,
facets: Dict[str, str],
) -> pd.DataFrame:
"""
Filter and clean a single time series from an EIA dataset.

Apply facet filters (e.g., state, sector) to select one unique
series, drop missing values, and convert the time column to a
UTC-indexed datetime format.
Comment thread
aangelo9 marked this conversation as resolved.
Outdated

:param df: EIA series data
:param id_: EIA series ID, e.g.,
"electricity.retail_sales.monthly.price"
:param facets: facet filters,
e.g., {"stateid": "WI", "sectorid": "ALL"}
:return: data of single time series with one facet value per
facet type

Example output:
```
period stateid stateDescription sectorid
2001-01-01T00:00:00+00:00 WI Wisconsin ALL
2001-02-01T00:00:00+00:00 WI Wisconsin ALL
2001-03-01T00:00:00+00:00 WI Wisconsin ALL

sectorName price price-units
all sectors 5.9 cents per kilowatt-hour
all sectors 5.98 cents per kilowatt-hour
all sectors 5.93 cents per kilowatt-hour
Comment thread
aangelo9 marked this conversation as resolved.
Outdated
```
"""
# Filter data with given facet values.
for key, val in facets.items():
hdbg.dassert_in(
key,
df.columns,
"Facet '%s' not found in data columns=%s",
key,
list(df.columns),
)
df = df[df[key] == val]
# Detect the metric column.
_, data_identifier = self._parse_id(id_)
# Drop rows with missing value.
df = df.dropna(subset=[data_identifier])
Comment thread
aangelo9 marked this conversation as resolved.
if df.empty:
_LOG.warning("No data remaining after applying facets.")
Comment thread
aangelo9 marked this conversation as resolved.
Outdated
# Convert to datetime index.
df["period"] = pd.to_datetime(df["period"])
df = df.rename(columns={"period": "period (UTC)"})
Comment thread
aangelo9 marked this conversation as resolved.
Outdated
df = df.set_index("period (UTC)")
df.index = df.index.tz_localize("UTC")
Comment thread
aangelo9 marked this conversation as resolved.
Outdated
df = df.sort_index()
return df

def download_series(
self,
id_: str,
*,
start_timestamp: Optional[pd.Timestamp] = None,
end_timestamp: Optional[pd.Timestamp] = None,
max_rows_per_call: int = 5000,
) -> pd.DataFrame:
"""
Download EIA historical series data.

This method retrieves the full set of time series linked to an
EIA identifier, including all combinations of facet values
(e.g., `stateid`, `sectorid`). When no start and end timestamps are
passed, the entire time series is downloaded.

:param id_: EIA series ID, e.g.,
"electricity.retail_sales.monthly.price"
:param start_timestamp: first observation date
:param end_timestamp: last observation date
:param max_rows_per_call: max data rows per api call
Comment thread
aangelo9 marked this conversation as resolved.
Outdated
Comment thread
aangelo9 marked this conversation as resolved.
Outdated
:return: full time series data with all facets

Example output:
```
period stateid stateDescription sectorid sectorName
2020-09 WI Wisconsin IND industrial
2020-09 WY Wyoming ALL all sectors
2020-09 IA Iowa RES Residential

price price-units
7.45 cents per kilowatt-hour
8.55 cents per kilowatt-hour
12.65 cents per kilowatt-hour
```
"""
# Get base url from metadata index.
base_url = self._get_metadata_url(id_)
# Build URL query with api key and timestamps.
Comment thread
aangelo9 marked this conversation as resolved.
Outdated
url = catemdpeu.build_full_url(
base_url,
self._api_key,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
)
data_chunks = []
offset = 0
while True:
# Construct the paginated URL for the current offset.
paginated_url = f"{url}&offset={offset}&length={max_rows_per_call}"
data = self._client.get_response(paginated_url, self._client.header)
Comment thread
aangelo9 marked this conversation as resolved.
data_chunks.append(data)
if len(data) < max_rows_per_call:
# Exit loop when its the final page of data.
Comment thread
aangelo9 marked this conversation as resolved.
Outdated
break
offset += max_rows_per_call
if not data_chunks:
_LOG.warning("No data returned under given id.")
df = pd.concat(data_chunks, ignore_index=True)
_LOG.debug("Downloaded %d rows for id=%s", len(df), id_)
return df

def _parse_id(self, id_: str) -> Tuple[str, str]:
"""
Parse an EIA time series ID into its components.
Comment thread
aangelo9 marked this conversation as resolved.

:param id_: EIA time series ID,
e.g., "electricity.retail_sales.monthly.price"
:return:
- top-level EIA category, e.g., "electricity"
- data identifier, e.g., "price"
"""
id_ = id_.replace("_", "-")
parts = id_.split(".")
category = parts[0]
data_identifier = parts[-1]
return category, data_identifier
Comment thread
aangelo9 marked this conversation as resolved.
Outdated

def _get_latest_metadata_s3_path(self, category: str) -> str:
"""
Get the latest versioned metadata file S3 path for a given category.

:param category: top-level EIA category, e.g., "electricity"
:return: full S3 path to the latest version of the metadata CSV
e.g., "eia_electricity_metadata_original_v2.0.csv"
"""
# Get file names from s3 bucket.
base_dir = "s3://causify-data-collaborators/causal_automl/metadata"
pattern = f"eia_{category}_metadata_original_v*"
files = hs3.listdir(
dir_name=base_dir,
pattern=pattern,
only_files=True,
use_relative_paths=False,
aws_profile=self._aws_profile,
maxdepth=1,
)
if not files:
raise FileNotFoundError(
f"No metadata index file found for category: '{category}' in S3."
)
# Get latest file version.
files.sort(reverse=True)
s3_path = f"s3://{files[0]}"
return s3_path

def _get_metadata_url(self, id_: str) -> str:
"""
:param id_: EIA time series ID,
e.g., "electricity.retail_sales.monthly.price"
:param category: top-level EIA category, e.g., "electricity"
Comment thread
aangelo9 marked this conversation as resolved.
Outdated
:return: base API URL with frequency and metric, excluding facet values,
e.g., "https://api.eia.gov/v2/electricity/retail-sales?api_key={API_KEY}&frequency=monthly&data[0]=revenue"
"""
category, _ = self._parse_id(id_)
# Load latest metadata index file from s3.
s3_path = self._get_latest_metadata_s3_path(category)
csv_str = hs3.from_file(s3_path, aws_profile=self._aws_profile)
df = pd.read_csv(io.StringIO(csv_str))
Comment thread
aangelo9 marked this conversation as resolved.
Outdated
# Filter for exact ID match.
match = df[df["id"] == id_]
if match.empty:
raise ValueError(f"Invalid id: '{id_}'")
row = match.iloc[0]
base_url = str(row["url"])
Comment thread
aangelo9 marked this conversation as resolved.
Outdated
Comment thread
aangelo9 marked this conversation as resolved.
Outdated
return base_url