Skip to content

Commit d533848

Browse files
Fix Pass data and country package versions to APIv2 #2500 (#2528)
* Fix Pass data and country package versions to APIv2 #2500 * Add missing file * Adjust UK address * Typo * Use token * Rename function * Fix accidental uk/us mixup
1 parent 2afc982 commit d533848

3 files changed

Lines changed: 173 additions & 5 deletions

File tree

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: patch
2+
changes:
3+
fixed:
4+
- Pass country and data versions to APIv2.

policyengine_api/jobs/calculate_economy_simulation_job.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,16 @@
2626
compute_difference,
2727
)
2828
from policyengine_core.simulations import Microsimulation
29-
from policyengine_core.tools.hugging_face import download_huggingface_dataset
29+
from policyengine_core.tools.hugging_face import (
30+
download_huggingface_dataset,
31+
)
32+
from policyengine_api.utils.hugging_face import get_latest_commit_tag
3033
import h5py
3134

3235
from policyengine_us import Microsimulation
3336
from policyengine_uk import Microsimulation
3437
import logging
38+
import huggingface_hub
3539

3640
load_dotenv()
3741

@@ -44,6 +48,34 @@
4448
CPS = "hf://policyengine/policyengine-us-data/cps_2023.h5"
4549
POOLED_CPS = "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5"
4650

51+
datasets = {
52+
"uk": {
53+
"enhanced_frs": ENHANCED_FRS,
54+
"frs": FRS,
55+
},
56+
"us": {
57+
"enhanced_cps": ENHANCED_CPS,
58+
"cps": CPS,
59+
"pooled_cps": POOLED_CPS,
60+
},
61+
}
62+
63+
us_dataset_version = get_latest_commit_tag(
64+
repo_id="policyengine/policyengine-us-data",
65+
repo_type="model",
66+
)
67+
uk_dataset_version = get_latest_commit_tag(
68+
repo_id="policyengine/policyengine-uk-data-private",
69+
repo_type="model",
70+
)
71+
72+
for dataset in datasets["uk"]:
73+
datasets["uk"][dataset] = f"{datasets['uk'][dataset]}@{uk_dataset_version}"
74+
75+
for dataset in datasets["us"]:
76+
datasets["us"][dataset] = f"{datasets['us'][dataset]}@{us_dataset_version}"
77+
78+
4779
check_against_api_v2 = (
4880
os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") is not None
4981
)
@@ -189,6 +221,12 @@ def run(
189221
time_period=time_period,
190222
region=region,
191223
dataset=dataset,
224+
model_version=COUNTRY_PACKAGE_VERSIONS[country_id],
225+
data_version=(
226+
uk_dataset_version
227+
if country_id == "uk"
228+
else us_dataset_version
229+
),
192230
)
193231

194232
try:
@@ -454,7 +492,7 @@ def _create_simulation_uk(
454492

455493
simulation = CountryMicrosimulation(
456494
reform=reform,
457-
dataset=ENHANCED_FRS,
495+
dataset=datasets["uk"]["enhanced_frs"],
458496
)
459497
simulation.default_calculation_period = time_period
460498
if region != "uk":
@@ -514,7 +552,7 @@ def _create_simulation_us(
514552
if dataset in DATASETS:
515553
print(f"Running simulation using {dataset} dataset")
516554

517-
sim_options["dataset"] = ENHANCED_CPS
555+
sim_options["dataset"] = datasets["us"]["enhanced_cps"]
518556

519557
# Handle region settings
520558
if region != "us":
@@ -526,7 +564,7 @@ def _create_simulation_us(
526564
if "dataset" in sim_options:
527565
filter_dataset = sim_options["dataset"]
528566
else:
529-
filter_dataset = POOLED_CPS
567+
filter_dataset = datasets["us"]["pooled_cps"]
530568

531569
# Run sim to filter by region
532570
region_sim = Microsimulation(
@@ -547,7 +585,7 @@ def _create_simulation_us(
547585
sim_options["dataset"] = df[state_code == region.upper()]
548586

549587
if dataset == "default" and region == "us":
550-
sim_options["dataset"] = CPS
588+
sim_options["dataset"] = datasets["us"]["cps"]
551589

552590
# Return completed simulation
553591
return Microsimulation(**sim_options)
@@ -723,6 +761,8 @@ def _setup_sim_options(
723761
dataset: str,
724762
time_period: str,
725763
scope: Literal["macro", "household"] = "macro",
764+
model_version: str | None = None,
765+
data_version: str | None = None,
726766
) -> dict[str, Any]:
727767
"""
728768
Set up the simulation options for the APIv2 job.
@@ -738,6 +778,8 @@ def _setup_sim_options(
738778
"data": self._setup_data(
739779
dataset=dataset, country_id=country_id, region=region
740780
),
781+
"model_version": model_version,
782+
"data_version": data_version,
741783
}
742784

743785
def _setup_region(self, country_id: str, region: str) -> str:
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from huggingface_hub import (
2+
hf_hub_download,
3+
model_info,
4+
ModelInfo,
5+
HfApi,
6+
)
7+
from huggingface_hub.errors import RepositoryNotFoundError
8+
from getpass import getpass
9+
import os
10+
import warnings
11+
import traceback
12+
13+
with warnings.catch_warnings():
14+
warnings.simplefilter("ignore")
15+
16+
17+
def get_latest_commit_tag(repo_id, repo_type="model"):
18+
"""
19+
Get the tag associated with the latest commit in a HF repo.
20+
Returns the tag name or None if no tag is associated.
21+
"""
22+
api = HfApi()
23+
24+
is_repo_private = check_is_repo_private(repo_id)
25+
26+
authentication_token: str = None
27+
if is_repo_private:
28+
authentication_token: str = get_or_prompt_hf_token()
29+
30+
# Get list of commits
31+
commits = api.list_repo_commits(
32+
repo_id=repo_id, repo_type=repo_type, token=authentication_token
33+
)
34+
35+
if not commits:
36+
return None
37+
38+
latest_commit = commits[0] # Most recent commit is first
39+
40+
# Get all tags in the repository
41+
tags = api.list_repo_refs(
42+
repo_id=repo_id, repo_type=repo_type, token=authentication_token
43+
).tags
44+
45+
# Find tag that points to the latest commit
46+
for tag in tags:
47+
if tag.target_commit == latest_commit.commit_id:
48+
return tag.ref.replace("refs/tags/", "")
49+
50+
return None
51+
52+
53+
def check_is_repo_private(repo: str) -> bool:
54+
"""
55+
Check if a Hugging Face repository is private.
56+
57+
Args:
58+
repo (str): The Hugging Face repo name, in format "{org}/{repo}".
59+
60+
Returns:
61+
bool: True if the repo is private, False otherwise.
62+
"""
63+
try:
64+
fetched_model_info: ModelInfo = model_info(repo)
65+
return fetched_model_info.private
66+
except RepositoryNotFoundError:
67+
return True # If repo not found, assume it's private
68+
except Exception as e:
69+
raise Exception(
70+
f"Unable to check if repo {repo} is private. The full error is {traceback.format_exc()}"
71+
)
72+
73+
74+
def download_huggingface_dataset(
75+
repo: str,
76+
repo_filename: str,
77+
version: str = None,
78+
local_dir: str | None = None,
79+
):
80+
"""
81+
Download a dataset from the Hugging Face Hub.
82+
83+
Args:
84+
repo (str): The Hugging Face repo name, in format "{org}/{repo}".
85+
repo_filename (str): The filename of the dataset.
86+
version (str, optional): The version of the dataset. Defaults to None.
87+
local_dir (str, optional): The local directory to save the dataset to. Defaults to None.
88+
"""
89+
is_repo_private = check_is_repo_private(repo)
90+
91+
authentication_token: str = None
92+
if is_repo_private:
93+
authentication_token: str = get_or_prompt_hf_token()
94+
95+
return hf_hub_download(
96+
repo_id=repo,
97+
repo_type="model",
98+
filename=repo_filename,
99+
revision=version,
100+
token=authentication_token,
101+
local_dir=local_dir,
102+
)
103+
104+
105+
def get_or_prompt_hf_token() -> str:
106+
"""
107+
Either get the Hugging Face token from the environment,
108+
or prompt the user for it and store it in the environment.
109+
110+
Returns:
111+
str: The Hugging Face token.
112+
"""
113+
114+
token = os.environ.get("HUGGING_FACE_TOKEN")
115+
if token is None:
116+
token = getpass(
117+
"Enter your Hugging Face token (or set HUGGING_FACE_TOKEN environment variable): "
118+
)
119+
# Optionally store in env for subsequent calls in same session
120+
os.environ["HUGGING_FACE_TOKEN"] = token
121+
122+
return token

0 commit comments

Comments
 (0)