Skip to content

Commit 17a73ce

Browse files
authored
Expose PolicyEngine bundle metadata in calculate responses (#1459)
* Expose policyengine bundle metadata * Expose deterministic household bundle metadata
1 parent 8408d47 commit 17a73ce

9 files changed

Lines changed: 126 additions & 30 deletions

File tree

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: patch
2+
changes:
3+
added:
4+
- Return PolicyEngine bundle metadata from household calculate responses.

policyengine_household_api/constants.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515
"policyengine_ng",
1616
"policyengine_il",
1717
)
18-
try:
19-
COUNTRY_PACKAGE_VERSIONS = {
20-
country: version(package_name)
21-
for country, package_name in zip(COUNTRIES, COUNTRY_PACKAGE_NAMES)
22-
}
23-
except:
24-
COUNTRY_PACKAGE_VERSIONS = {country: "0.0.0" for country in COUNTRIES}
18+
COUNTRY_PACKAGE_VERSIONS = {}
19+
for country, package_name in zip(COUNTRIES, COUNTRY_PACKAGE_NAMES):
20+
try:
21+
COUNTRY_PACKAGE_VERSIONS[country] = version(package_name)
22+
except Exception:
23+
COUNTRY_PACKAGE_VERSIONS[country] = "0.0.0"
2524
__version__ = VERSION

policyengine_household_api/country.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,12 @@
2121
ParameterScale,
2222
ParameterScaleBracket,
2323
)
24-
from typing import Annotated
2524
from policyengine_core.parameters import get_parameter
26-
from importlib.metadata import version
2725
from policyengine_core.model_api import Reform, Enum
2826
from policyengine_core.periods import instant
2927
import dpath
3028
import math
3129
from uuid import UUID, uuid4
32-
import policyengine_uk
33-
import policyengine_us
34-
import policyengine_canada
35-
import policyengine_ng
36-
import policyengine_il
3730

3831

3932
class PolicyEngineCountry:
@@ -44,8 +37,16 @@ def __init__(self, country_package_name: str, country_id: str):
4437
self.tax_benefit_system: TaxBenefitSystem = (
4538
self.country_package.CountryTaxBenefitSystem()
4639
)
40+
self.policyengine_bundle = self.build_policyengine_bundle()
4741
self.build_metadata()
4842

43+
def build_policyengine_bundle(self) -> dict:
44+
return {
45+
"model_version": COUNTRY_PACKAGE_VERSIONS[self.country_id],
46+
"data_version": None,
47+
"dataset": None,
48+
}
49+
4950
def build_metadata(self):
5051
self.metadata = dict(
5152
status="ok",
@@ -65,7 +66,7 @@ def build_metadata(self):
6566
}[self.country_id],
6667
basicInputs=self.tax_benefit_system.basic_inputs,
6768
modelled_policies=self.tax_benefit_system.modelled_policies,
68-
version=version(self.country_package_name),
69+
version=self.policyengine_bundle["model_version"],
6970
),
7071
)
7172

@@ -314,11 +315,11 @@ def calculate(
314315
system.parameters, parameter_name
315316
)
316317
node_type = type(parameter.values_list[-1].value)
317-
if node_type == int:
318+
if node_type is int:
318319
node_type = float
319320
try:
320321
value = float(value)
321-
except:
322+
except (TypeError, ValueError):
322323
pass
323324
parameter.update(
324325
start=instant(start_instant),
@@ -373,14 +374,14 @@ def calculate(
373374
entity_index = population.get_index(entity_id)
374375
if variable.value_type == Enum:
375376
entity_result = result.decode()[entity_index].name
376-
elif variable.value_type == float:
377+
elif variable.value_type is float:
377378
entity_result = float(str(result[entity_index]))
378379
# Convert infinities to JSON infinities
379380
if entity_result == float("inf"):
380381
entity_result = "Infinity"
381382
elif entity_result == float("-inf"):
382383
entity_result = "-Infinity"
383-
elif variable.value_type == str:
384+
elif variable.value_type is str:
384385
entity_result = str(result[entity_index])
385386
else:
386387
entity_result = result.tolist()[entity_index]
@@ -459,7 +460,7 @@ def modify_parameters(parameters: ParameterNode) -> ParameterNode:
459460
for period, value in values.items():
460461
start, end = period.split(".")
461462
node_type = type(node.values_list[-1].value)
462-
if node_type == int:
463+
if node_type is int:
463464
node_type = float # '0' is of type int by default, but usually we want to cast to float.
464465
node.update(
465466
start=instant(start),

policyengine_household_api/endpoints/household.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from uuid import UUID
44
from policyengine_household_api.country import COUNTRIES
55
from policyengine_household_api.utils.validate_country import validate_country
6-
import json
76
import logging
87

98

@@ -47,6 +46,7 @@ def get_calculate(country_id: str, add_missing: bool = False) -> Response:
4746
status="ok",
4847
message=None,
4948
result=result,
49+
policyengine_bundle=dict(country.policyengine_bundle),
5050
)
5151

5252
if enable_ai_explainer:

policyengine_household_api/openapi_spec.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,17 @@ paths:
7373
nullable: true
7474
result:
7575
type: object
76+
policyengine_bundle:
77+
type: object
78+
properties:
79+
model_version:
80+
type: string
81+
data_version:
82+
type: string
83+
nullable: true
84+
dataset:
85+
type: string
86+
nullable: true
7687
properties:
7788
variables:
7889
type: object

policyengine_household_api/utils/computation_tree.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@
1212
from policyengine_household_api.utils.config_loader import get_config_value
1313

1414

15+
def _get_anthropic_api_key() -> str:
16+
api_key = get_config_value("ai.anthropic.api_key")
17+
if not api_key:
18+
raise ValueError("Anthropic api_key is not configured.")
19+
return api_key
20+
21+
1522
def trigger_streaming_ai_analysis(
1623
prompt: str,
1724
) -> Generator[str, None, None] | None:
@@ -30,9 +37,7 @@ def trigger_streaming_ai_analysis(
3037
return None
3138

3239
# Configure a Claude client
33-
claude_client = anthropic.Anthropic(
34-
api_key=get_config_value("ai.anthropic.api_key")
35-
)
40+
claude_client = anthropic.Anthropic(api_key=_get_anthropic_api_key())
3641

3742
def generate():
3843
"""
@@ -76,9 +81,7 @@ def trigger_buffered_ai_analysis(prompt: str) -> str | None:
7681
return None
7782

7883
# Configure a Claude client
79-
claude_client = anthropic.Anthropic(
80-
api_key=get_config_value("ai.anthropic.api_key")
81-
)
84+
claude_client = anthropic.Anthropic(api_key=_get_anthropic_api_key())
8285

8386
# Pass the prompt to Claude for analysis
8487
response: Message = claude_client.messages.create(

tests/to_refactor/python/test_sync.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import requests
33
import json
44
import sys
5+
from policyengine_household_api.constants import COUNTRY_PACKAGE_VERSIONS
56
from policyengine_household_api.utils.config_loader import get_config_value
67
from tests.to_refactor.fixtures import client, extract_json_from_file
78

@@ -33,5 +34,12 @@ def test_calculate_sync(client):
3334
json=input_data,
3435
).get_json()
3536

36-
# Compare the outputs
37+
policyengine_bundle = resLight.pop("policyengine_bundle")
38+
39+
# Compare the legacy response body and assert the new provenance separately.
3740
assert resAPI == resLight
41+
assert policyengine_bundle == {
42+
"model_version": COUNTRY_PACKAGE_VERSIONS[country_id],
43+
"data_version": None,
44+
"dataset": None,
45+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import json
2+
3+
from policyengine_household_api.constants import COUNTRY_PACKAGE_VERSIONS
4+
from policyengine_household_api.utils.config_loader import get_config_value
5+
from tests.fixtures.country import (
6+
valid_household_requesting_ctc_calculation,
7+
)
8+
9+
10+
class TestCalculateEndpoint:
11+
auth_headers = {
12+
"Authorization": f"Bearer {get_config_value('auth.auth0.test_token')}",
13+
}
14+
15+
def test_returns_policyengine_bundle(self, client):
16+
response = client.post(
17+
"/us/calculate",
18+
json={"household": valid_household_requesting_ctc_calculation},
19+
headers=self.auth_headers,
20+
)
21+
22+
assert response.status_code == 200
23+
24+
payload = json.loads(response.data)
25+
assert payload["policyengine_bundle"] == {
26+
"model_version": COUNTRY_PACKAGE_VERSIONS["us"],
27+
"data_version": None,
28+
"dataset": None,
29+
}

tests/unit/test_country.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
country_package_name_us,
44
country_id_us,
55
)
6+
from importlib.metadata import PackageNotFoundError
67
from policyengine_household_api.country import PolicyEngineCountry
8+
from policyengine_household_api.constants import COUNTRY_PACKAGE_VERSIONS
79
from uuid import UUID
810

911

@@ -25,7 +27,7 @@ def test_calculate_no_tree(self):
2527
)
2628

2729
# Then a tuple of a valid response and None is returned
28-
assert test_uuid_value == None
30+
assert test_uuid_value is None
2931

3032
def test_calculate_tree_requested(self):
3133

@@ -39,5 +41,44 @@ def test_calculate_tree_requested(self):
3941
enable_ai_explainer=True,
4042
)
4143

42-
assert type(test_uuid_value) == str
44+
assert isinstance(test_uuid_value, str)
4345
assert UUID(test_uuid_value).version == 4
46+
47+
48+
class TestPolicyEngineBundle:
49+
50+
def test_country_exposes_policyengine_bundle(self):
51+
country = PolicyEngineCountry(country_package_name_us, country_id_us)
52+
53+
assert country.policyengine_bundle == {
54+
"model_version": COUNTRY_PACKAGE_VERSIONS[country_id_us],
55+
"data_version": None,
56+
"dataset": None,
57+
}
58+
assert (
59+
country.metadata["result"]["version"]
60+
== COUNTRY_PACKAGE_VERSIONS[country_id_us]
61+
)
62+
63+
64+
def test_country_package_versions_falls_back_per_package(monkeypatch):
65+
from policyengine_household_api import constants
66+
67+
def _fake_version(package_name: str) -> str:
68+
if package_name == "policyengine_us":
69+
return "1.602.0"
70+
raise PackageNotFoundError(package_name)
71+
72+
monkeypatch.setattr(constants, "version", _fake_version)
73+
74+
versions = {}
75+
for country, package_name in zip(
76+
constants.COUNTRIES, constants.COUNTRY_PACKAGE_NAMES
77+
):
78+
try:
79+
versions[country] = constants.version(package_name)
80+
except Exception:
81+
versions[country] = "0.0.0"
82+
83+
assert versions["us"] == "1.602.0"
84+
assert versions["uk"] == "0.0.0"

0 commit comments

Comments
 (0)