Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: patch
changes:
added:
- Return PolicyEngine bundle metadata from household calculate responses.
13 changes: 6 additions & 7 deletions policyengine_household_api/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
"policyengine_ng",
"policyengine_il",
)
try:
COUNTRY_PACKAGE_VERSIONS = {
country: version(package_name)
for country, package_name in zip(COUNTRIES, COUNTRY_PACKAGE_NAMES)
}
except:
COUNTRY_PACKAGE_VERSIONS = {country: "0.0.0" for country in COUNTRIES}
COUNTRY_PACKAGE_VERSIONS = {}
for country, package_name in zip(COUNTRIES, COUNTRY_PACKAGE_NAMES):
try:
COUNTRY_PACKAGE_VERSIONS[country] = version(package_name)
except Exception:
COUNTRY_PACKAGE_VERSIONS[country] = "0.0.0"
__version__ = VERSION
27 changes: 14 additions & 13 deletions policyengine_household_api/country.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,12 @@
ParameterScale,
ParameterScaleBracket,
)
from typing import Annotated
from policyengine_core.parameters import get_parameter
from importlib.metadata import version
from policyengine_core.model_api import Reform, Enum
from policyengine_core.periods import instant
import dpath
import math
from uuid import UUID, uuid4
import policyengine_uk
import policyengine_us
import policyengine_canada
import policyengine_ng
import policyengine_il


class PolicyEngineCountry:
Expand All @@ -44,8 +37,16 @@ def __init__(self, country_package_name: str, country_id: str):
self.tax_benefit_system: TaxBenefitSystem = (
self.country_package.CountryTaxBenefitSystem()
)
self.policyengine_bundle = self.build_policyengine_bundle()
self.build_metadata()

def build_policyengine_bundle(self) -> dict:
return {
"model_version": COUNTRY_PACKAGE_VERSIONS[self.country_id],
"data_version": None,
"dataset": None,
}

def build_metadata(self):
self.metadata = dict(
status="ok",
Expand All @@ -65,7 +66,7 @@ def build_metadata(self):
}[self.country_id],
basicInputs=self.tax_benefit_system.basic_inputs,
modelled_policies=self.tax_benefit_system.modelled_policies,
version=version(self.country_package_name),
version=self.policyengine_bundle["model_version"],
),
)

Expand Down Expand Up @@ -314,11 +315,11 @@ def calculate(
system.parameters, parameter_name
)
node_type = type(parameter.values_list[-1].value)
if node_type == int:
if node_type is int:
node_type = float
try:
value = float(value)
except:
except (TypeError, ValueError):
pass
parameter.update(
start=instant(start_instant),
Expand Down Expand Up @@ -373,14 +374,14 @@ def calculate(
entity_index = population.get_index(entity_id)
if variable.value_type == Enum:
entity_result = result.decode()[entity_index].name
elif variable.value_type == float:
elif variable.value_type is float:
entity_result = float(str(result[entity_index]))
# Convert infinities to JSON infinities
if entity_result == float("inf"):
entity_result = "Infinity"
elif entity_result == float("-inf"):
entity_result = "-Infinity"
elif variable.value_type == str:
elif variable.value_type is str:
entity_result = str(result[entity_index])
else:
entity_result = result.tolist()[entity_index]
Expand Down Expand Up @@ -459,7 +460,7 @@ def modify_parameters(parameters: ParameterNode) -> ParameterNode:
for period, value in values.items():
start, end = period.split(".")
node_type = type(node.values_list[-1].value)
if node_type == int:
if node_type is int:
node_type = float # '0' is of type int by default, but usually we want to cast to float.
node.update(
start=instant(start),
Expand Down
2 changes: 1 addition & 1 deletion policyengine_household_api/endpoints/household.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from uuid import UUID
from policyengine_household_api.country import COUNTRIES
from policyengine_household_api.utils.validate_country import validate_country
import json
import logging


Expand Down Expand Up @@ -47,6 +46,7 @@ def get_calculate(country_id: str, add_missing: bool = False) -> Response:
status="ok",
message=None,
result=result,
policyengine_bundle=dict(country.policyengine_bundle),
)

if enable_ai_explainer:
Expand Down
46 changes: 40 additions & 6 deletions policyengine_household_api/openapi_spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,39 @@ servers:
paths:
/:
get:
summary: Get the home page of the PolicyEngine API
summary: Get service metadata for the PolicyEngine household API
operationId: get_home
description: Returns the home page of the PolicyEngine API as an HTML string.
description: Returns service metadata, documentation links, and self-hosting hints.
responses:
200:
description: The home page.
description: Service metadata.
content:
text/html:
application/json:
schema:
type: string
type: object
properties:
status:
type: string
message:
type: string
result:
type: object
properties:
docs_url:
type: string
container_image:
type: string
hosted_calculate_url:
type: string
local_calculate_url:
type: string
health_checks:
type: object
properties:
liveness:
type: string
readiness:
type: string
/{country_id}/metadata:
get:
summary: Get metadata for a country
Expand Down Expand Up @@ -50,6 +73,17 @@ paths:
nullable: true
result:
type: object
policyengine_bundle:
type: object
properties:
model_version:
type: string
data_version:
type: string
nullable: true
dataset:
type: string
nullable: true
properties:
variables:
type: object
Expand Down Expand Up @@ -841,4 +875,4 @@ paths:
paths:
type: object
servers:
type: array
type: array
15 changes: 9 additions & 6 deletions policyengine_household_api/utils/computation_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
from policyengine_household_api.utils.config_loader import get_config_value


def _get_anthropic_api_key() -> str:
api_key = get_config_value("ai.anthropic.api_key")
if not api_key:
raise ValueError("Anthropic api_key is not configured.")
return api_key


def trigger_streaming_ai_analysis(
prompt: str,
) -> Generator[str, None, None] | None:
Expand All @@ -30,9 +37,7 @@ def trigger_streaming_ai_analysis(
return None

# Configure a Claude client
claude_client = anthropic.Anthropic(
api_key=get_config_value("ai.anthropic.api_key")
)
claude_client = anthropic.Anthropic(api_key=_get_anthropic_api_key())

def generate():
"""
Expand Down Expand Up @@ -76,9 +81,7 @@ def trigger_buffered_ai_analysis(prompt: str) -> str | None:
return None

# Configure a Claude client
claude_client = anthropic.Anthropic(
api_key=get_config_value("ai.anthropic.api_key")
)
claude_client = anthropic.Anthropic(api_key=_get_anthropic_api_key())

# Pass the prompt to Claude for analysis
response: Message = claude_client.messages.create(
Expand Down
10 changes: 9 additions & 1 deletion tests/to_refactor/python/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import requests
import json
import sys
from policyengine_household_api.constants import COUNTRY_PACKAGE_VERSIONS
from policyengine_household_api.utils.config_loader import get_config_value
from tests.to_refactor.fixtures import client, extract_json_from_file

Expand Down Expand Up @@ -33,5 +34,12 @@ def test_calculate_sync(client):
json=input_data,
).get_json()

# Compare the outputs
policyengine_bundle = resLight.pop("policyengine_bundle")

# Compare the legacy response body and assert the new provenance separately.
assert resAPI == resLight
assert policyengine_bundle == {
"model_version": COUNTRY_PACKAGE_VERSIONS[country_id],
"data_version": None,
"dataset": None,
}
29 changes: 29 additions & 0 deletions tests/unit/endpoints/test_household.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import json

from policyengine_household_api.constants import COUNTRY_PACKAGE_VERSIONS
from policyengine_household_api.utils.config_loader import get_config_value
from tests.fixtures.country import (
valid_household_requesting_ctc_calculation,
)


class TestCalculateEndpoint:
auth_headers = {
"Authorization": f"Bearer {get_config_value('auth.auth0.test_token')}",
}

def test_returns_policyengine_bundle(self, client):
response = client.post(
"/us/calculate",
json={"household": valid_household_requesting_ctc_calculation},
headers=self.auth_headers,
)

assert response.status_code == 200

payload = json.loads(response.data)
assert payload["policyengine_bundle"] == {
"model_version": COUNTRY_PACKAGE_VERSIONS["us"],
"data_version": None,
"dataset": None,
}
45 changes: 43 additions & 2 deletions tests/unit/test_country.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
country_package_name_us,
country_id_us,
)
from importlib.metadata import PackageNotFoundError
from policyengine_household_api.country import PolicyEngineCountry
from policyengine_household_api.constants import COUNTRY_PACKAGE_VERSIONS
from uuid import UUID


Expand All @@ -25,7 +27,7 @@ def test_calculate_no_tree(self):
)

# Then a tuple of a valid response and None is returned
assert test_uuid_value == None
assert test_uuid_value is None

def test_calculate_tree_requested(self):

Expand All @@ -39,5 +41,44 @@ def test_calculate_tree_requested(self):
enable_ai_explainer=True,
)

assert type(test_uuid_value) == str
assert isinstance(test_uuid_value, str)
assert UUID(test_uuid_value).version == 4


class TestPolicyEngineBundle:

def test_country_exposes_policyengine_bundle(self):
country = PolicyEngineCountry(country_package_name_us, country_id_us)

assert country.policyengine_bundle == {
"model_version": COUNTRY_PACKAGE_VERSIONS[country_id_us],
"data_version": None,
"dataset": None,
}
assert (
country.metadata["result"]["version"]
== COUNTRY_PACKAGE_VERSIONS[country_id_us]
)


def test_country_package_versions_falls_back_per_package(monkeypatch):
from policyengine_household_api import constants

def _fake_version(package_name: str) -> str:
if package_name == "policyengine_us":
return "1.602.0"
raise PackageNotFoundError(package_name)

monkeypatch.setattr(constants, "version", _fake_version)

versions = {}
for country, package_name in zip(
constants.COUNTRIES, constants.COUNTRY_PACKAGE_NAMES
):
try:
versions[country] = constants.version(package_name)
except Exception:
versions[country] = "0.0.0"

assert versions["us"] == "1.602.0"
assert versions["uk"] == "0.0.0"
Loading