diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 67c82fa..1148345 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -26,10 +26,10 @@ jobs: - uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - - name: Install pylcm + - name: Install pylcm (feature branch — revert to @main once pylcm#350 merges) run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@main" + git+https://github.com/OpenSourceEconomics/pylcm.git@feat/categorical-scalarint" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest diff --git a/src/aca_model/_benchmark_data/benchmark_params.pkl b/src/aca_model/_benchmark_data/benchmark_params.pkl index d0d9c1d..650c390 100644 Binary files a/src/aca_model/_benchmark_data/benchmark_params.pkl and b/src/aca_model/_benchmark_data/benchmark_params.pkl differ diff --git a/src/aca_model/aca/health_insurance.py b/src/aca_model/aca/health_insurance.py index 1aa4133..f32d54b 100644 --- a/src/aca_model/aca/health_insurance.py +++ b/src/aca_model/aca/health_insurance.py @@ -10,7 +10,14 @@ import jax.numpy as jnp from lcm.params import MappingLeaf -from lcm.typing import BoolND, ContinuousState, DiscreteAction, DiscreteState, FloatND +from lcm.typing import ( + BoolND, + ContinuousState, + DiscreteAction, + DiscreteState, + FloatND, + ScalarFloat, +) from aca_model.baseline.health_insurance import BuyPrivate, oop_costs @@ -136,9 +143,9 @@ def primary_oop( total_health_costs: FloatND, cost_sharing_scale: FloatND, buy_private: DiscreteAction, - deductible: float, - coinsurance_rate: float, - oop_max: float, + deductible: ScalarFloat, + coinsurance_rate: ScalarFloat, + oop_max: ScalarFloat, ) -> FloatND: """Compute primary OOP costs with ACA cost-sharing reductions. diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index 92e9abb..e517fb5 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -9,12 +9,13 @@ ContinuousAction, ContinuousState, FloatND, + ScalarFloat, ) def capital_income( assets: ContinuousState, - rate_of_return: float, + rate_of_return: ScalarFloat, ) -> FloatND: """Compute capital income from assets.""" return assets * rate_of_return @@ -36,7 +37,7 @@ def cash_on_hand( def consumption_dollars_floor( - consumption_equiv_floor: float, + consumption_equiv_floor: ScalarFloat, equivalence_scale: FloatND, ) -> FloatND: """Per-household $-floor on consumption.""" diff --git a/src/aca_model/agent/health.py b/src/aca_model/agent/health.py index e0edf5f..66f04fe 100644 --- a/src/aca_model/agent/health.py +++ b/src/aca_model/agent/health.py @@ -6,28 +6,28 @@ import jax.numpy as jnp from lcm import categorical -from lcm.typing import DiscreteState, FloatND, IntND, Period +from lcm.typing import DiscreteState, FloatND, IntND, Period, ScalarInt @categorical(ordered=True) class HealthWithDisability: - disabled: int - bad: int - good: int + disabled: ScalarInt + bad: ScalarInt + good: ScalarInt @categorical(ordered=True) class Health: - bad: int - good: int + bad: ScalarInt + good: ScalarInt @categorical(ordered=True) class GoodHealth: """Derived categorical for good_health DAG output (0=no, 1=yes).""" - no: int - yes: int + no: ScalarInt + yes: ScalarInt def is_good_health_3(health: DiscreteState) -> IntND: diff --git a/src/aca_model/agent/labor_market.py b/src/aca_model/agent/labor_market.py index 6e36947..b260c43 100644 --- a/src/aca_model/agent/labor_market.py +++ b/src/aca_model/agent/labor_market.py @@ -12,29 +12,31 @@ FloatND, IntND, Period, + ScalarFloat, + ScalarInt, ) @categorical(ordered=True) class LaborSupply: - do_not_work: int - h1000: int - h1500: int - h2000: int - h2500: int + do_not_work: ScalarInt + h1000: ScalarInt + h1500: ScalarInt + h2000: ScalarInt + h2500: ScalarInt @categorical(ordered=False) class LaggedLaborSupply: - did_not_work: int - worked: int + did_not_work: ScalarInt + worked: ScalarInt @categorical(ordered=False) class SpousalIncome: - single: int - married_no_inc: int - married_has_inc: int + single: ScalarInt + married_no_inc: ScalarInt + married_has_inc: ScalarInt HOURS_VALUES = jnp.array([0.0, 1000.0, 1500.0, 2000.0, 2500.0]) @@ -52,8 +54,8 @@ def income( good_health: IntND, log_ft_wage_mean: FloatND, log_ft_wage_std: FloatND, - adj_wage_hours_exp: float, - adj_wage_hours_int: float, + adj_wage_hours_exp: ScalarFloat, + adj_wage_hours_int: ScalarFloat, ) -> FloatND: """Labor income with wage-hours interaction (French & Jones 2011). @@ -88,8 +90,8 @@ def next_lagged_supply(labor_supply: DiscreteAction) -> DiscreteState: class IsMarried: """Derived categorical for is_married DAG output (0=no, 1=yes).""" - no: int - yes: int + no: ScalarInt + yes: ScalarInt def is_married(spousal_income: DiscreteState) -> IntND: diff --git a/src/aca_model/agent/preferences.py b/src/aca_model/agent/preferences.py index 5c08541..612896b 100644 --- a/src/aca_model/agent/preferences.py +++ b/src/aca_model/agent/preferences.py @@ -6,12 +6,15 @@ import jax.numpy as jnp from lcm import categorical from lcm.typing import ( + Age, BoolND, ContinuousAction, ContinuousState, DiscreteState, FloatND, IntND, + ScalarFloat, + ScalarInt, ) from aca_model.agent.labor_market import LaggedLaborSupply @@ -21,9 +24,9 @@ class PrefType: """Unobserved preference type for heterogeneity in estimation.""" - type_0: int - type_1: int - type_2: int + type_0: ScalarInt + type_1: ScalarInt + type_2: ScalarInt @categorical(ordered=False) @@ -37,8 +40,8 @@ class BenchmarkPrefType: measured. """ - type_0: int - type_1: int + type_0: ScalarInt + type_1: ScalarInt def positive_leisure(leisure: FloatND) -> BoolND: @@ -46,7 +49,7 @@ def positive_leisure(leisure: FloatND) -> BoolND: return leisure > 0 -def equivalence_scale(is_married: IntND, exponent: float) -> FloatND: +def equivalence_scale(is_married: IntND, exponent: ScalarFloat) -> FloatND: """Return the equivalence scale for household size adjustment. Single (is_married=False) → 1.0, married (is_married=True) → 2^exponent. @@ -54,69 +57,70 @@ def equivalence_scale(is_married: IntND, exponent: float) -> FloatND: return jnp.where(is_married, 2.0**exponent, 1.0) -def leisure( +def fixed_cost_of_work( + age: Age, + fixed_cost_of_work_intercept: ScalarFloat, + fixed_cost_of_work_age_trend: ScalarFloat, + reference_age: ScalarInt, +) -> ScalarFloat: + """Age-dependent fixed cost of working (intercept + trend slope on age).""" + return fixed_cost_of_work_intercept + fixed_cost_of_work_age_trend * ( + age - reference_age + ) + + +def leisure_canwork_retiree_or_nongroup( working_hours_value: FloatND, - age: int, good_health: IntND, lagged_labor_supply: DiscreteState, - time_endowment: float, - leisure_cost_of_bad_health: float, - fixed_cost_of_work_intercept: float, - fixed_cost_of_work_age_trend: float, - labor_force_reentry_cost: float, - reference_age: int, + time_endowment: ScalarFloat, + leisure_cost_of_bad_health: ScalarFloat, + fixed_cost_of_work: ScalarFloat, + labor_force_reentry_cost: ScalarFloat, ) -> FloatND: - """Compute leisure given hours worked and state variables. + """Compute leisure for canwork retiree / nongroup regimes. - Fixed cost of work is age-dependent: intercept + trend * (age - reference_age). Reentry cost applies when returning to work after not working last period. - Working status is derived from working_hours_value > 0. """ - is_working = working_hours_value > 0.0 health_loss = jnp.where(good_health, 0.0, leisure_cost_of_bad_health) - - fixed_cost = fixed_cost_of_work_intercept + fixed_cost_of_work_age_trend * ( - age - reference_age - ) reentry_cost = jnp.where( lagged_labor_supply == LaggedLaborSupply.did_not_work, labor_force_reentry_cost, 0.0, ) work_loss = jnp.where( - is_working, working_hours_value + fixed_cost + reentry_cost, 0.0 + working_hours_value > 0.0, + working_hours_value + fixed_cost_of_work + reentry_cost, + 0.0, ) return time_endowment - health_loss - work_loss -def leisure_tied( +def leisure_canwork_tied( working_hours_value: FloatND, - age: int, good_health: IntND, - time_endowment: float, - leisure_cost_of_bad_health: float, - fixed_cost_of_work_intercept: float, - fixed_cost_of_work_age_trend: float, - reference_age: int, + time_endowment: ScalarFloat, + leisure_cost_of_bad_health: ScalarFloat, + fixed_cost_of_work: ScalarFloat, ) -> FloatND: - """Compute leisure for tied regimes (no reentry cost, no lagged_labor_supply).""" + """Compute leisure for canwork tied regimes. + + No need to consider reentry costs. + """ health_loss = jnp.where(good_health, 0.0, leisure_cost_of_bad_health) - fixed_cost = fixed_cost_of_work_intercept + fixed_cost_of_work_age_trend * ( - age - reference_age - ) work_loss = jnp.where( - working_hours_value > 0.0, working_hours_value + fixed_cost, 0.0 + working_hours_value > 0.0, working_hours_value + fixed_cost_of_work, 0.0 ) return time_endowment - health_loss - work_loss -def leisure_retired( +def leisure_forcedout( good_health: IntND, - time_endowment: float, - leisure_cost_of_bad_health: float, + time_endowment: ScalarFloat, + leisure_cost_of_bad_health: ScalarFloat, ) -> FloatND: - """Compute leisure for retired agents (no work).""" + """Compute leisure for forcedout regimes (no work).""" health_loss = jnp.where(good_health, 0.0, leisure_cost_of_bad_health) return time_endowment - health_loss @@ -129,14 +133,18 @@ def consumption_equiv( return consumption_dollars / equivalence_scale -def u_can_work( +def u_alive( consumption_equiv: FloatND, leisure: FloatND, consumption_weight: FloatND, coefficient_rra: FloatND, utility_scale_factor: FloatND, ) -> FloatND: - """Within-period utility for canwork regimes: CES over consumption and leisure.""" + """Within-period utility for every non-dead regime: CES over consumption and leisure. + + `leisure` is a DAG input — supplied per-regime by `leisure_canwork_retiree_or_nongroup`, + `leisure_canwork_tied`, or `leisure_forcedout`. + """ composite = consumption_equiv**consumption_weight * leisure ** ( 1.0 - consumption_weight ) @@ -152,49 +160,6 @@ def u_can_work( return u * utility_scale_factor -def u_cannot_work( - consumption_equiv: FloatND, - good_health: IntND, - consumption_weight: FloatND, - coefficient_rra: FloatND, - utility_scale_factor: FloatND, - time_endowment: float, - leisure_cost_of_bad_health: float, -) -> FloatND: - """Within-period utility for forcedout regimes (no work, retired leisure).""" - leisure = leisure_retired( - good_health=good_health, - time_endowment=time_endowment, - leisure_cost_of_bad_health=leisure_cost_of_bad_health, - ) - return u_can_work( - consumption_equiv=consumption_equiv, - leisure=leisure, - consumption_weight=consumption_weight, - coefficient_rra=coefficient_rra, - utility_scale_factor=utility_scale_factor, - ) - - -def u_dead( - assets: ContinuousState, - bequest_shifter: float, - scaled_bequest_weight: float, - consumption_weight: FloatND, - coefficient_rra: FloatND, - utility_scale_factor: FloatND, -) -> FloatND: - """Terminal bequest utility for the dead regime.""" - return bequest( - assets=assets, - bequest_shifter=bequest_shifter, - scaled_bequest_weight=scaled_bequest_weight, - consumption_weight=consumption_weight, - coefficient_rra=coefficient_rra, - utility_scale_factor=utility_scale_factor, - ) - - def consumption_weight( consumption_weights: FloatND, pref_type: DiscreteState, @@ -233,16 +198,16 @@ def discount_factor( def utility_scale_factor( - average_consumption_dollars: float, + average_consumption_equiv: ScalarFloat, consumption_weight: FloatND, coefficient_rra: FloatND, - time_endowment: float, - fixed_cost_of_work_intercept: float, - reference_hours: float, + time_endowment: ScalarFloat, + fixed_cost_of_work_intercept: ScalarFloat, + reference_hours: ScalarFloat, ) -> FloatND: """Compute the scale factor so utility is approximately 1 at typical values.""" average_leisure = time_endowment - reference_hours - fixed_cost_of_work_intercept - u_cons = average_consumption_dollars**consumption_weight + u_cons = average_consumption_equiv**consumption_weight u_leisure = average_leisure ** (1.0 - consumption_weight) one_minus_rra = jnp.where( @@ -257,12 +222,12 @@ def utility_scale_factor( def scaled_bequest_weight( - bequest_weight: float, - consumption_weight: float, - coefficient_rra: float, - time_endowment: float, - time_discount_factor: float, - rate_of_return: float, + bequest_weight: ScalarFloat, + consumption_weight: ScalarFloat, + coefficient_rra: ScalarFloat, + time_endowment: ScalarFloat, + time_discount_factor: ScalarFloat, + rate_of_return: ScalarFloat, ) -> FloatND: """Transform raw bequest weight into the form used in the bequest function. @@ -283,8 +248,8 @@ def scaled_bequest_weight( def bequest( assets: ContinuousState, - bequest_shifter: float, - scaled_bequest_weight: float, + bequest_shifter: ScalarFloat, + scaled_bequest_weight: ScalarFloat, consumption_weight: FloatND, coefficient_rra: FloatND, utility_scale_factor: FloatND, diff --git a/src/aca_model/baseline/health_insurance.py b/src/aca_model/baseline/health_insurance.py index 3732d6d..d0e9322 100644 --- a/src/aca_model/baseline/health_insurance.py +++ b/src/aca_model/baseline/health_insurance.py @@ -15,6 +15,7 @@ import jax.numpy as jnp from lcm import categorical from lcm.typing import ( + Age, BoolND, ContinuousState, DiscreteAction, @@ -22,6 +23,9 @@ FloatND, IntND, Period, + ScalarBool, + ScalarFloat, + ScalarInt, ) from aca_model.agent.labor_market import LaborSupply @@ -29,15 +33,15 @@ @categorical(ordered=False) class BuyPrivate: - no: int - yes: int + no: ScalarInt + yes: ScalarInt @categorical(ordered=False) class HealthInsuranceState: - retiree: int - tied: int - nongroup: int + retiree: ScalarInt + tied: ScalarInt + nongroup: ScalarInt def countable_income( @@ -47,8 +51,8 @@ def countable_income( spousal_income_amounts: FloatND, ss_benefit: FloatND, pension_benefit: FloatND, - ssi_ignored_overall: float, - ssi_ignored_earned: float, + ssi_ignored_overall: ScalarFloat, + ssi_ignored_earned: ScalarFloat, ) -> FloatND: """Compute countable income for SSI eligibility test. @@ -69,7 +73,7 @@ def is_ssi_eligible( assets: ContinuousState, countable_income: FloatND, spousal_income: DiscreteState, - gets_medicare: bool, + gets_medicare: ScalarBool, ssi_assets_test: FloatND, ssi_maximum_benefit: FloatND, ) -> BoolND: @@ -99,21 +103,21 @@ def ssi_benefit( def premium( - age: int, + age: Age, good_health: IntND, is_married: IntND, labor_supply: DiscreteAction, buy_private: DiscreteAction, - premium_intercept: float, - premium_age: int, - premium_age_sq: float, - premium_age_cub: float, - premium_predicted_hcc: float, - premium_good_health: float, - premium_married: float, - premium_works: float, - premium_married_works: float, - premium_minimum: float, + premium_intercept: ScalarFloat, + premium_age: ScalarFloat, + premium_age_sq: ScalarFloat, + premium_age_cub: ScalarFloat, + premium_predicted_hcc: ScalarFloat, + premium_good_health: ScalarFloat, + premium_married: ScalarFloat, + premium_works: ScalarFloat, + premium_married_works: ScalarFloat, + premium_minimum: ScalarFloat, predicted_hcc_insurer: FloatND, ) -> FloatND: """Compute health insurance premium for canwork regimes. @@ -141,20 +145,20 @@ def premium( def premium_insured( - age: int, + age: Age, good_health: IntND, is_married: IntND, labor_supply: DiscreteAction, - premium_intercept: float, - premium_age: int, - premium_age_sq: float, - premium_age_cub: float, - premium_predicted_hcc: float, - premium_good_health: float, - premium_married: float, - premium_works: float, - premium_married_works: float, - premium_minimum: float, + premium_intercept: ScalarFloat, + premium_age: ScalarFloat, + premium_age_sq: ScalarFloat, + premium_age_cub: ScalarFloat, + premium_predicted_hcc: ScalarFloat, + premium_good_health: ScalarFloat, + premium_married: ScalarFloat, + premium_works: ScalarFloat, + premium_married_works: ScalarFloat, + premium_minimum: ScalarFloat, predicted_hcc_insurer: FloatND, ) -> FloatND: """Compute health insurance premium for canwork regimes without `buy_private`. @@ -178,17 +182,17 @@ def premium_insured( def premium_retired( - age: int, + age: Age, good_health: IntND, is_married: IntND, - premium_intercept: float, - premium_age: int, - premium_age_sq: float, - premium_age_cub: float, - premium_predicted_hcc: float, - premium_good_health: float, - premium_married: float, - premium_minimum: float, + premium_intercept: ScalarFloat, + premium_age: ScalarFloat, + premium_age_sq: ScalarFloat, + premium_age_cub: ScalarFloat, + premium_predicted_hcc: ScalarFloat, + premium_good_health: ScalarFloat, + premium_married: ScalarFloat, + premium_minimum: ScalarFloat, predicted_hcc_insurer: FloatND, ) -> FloatND: """Compute health insurance premium for forcedout regimes. @@ -209,9 +213,9 @@ def premium_retired( def oop_costs( total_health_costs: FloatND, - deductible: float | FloatND, - coinsurance_rate: float | FloatND, - oop_max: float | FloatND, + deductible: ScalarFloat | FloatND, + coinsurance_rate: ScalarFloat | FloatND, + oop_max: ScalarFloat | FloatND, ) -> FloatND: """Compute out-of-pocket health care costs. @@ -228,9 +232,9 @@ def oop_costs( def primary_oop( total_health_costs: FloatND, buy_private: DiscreteAction, - deductible: float, - coinsurance_rate: float, - oop_max: float, + deductible: ScalarFloat, + coinsurance_rate: ScalarFloat, + oop_max: ScalarFloat, ) -> FloatND: """Compute primary OOP costs. @@ -272,9 +276,9 @@ def target_his( def oop_with_medicaid( primary_oop: FloatND, is_medicaid_eligible: BoolND, - deductible_medicaid: float, - coinsurance_rate_medicaid: float, - oop_max_medicaid: float, + deductible_medicaid: ScalarFloat, + coinsurance_rate_medicaid: ScalarFloat, + oop_max_medicaid: ScalarFloat, ) -> FloatND: """Apply Medicaid cost-sharing on top of primary insurance OOP costs. @@ -352,7 +356,7 @@ def total_costs( log_std: FloatND, hcc_persistent: ContinuousState, hcc_transitory: ContinuousState, - std_xsect_persistent: float, + std_xsect_persistent: ScalarFloat, ) -> FloatND: """Compute total health care costs from log-normal model. diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index a2e3a13..836ff71 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -23,7 +23,7 @@ ) from lcm.grids.continuous import ContinuousGrid from lcm.grids.piecewise import Piece, PiecewiseLinSpacedGrid -from lcm.typing import BoolND, FloatND, RegimeName, UserParams +from lcm.typing import BoolND, FloatND, RegimeName, ScalarInt, UserParams from aca_model.agent import ( assets_and_income, @@ -42,25 +42,25 @@ @categorical(ordered=False) class RegimeId: - retiree_nomc_inelig_canwork: int - tied_nomc_inelig_canwork: int - nongroup_nomc_inelig_canwork: int - retiree_dimc_inelig_canwork: int - nongroup_dimc_inelig_canwork: int - retiree_nomc_choose_canwork: int - tied_nomc_choose_canwork: int - nongroup_nomc_choose_canwork: int - retiree_dimc_choose_canwork: int - nongroup_dimc_choose_canwork: int - retiree_oamc_choose_canwork: int - tied_oamc_choose_canwork: int - nongroup_oamc_choose_canwork: int - retiree_oamc_forced_canwork: int - tied_oamc_forced_canwork: int - nongroup_oamc_forced_canwork: int - retiree_oamc_forced_forcedout: int - nongroup_oamc_forced_forcedout: int - dead: int + retiree_nomc_inelig_canwork: ScalarInt + tied_nomc_inelig_canwork: ScalarInt + nongroup_nomc_inelig_canwork: ScalarInt + retiree_dimc_inelig_canwork: ScalarInt + nongroup_dimc_inelig_canwork: ScalarInt + retiree_nomc_choose_canwork: ScalarInt + tied_nomc_choose_canwork: ScalarInt + nongroup_nomc_choose_canwork: ScalarInt + retiree_dimc_choose_canwork: ScalarInt + nongroup_dimc_choose_canwork: ScalarInt + retiree_oamc_choose_canwork: ScalarInt + tied_oamc_choose_canwork: ScalarInt + nongroup_oamc_choose_canwork: ScalarInt + retiree_oamc_forced_canwork: ScalarInt + tied_oamc_forced_canwork: ScalarInt + nongroup_oamc_forced_canwork: ScalarInt + retiree_oamc_forced_forcedout: ScalarInt + nongroup_oamc_forced_forcedout: ScalarInt + dead: ScalarInt class RegimeSpec(TypedDict): @@ -213,6 +213,13 @@ class Grids: can vary across optimizer iterations without re-importing this module). """ +# AR(1) persistence of the Rouwenhorst shocks. Calibrated once; not +# routed through fixed_params because they shape the grid topology +# rather than feed any DAG function. The Rouwenhorst innovation std is +# `sqrt(1 - rho**2)` so the grid carries unit unconditional variance. +_HCC_RHO = 0.925 +_WAGE_RHO = 0.977 + def build_grids( *, @@ -240,7 +247,6 @@ def build_grids( # grid to have unconditional variance 1, the Rouwenhorst innovation # std must be √(1 − ρ²). Passing the σ_y itself (≈0.577 for hcc, # 0.5627 for wage) would mis-scale the grid. - _WAGE_RHO = 0.977 wage_res = lcm.shocks.ar1.Rouwenhorst( n_points=grid_config.n_wage_res_gridpoints, rho=_WAGE_RHO, @@ -277,9 +283,6 @@ def build_grids( ) -_HCC_RHO = 0.925 - - def get_hcc_persistent_shock(*, grid_config: GridConfig) -> lcm.shocks.ar1.Rouwenhorst: """Return the persistent-HCC AR(1) shock grid for a given `grid_config`. @@ -442,7 +445,7 @@ def build_dead_regime(grids: Grids) -> Regime: return Regime( transition=None, functions={ - "utility": preferences.u_dead, + "utility": preferences.bequest, "consumption_weight": preferences.consumption_weight, "coefficient_rra": preferences.coefficient_rra, "utility_scale_factor": preferences.utility_scale_factor, @@ -468,18 +471,13 @@ def select_ss_benefit(spec: RegimeSpec) -> Callable[..., Any]: return social_security.benefit_inelig_pre65 -def select_utility(spec: RegimeSpec) -> Callable[..., Any]: - """Select the utility function for a regime.""" - if spec["canwork"] != "canwork": - return preferences.u_cannot_work - return preferences.u_can_work - - def _select_leisure(spec: RegimeSpec) -> Callable[..., Any]: - """Select the leisure function for a canwork regime.""" + """Select the leisure function for a non-dead regime.""" + if spec["canwork"] == "forcedout": + return preferences.leisure_forcedout if spec["his"] == "tied": - return preferences.leisure_tied - return preferences.leisure + return preferences.leisure_canwork_tied + return preferences.leisure_canwork_retiree_or_nongroup def build_common_functions(spec: RegimeSpec) -> dict: @@ -503,9 +501,11 @@ def build_common_functions(spec: RegimeSpec) -> dict: if can_work: functions["working_hours_value"] = labor_market.working_hours_value - functions["leisure"] = _select_leisure(spec) functions["labor_income"] = labor_market.income + functions["fixed_cost_of_work"] = preferences.fixed_cost_of_work + functions["leisure"] = _select_leisure(spec) + functions["utility"] = preferences.u_alive functions["capital_income"] = assets_and_income.capital_income # spousal_income_amounts is a lookup table param, not a DAG function functions["is_married"] = labor_market.is_married @@ -552,7 +552,12 @@ def build_common_functions(spec: RegimeSpec) -> dict: def precompute_target_regimes(spec: RegimeSpec) -> MappingProxyType[str, int]: - """Pre-compute target regime IDs for each next-age bracket.""" + """Pre-compute target regime IDs for each next-age bracket. + + Coerces each `RegimeId.` (`ScalarInt`, post-pylcm#349) to a + Python `int` so the returned mapping's values can serve as dict + keys and `in`-set members downstream. + """ def _resolve(his_val: str, mc_val: str, ss_val: str, canwork_val: str) -> int: for name, s in REGIME_SPECS.items(): @@ -562,8 +567,8 @@ def _resolve(his_val: str, mc_val: str, ss_val: str, canwork_val: str) -> int: and s["ss"] == ss_val and s["canwork"] == canwork_val ): - return getattr(RegimeId, name) - return RegimeId.dead + return int(getattr(RegimeId, name)) + return int(RegimeId.dead) ng_his = "nongroup" if spec["his"] == "tied" else spec["his"] @@ -674,7 +679,7 @@ def _build_per_target_regime_assets( targets use the full `next_assets` with the pension correction. """ target_regimes = precompute_target_regimes(spec) - id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} + id_to_name = {int(getattr(RegimeId, name)): name for name in REGIME_SPECS} result: dict[RegimeName, Callable[..., FloatND]] = {} seen_ids: set[int] = set() @@ -701,7 +706,7 @@ def _build_per_target_regime_health( Cross-grid transitions (3->2) happen at the age-65 boundary. """ target_regimes = precompute_target_regimes(spec) - id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} + id_to_name = {int(getattr(RegimeId, name)): name for name in REGIME_SPECS} result: dict[RegimeName, MarkovTransition] = {} seen_ids: set[int] = set() @@ -737,7 +742,7 @@ def _build_per_target_regime_claimed_ss( return {} target_regimes = precompute_target_regimes(spec) - id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} + id_to_name = {int(getattr(RegimeId, name)): name for name in REGIME_SPECS} result: dict[RegimeName, Callable[..., BoolND]] = {} seen_ids: set[int] = set() @@ -778,7 +783,7 @@ def _build_per_target_regime_lagged_labor_supply( return {} target_regimes = precompute_target_regimes(spec) - id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} + id_to_name = {int(getattr(RegimeId, name)): name for name in REGIME_SPECS} result: dict[RegimeName, Callable[..., BoolND]] = {} seen_ids: set[int] = set() diff --git a/src/aca_model/baseline/regimes/_nongroup.py b/src/aca_model/baseline/regimes/_nongroup.py index a723b44..730dcc4 100644 --- a/src/aca_model/baseline/regimes/_nongroup.py +++ b/src/aca_model/baseline/regimes/_nongroup.py @@ -7,7 +7,7 @@ from collections.abc import Callable from lcm import MarkovTransition, Regime -from lcm.typing import DiscreteAction, FloatND, Period +from lcm.typing import Age, DiscreteAction, FloatND, Period from aca_model.agent import assets_and_income, preferences from aca_model.agent.labor_market import LaborSupply @@ -25,7 +25,6 @@ make_targets, select_ss_benefit, select_target_for_age, - select_utility, ) from aca_model.environment import pensions @@ -41,7 +40,7 @@ def _make_transition_canwork( """ def transition( - age: int, + age: Age, period: Period, labor_supply: DiscreteAction, survival_probs: FloatND, @@ -65,7 +64,7 @@ def _make_transition_forcedout( """ def transition( - age: int, + age: Age, period: Period, survival_probs: FloatND, ) -> FloatND: @@ -80,7 +79,6 @@ def _build_functions(spec: RegimeSpec) -> dict: can_work = spec["canwork"] == "canwork" functions = build_common_functions(spec) - functions["utility"] = select_utility(spec) functions["ss_benefit"] = select_ss_benefit(spec) # his and gets_medicare are fixed params (constants per regime), diff --git a/src/aca_model/baseline/regimes/_retiree.py b/src/aca_model/baseline/regimes/_retiree.py index 4f16faa..4cb52d9 100644 --- a/src/aca_model/baseline/regimes/_retiree.py +++ b/src/aca_model/baseline/regimes/_retiree.py @@ -8,7 +8,7 @@ import jax.numpy as jnp from lcm import MarkovTransition, Regime -from lcm.typing import BoolND, DiscreteAction, FloatND, Period +from lcm.typing import Age, BoolND, DiscreteAction, FloatND, Period from aca_model.agent import assets_and_income, preferences from aca_model.agent.labor_market import LaborSupply @@ -26,7 +26,6 @@ make_targets, select_ss_benefit, select_target_for_age, - select_utility, ) from aca_model.environment import pensions @@ -43,7 +42,7 @@ def _make_transition_canwork( """ def transition( - age: int, + age: Age, period: Period, labor_supply: DiscreteAction, is_medicaid_eligible: BoolND, @@ -72,7 +71,7 @@ def _make_transition_forcedout( """ def transition( - age: int, + age: Age, period: Period, is_medicaid_eligible: BoolND, survival_probs: FloatND, @@ -92,7 +91,6 @@ def _build_functions(spec: RegimeSpec) -> dict: can_work = spec["canwork"] == "canwork" functions = build_common_functions(spec) - functions["utility"] = select_utility(spec) functions["ss_benefit"] = select_ss_benefit(spec) # his and gets_medicare are fixed params (constants per regime), diff --git a/src/aca_model/baseline/regimes/_tied.py b/src/aca_model/baseline/regimes/_tied.py index df76fa4..c9eeecc 100644 --- a/src/aca_model/baseline/regimes/_tied.py +++ b/src/aca_model/baseline/regimes/_tied.py @@ -9,7 +9,7 @@ import jax.numpy as jnp from lcm import MarkovTransition, Regime -from lcm.typing import BoolND, DiscreteAction, FloatND, Period +from lcm.typing import Age, BoolND, DiscreteAction, FloatND, Period from aca_model.agent import assets_and_income, preferences from aca_model.agent.labor_market import LaborSupply @@ -27,7 +27,6 @@ make_targets, select_ss_benefit, select_target_for_age, - select_utility, ) from aca_model.environment import pensions @@ -44,7 +43,7 @@ def _make_transition_canwork( """ def transition( - age: int, + age: Age, period: Period, labor_supply: DiscreteAction, is_medicaid_eligible: BoolND, @@ -70,7 +69,6 @@ def _build_functions(spec: RegimeSpec) -> dict: """Build functions dict for a tied regime.""" functions = build_common_functions(spec) - functions["utility"] = select_utility(spec) functions["ss_benefit"] = select_ss_benefit(spec) # his and gets_medicare are fixed params (constants per regime), diff --git a/src/aca_model/consumption_dollars_grid.py b/src/aca_model/consumption_dollars_grid.py index 7487fd8..5de175d 100644 --- a/src/aca_model/consumption_dollars_grid.py +++ b/src/aca_model/consumption_dollars_grid.py @@ -39,7 +39,7 @@ def inject_consumption_dollars_points( Walks every regime, reads its `consumption_dollars` action grid, and writes `params[regime_name]["consumption_dollars"] = {"points": }`. - The lower two gridpoints are the single and married unequiv + The lower two gridpoints are the single and married Dollar-valued transfer floors; the rest are geomspaced from the married floor up to `MAX_CONSUMPTION_DOLLARS`. @@ -101,7 +101,7 @@ def _compute_consumption_dollars_points( ) -> Array: """Return log-spaced consumption_dollars gridpoints with both floors pinned. - Single and married households face different unequiv (in-$) floors + Single and married households face different Dollar-valued floors (`consumption_equiv_floor` and the married-scaled twin respectively). Both must land exactly on the action grid so the borrowing constraint's `max(cash_on_hand, floor)` kink boundary is @@ -111,14 +111,26 @@ def _compute_consumption_dollars_points( `MAX_CONSUMPTION_DOLLARS` so the two pinned points stay strictly increasing. """ - married_unequiv_floor = consumption_equiv_floor * jnp.asarray(2.0) ** exponent + married_dollar_floor = consumption_equiv_floor * jnp.asarray(2.0) ** exponent tail = jnp.geomspace( - married_unequiv_floor, MAX_CONSUMPTION_DOLLARS, num=n_points - 1 + married_dollar_floor, MAX_CONSUMPTION_DOLLARS, num=n_points - 1 ) pts = jnp.concatenate([consumption_equiv_floor[None], tail]) # `jnp.geomspace` returns `start * r^0` for the first tail element, - # which mathematically equals `married_unequiv_floor` but drifts by + # which mathematically equals `married_dollar_floor` but drifts by # sub-ULP on some XLA backends. Pin the slot back to the exact # arithmetic value so the borrowing-constraint kink boundary at the # married floor is exactly representable. - return pts.at[1].set(married_unequiv_floor) + pts = pts.at[1].set(married_dollar_floor) + # The runtime params are concrete, not JIT-traced — a Python `if` + # is fine. Guard against a degenerate grid where the geomspace step + # is too small for the next point to clear `married_dollar_floor`. + if not float(married_dollar_floor) < float(pts[2]): + msg = ( + f"consumption_dollars grid is not strictly increasing at the " + f"married-floor kink: pts[1]={float(married_dollar_floor):.6g}, " + f"pts[2]={float(pts[2]):.6g}. Either `MAX_CONSUMPTION_DOLLARS` " + f"is too close to the married floor or `n_points` is too small." + ) + raise ValueError(msg) + return pts diff --git a/src/aca_model/environment/pensions.py b/src/aca_model/environment/pensions.py index eef72d4..cb03a6c 100644 --- a/src/aca_model/environment/pensions.py +++ b/src/aca_model/environment/pensions.py @@ -4,7 +4,7 @@ """ import jax.numpy as jnp -from lcm.typing import ContinuousState, FloatND, IntND, Period +from lcm.typing import ContinuousState, FloatND, IntND, Period, ScalarFloat def benefit( @@ -131,7 +131,7 @@ def wealth_next_before_adjustment( pension_wealth: FloatND, pension_benefit: FloatND, pension_accrual: FloatND, - rate_of_return: float, + rate_of_return: ScalarFloat, unconditional_survival_prob: FloatND, period: Period, ) -> FloatND: diff --git a/src/aca_model/environment/social_security.py b/src/aca_model/environment/social_security.py index e3574cf..8b655d1 100644 --- a/src/aca_model/environment/social_security.py +++ b/src/aca_model/environment/social_security.py @@ -9,15 +9,24 @@ import jax.numpy as jnp from lcm import categorical -from lcm.typing import ContinuousState, DiscreteAction, DiscreteState, FloatND, Period +from lcm.typing import ( + Age, + ContinuousState, + DiscreteAction, + DiscreteState, + FloatND, + Period, + ScalarFloat, + ScalarInt, +) from aca_model.agent.labor_market import LaborSupply @categorical(ordered=False) class ClaimedSS: - no: int - yes: int + no: ScalarInt + yes: ScalarInt def next_claimed_ss( @@ -77,17 +86,17 @@ def benefit_forced( def benefit_choose_post65( pia: FloatND, - age: int, + age: Age, period: Period, claim_ss: DiscreteAction, claimed_ss: DiscreteState, labor_supply: DiscreteAction, labor_income: FloatND, early_ret_adjustment: FloatND, - normal_retirement_age: int, + normal_retirement_age: ScalarInt, earnings_test_threshold: FloatND, earnings_test_fraction: FloatND, - earnings_test_repealed_age: int, + earnings_test_repealed_age: ScalarInt, ) -> FloatND: """SS benefit for post-65, ss=choose: SS if claiming, 0 otherwise.""" ss = jnp.maximum(claim_ss, claimed_ss) @@ -110,7 +119,7 @@ def benefit_choose_post65( def benefit_choose_pre65( pia: FloatND, ssdi_pia: FloatND, - age: int, + age: Age, period: Period, claim_ss: DiscreteAction, claimed_ss: DiscreteState, @@ -118,11 +127,11 @@ def benefit_choose_pre65( labor_supply: DiscreteAction, labor_income: FloatND, early_ret_adjustment: FloatND, - normal_retirement_age: int, + normal_retirement_age: ScalarInt, earnings_test_threshold: FloatND, earnings_test_fraction: FloatND, - earnings_test_repealed_age: int, - ssdi_substantial_gainful_activity: float, + earnings_test_repealed_age: ScalarInt, + ssdi_substantial_gainful_activity: ScalarFloat, ) -> FloatND: """SS benefit for pre-65, ss=choose: SS if claiming, SSDI if disabled, else 0.""" ss = jnp.maximum(claim_ss, claimed_ss) @@ -160,7 +169,7 @@ def benefit_inelig_pre65( ssdi_pia: FloatND, health: DiscreteState, labor_income: FloatND, - ssdi_substantial_gainful_activity: float, + ssdi_substantial_gainful_activity: ScalarFloat, ) -> FloatND: """SS benefit for pre-65, ss=inelig: SSDI if disabled, else 0.""" is_disabled = health == 0 @@ -200,16 +209,16 @@ def benefit_withheld_fraction( def _apply_benefit_rules( *, pia: FloatND, - age: int, + age: Age, period: Period, ss: FloatND, work: FloatND, labor_income: FloatND, early_ret_adjustment: FloatND, - normal_retirement_age: int, + normal_retirement_age: ScalarInt, earnings_test_threshold: FloatND, earnings_test_fraction: FloatND, - earnings_test_repealed_age: int, + earnings_test_repealed_age: ScalarInt, ) -> FloatND: """Apply early retirement adjustment and earnings test to PIA. @@ -246,16 +255,16 @@ def next_aime( aime: ContinuousState, labor_income: FloatND, period: Period, - age: int, + age: Age, benefit_withheld_fraction: FloatND, earnings_test_credited_back: FloatND, - earnings_test_repealed_age: int, + earnings_test_repealed_age: ScalarInt, pia_table: FloatND, pia_aime_grid: FloatND, - aime_accrual_factor: float, - aggregate_wage_growth: float, - aime_last_age_with_indexing: int, - aime_kink_2: float, + aime_accrual_factor: ScalarFloat, + aggregate_wage_growth: ScalarFloat, + aime_last_age_with_indexing: ScalarInt, + aime_kink_2: ScalarFloat, ratio_lowest_earnings: FloatND, ) -> ContinuousState: """Compute next period's AIME given labor earnings. @@ -306,19 +315,19 @@ def next_aime_disabled( aime: ContinuousState, labor_income: FloatND, period: Period, - age: int, + age: Age, health: DiscreteState, benefit_withheld_fraction: FloatND, earnings_test_credited_back: FloatND, - earnings_test_repealed_age: int, + earnings_test_repealed_age: ScalarInt, pia_table: FloatND, pia_aime_grid: FloatND, - aime_accrual_factor: float, - aggregate_wage_growth: float, - aime_last_age_with_indexing: int, - aime_kink_2: float, + aime_accrual_factor: ScalarFloat, + aggregate_wage_growth: ScalarFloat, + aime_last_age_with_indexing: ScalarInt, + aime_kink_2: ScalarFloat, ratio_lowest_earnings: FloatND, - medicare_age: int, + medicare_age: ScalarInt, di_dropout_scale: FloatND, di_dropout_next_period_ratio: FloatND, ) -> ContinuousState: diff --git a/tests/test_baseline_equivalence.py b/tests/test_baseline_equivalence.py index 7d18e5c..5e68e86 100644 --- a/tests/test_baseline_equivalence.py +++ b/tests/test_baseline_equivalence.py @@ -75,9 +75,9 @@ def test_aca_cash_on_hand_matches_baseline_when_neutral() -> None: def test_baseline_primary_oop_no_cost_sharing_scale() -> None: """Baseline primary_oop applies raw deductible/coinsurance/oop_max.""" costs = jnp.array(5000.0) - deductible = 500.0 - coinsurance = 0.2 - oop_max_val = 3000.0 + deductible = jnp.asarray(500.0) + coinsurance = jnp.asarray(0.2) + oop_max_val = jnp.asarray(3000.0) result = health_insurance.primary_oop( total_health_costs=costs, buy_private=jnp.array(BuyPrivate.yes), @@ -97,9 +97,9 @@ def test_baseline_primary_oop_no_cost_sharing_scale() -> None: def test_aca_primary_oop_scaled_reduces_costs() -> None: """ACA primary_oop with scale < 1.0 reduces OOP costs.""" costs = jnp.array(5000.0) - deductible = 500.0 - coinsurance = 0.2 - oop_max_val = 3000.0 + deductible = jnp.asarray(500.0) + coinsurance = jnp.asarray(0.2) + oop_max_val = jnp.asarray(3000.0) oop_full = aca_hi.primary_oop( total_health_costs=costs, cost_sharing_scale=jnp.array(1.0), diff --git a/tests/test_health_insurance.py b/tests/test_health_insurance.py index cd89c9f..06a23b7 100644 --- a/tests/test_health_insurance.py +++ b/tests/test_health_insurance.py @@ -18,7 +18,7 @@ def test_ssi_eligible_assets_too_high() -> None: assets=jnp.array(5000.0), countable_income=jnp.array(1000.0), spousal_income=jnp.array(0), - gets_medicare=True, + gets_medicare=jnp.asarray(True), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, ) @@ -30,7 +30,7 @@ def test_ssi_eligible_income_too_high() -> None: assets=jnp.array(1000.0), countable_income=jnp.array(9000.0), spousal_income=jnp.array(0), - gets_medicare=True, + gets_medicare=jnp.asarray(True), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, ) @@ -42,7 +42,7 @@ def test_ssi_eligible_no_medicare() -> None: assets=jnp.array(1000.0), countable_income=jnp.array(1000.0), spousal_income=jnp.array(0), - gets_medicare=False, + gets_medicare=jnp.asarray(False), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, ) @@ -54,7 +54,7 @@ def test_ssi_eligible_all_pass() -> None: assets=jnp.array(1000.0), countable_income=jnp.array(1000.0), spousal_income=jnp.array(0), - gets_medicare=True, + gets_medicare=jnp.asarray(True), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, ) @@ -148,20 +148,20 @@ def test_compute_table_uniform_transition(table_inputs: dict) -> None: _PREMIUM_KWARGS: dict = { - "age": 60, + "age": jnp.int32(60), "good_health": jnp.array(True), "is_married": jnp.array(False), "labor_supply": jnp.array(LaborSupply.h2000), - "premium_intercept": 1000.0, - "premium_age": 0, - "premium_age_sq": 0.0, - "premium_age_cub": 0.0, - "premium_predicted_hcc": 0.0, - "premium_good_health": 0.0, - "premium_married": 0.0, - "premium_works": 0.0, - "premium_married_works": 0.0, - "premium_minimum": 500.0, + "premium_intercept": jnp.asarray(1000.0), + "premium_age": jnp.asarray(0.0), + "premium_age_sq": jnp.asarray(0.0), + "premium_age_cub": jnp.asarray(0.0), + "premium_predicted_hcc": jnp.asarray(0.0), + "premium_good_health": jnp.asarray(0.0), + "premium_married": jnp.asarray(0.0), + "premium_works": jnp.asarray(0.0), + "premium_married_works": jnp.asarray(0.0), + "premium_minimum": jnp.asarray(500.0), "predicted_hcc_insurer": jnp.array(0.0), } @@ -187,9 +187,9 @@ def test_primary_oop_insured_applies_deductible_coinsurance() -> None: result = health_insurance.primary_oop( total_health_costs=jnp.array(10000.0), buy_private=jnp.array(BuyPrivate.yes), - deductible=500.0, - coinsurance_rate=0.2, - oop_max=5000.0, + deductible=jnp.asarray(500.0), + coinsurance_rate=jnp.asarray(0.2), + oop_max=jnp.asarray(5000.0), ) expected = 500.0 + (10000.0 - 500.0) * 0.2 # 2400 assert jnp.isclose(result, expected, atol=ATOL) @@ -200,8 +200,8 @@ def test_primary_oop_uninsured_equals_total_costs() -> None: result = health_insurance.primary_oop( total_health_costs=total, buy_private=jnp.array(BuyPrivate.no), - deductible=500.0, - coinsurance_rate=0.2, - oop_max=5000.0, + deductible=jnp.asarray(500.0), + coinsurance_rate=jnp.asarray(0.2), + oop_max=jnp.asarray(5000.0), ) assert jnp.isclose(result, total) diff --git a/tests/test_model_components.py b/tests/test_model_components.py index 5b7df6a..b3569c5 100644 --- a/tests/test_model_components.py +++ b/tests/test_model_components.py @@ -7,77 +7,68 @@ def test_equivalence_scale_single() -> None: - result = preferences.equivalence_scale(jnp.array(False), 0.7) + result = preferences.equivalence_scale(jnp.array(False), jnp.asarray(0.7)) assert jnp.isclose(result, 1.0) def test_equivalence_scale_married() -> None: - result = preferences.equivalence_scale(jnp.array(True), 0.7) + result = preferences.equivalence_scale(jnp.array(True), jnp.asarray(0.7)) assert jnp.isclose(result, 2.0**0.7) def test_leisure_not_working() -> None: - result = preferences.leisure( + result = preferences.leisure_canwork_retiree_or_nongroup( working_hours_value=jnp.array(0.0), - age=60, good_health=jnp.array(1.0), lagged_labor_supply=jnp.array(0), - time_endowment=5000.0, - leisure_cost_of_bad_health=500.0, - fixed_cost_of_work_intercept=100.0, - fixed_cost_of_work_age_trend=5, - labor_force_reentry_cost=200.0, - reference_age=50, + time_endowment=jnp.asarray(5000.0), + leisure_cost_of_bad_health=jnp.asarray(500.0), + fixed_cost_of_work=jnp.asarray(150.0), + labor_force_reentry_cost=jnp.asarray(200.0), ) assert jnp.isclose(result, 5000.0) def test_leisure_working_good_health() -> None: - result = preferences.leisure( + result = preferences.leisure_canwork_retiree_or_nongroup( working_hours_value=jnp.array(2000.0), - age=60, good_health=jnp.array(1.0), lagged_labor_supply=jnp.array(1), - time_endowment=5000.0, - leisure_cost_of_bad_health=500.0, - fixed_cost_of_work_intercept=100.0, - fixed_cost_of_work_age_trend=5, - labor_force_reentry_cost=200.0, - reference_age=50, + time_endowment=jnp.asarray(5000.0), + leisure_cost_of_bad_health=jnp.asarray(500.0), + fixed_cost_of_work=jnp.asarray(150.0), + labor_force_reentry_cost=jnp.asarray(200.0), ) - # 5000 - 0 (good health) - (2000 + 100 + 5*(60-50) + 0 (lagged=1)) - expected = 5000.0 - 2000.0 - 100.0 - 50.0 + # 5000 - 0 (good health) - (2000 + 150 + 0 (lagged=1)) + expected = 5000.0 - 2000.0 - 150.0 assert jnp.isclose(result, expected) def test_leisure_reentry_cost() -> None: - result = preferences.leisure( + result = preferences.leisure_canwork_retiree_or_nongroup( working_hours_value=jnp.array(2000.0), - age=60, good_health=jnp.array(1.0), lagged_labor_supply=jnp.array(0), - time_endowment=5000.0, - leisure_cost_of_bad_health=500.0, - fixed_cost_of_work_intercept=100.0, - fixed_cost_of_work_age_trend=5, - labor_force_reentry_cost=200.0, - reference_age=50, + time_endowment=jnp.asarray(5000.0), + leisure_cost_of_bad_health=jnp.asarray(500.0), + fixed_cost_of_work=jnp.asarray(150.0), + labor_force_reentry_cost=jnp.asarray(200.0), ) - expected = 5000.0 - 2000.0 - 100.0 - 50.0 - 200.0 + expected = 5000.0 - 2000.0 - 150.0 - 200.0 assert jnp.isclose(result, expected) def test_leisure_bad_health() -> None: - result = preferences.leisure_retired( + result = preferences.leisure_forcedout( good_health=jnp.array(0.0), - time_endowment=5000.0, - leisure_cost_of_bad_health=500.0, + time_endowment=jnp.asarray(5000.0), + leisure_cost_of_bad_health=jnp.asarray(500.0), ) assert jnp.isclose(result, 4500.0) def test_utility_positive_leisure() -> None: - result = preferences.u_can_work( + result = preferences.u_alive( consumption_equiv=jnp.array(10000.0), leisure=jnp.array(3000.0), consumption_weight=jnp.array(0.4), @@ -88,7 +79,7 @@ def test_utility_positive_leisure() -> None: def test_utility_log_case() -> None: - result = preferences.u_can_work( + result = preferences.u_alive( consumption_equiv=jnp.array(10000.0), leisure=jnp.array(3000.0), consumption_weight=jnp.array(0.4), @@ -103,8 +94,8 @@ def test_utility_log_case() -> None: def test_bequest_positive_assets() -> None: result = preferences.bequest( assets=jnp.array(100000.0), - bequest_shifter=5000.0, - scaled_bequest_weight=0.5, + bequest_shifter=jnp.asarray(5000.0), + scaled_bequest_weight=jnp.asarray(0.5), consumption_weight=jnp.array(0.4), coefficient_rra=jnp.array(2.0), utility_scale_factor=jnp.array(1.0), @@ -115,8 +106,8 @@ def test_bequest_positive_assets() -> None: def test_bequest_zero_assets() -> None: result = preferences.bequest( assets=jnp.array(0.0), - bequest_shifter=5000.0, - scaled_bequest_weight=0.5, + bequest_shifter=jnp.asarray(5000.0), + scaled_bequest_weight=jnp.asarray(0.5), consumption_weight=jnp.array(0.4), coefficient_rra=jnp.array(2.0), utility_scale_factor=jnp.array(1.0), @@ -164,13 +155,13 @@ def test_next_aime_accrual() -> None: age=jnp.int32(55), benefit_withheld_fraction=jnp.array(0.0), earnings_test_credited_back=jnp.zeros(100), - earnings_test_repealed_age=70, + earnings_test_repealed_age=jnp.int32(70), pia_table=jnp.array([0.0, 711.9, 2115.1, 3015.1]), pia_aime_grid=jnp.array([0.0, 791.0, 4768.0, 8000.0]), - aime_accrual_factor=1 / 35, - aggregate_wage_growth=0.02, - aime_last_age_with_indexing=60, - aime_kink_2=8000.0, + aime_accrual_factor=jnp.asarray(1 / 35), + aggregate_wage_growth=jnp.asarray(0.02), + aime_last_age_with_indexing=jnp.int32(60), + aime_kink_2=jnp.asarray(8000.0), ratio_lowest_earnings=ratio, ) assert result > 1000.0 diff --git a/tests/test_pension_integration.py b/tests/test_pension_integration.py index 0f6c07d..ae27d3f 100644 --- a/tests/test_pension_integration.py +++ b/tests/test_pension_integration.py @@ -12,7 +12,7 @@ from aca_model.environment import pensions ATOL = 0.01 -RATE_OF_RETURN = 0.03 +RATE_OF_RETURN = jnp.asarray(0.03) # Pension imputation coefficients — two HIS types with different intercepts. # HIS 0 (retiree): intercept = -50, HIS 1 (nongroup): intercept = -80. diff --git a/tests/test_pensions.py b/tests/test_pensions.py index 514ab8c..286c72a 100644 --- a/tests/test_pensions.py +++ b/tests/test_pensions.py @@ -141,7 +141,7 @@ def test_pension_wealth_next_accrual_only() -> None: lli = math.log(10000) prob = math.exp(0.1) / (1 + math.exp(0.1)) accrual = lli * 0.5 * prob * 10000 - r = 0.03 + r = jnp.asarray(0.03) result = pensions.wealth_next_before_adjustment( pension_wealth=jnp.array(0.0), pension_benefit=jnp.array(0.0), @@ -157,7 +157,7 @@ def test_pension_wealth_next_with_benefit() -> None: lli = math.log(10000) prob = math.exp(0.1) / (1 + math.exp(0.1)) accrual = lli * 0.5 * prob * 10000 - r = 0.03 + r = jnp.asarray(0.03) result = pensions.wealth_next_before_adjustment( pension_wealth=jnp.array(3000.0), pension_benefit=jnp.array(2000.0), diff --git a/tests/test_preferences.py b/tests/test_preferences.py index 1b5107f..8017635 100644 --- a/tests/test_preferences.py +++ b/tests/test_preferences.py @@ -7,16 +7,18 @@ from aca_model.agent import preferences -# Struct-ret preference parameters -CONSUMPTION_WEIGHT = 0.6 -TIME_DISCOUNT_FACTOR = 0.85 -TIME_ENDOWMENT = 5000.0 -FIXED_COST_INTERCEPT = 0.0 -AVERAGE_CONSUMPTION = 10000.0 -RATE_OF_RETURN = 0.01 -BEQUEST_WEIGHT = 0.02 -BEQUEST_SHIFTER = 500_000.0 -REFERENCE_HOURS = 1000.0 +# Struct-ret preference parameters. Tests call DAG functions directly, so +# every scalar fixed_param is supplied as a 0-d jax array (the type pylcm +# casts user-provided Python scalars to before passing them into the DAG). +CONSUMPTION_WEIGHT = jnp.asarray(0.6) +TIME_DISCOUNT_FACTOR = jnp.asarray(0.85) +TIME_ENDOWMENT = jnp.asarray(5000.0) +FIXED_COST_INTERCEPT = jnp.asarray(0.0) +AVERAGE_CONSUMPTION = jnp.asarray(10000.0) +RATE_OF_RETURN = jnp.asarray(0.01) +BEQUEST_WEIGHT = jnp.asarray(0.02) +BEQUEST_SHIFTER = jnp.asarray(500_000.0) +REFERENCE_HOURS = jnp.asarray(1000.0) # --- utility_scale_factor --- @@ -24,9 +26,9 @@ def test_utility_scale_factor_crra() -> None: result = preferences.utility_scale_factor( - average_consumption_dollars=AVERAGE_CONSUMPTION, - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(5.0), + average_consumption_equiv=AVERAGE_CONSUMPTION, + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(5.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, @@ -36,9 +38,9 @@ def test_utility_scale_factor_crra() -> None: def test_utility_scale_factor_log() -> None: result = preferences.utility_scale_factor( - average_consumption_dollars=AVERAGE_CONSUMPTION, - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(1.0), + average_consumption_equiv=AVERAGE_CONSUMPTION, + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(1.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, @@ -53,7 +55,7 @@ def test_scaled_bequest_weight_positive() -> None: result = preferences.scaled_bequest_weight( bequest_weight=BEQUEST_WEIGHT, consumption_weight=CONSUMPTION_WEIGHT, - coefficient_rra=5.0, + coefficient_rra=jnp.asarray(5.0), time_endowment=TIME_ENDOWMENT, time_discount_factor=TIME_DISCOUNT_FACTOR, rate_of_return=RATE_OF_RETURN, @@ -65,7 +67,7 @@ def test_scaled_bequest_weight_log() -> None: result = preferences.scaled_bequest_weight( bequest_weight=BEQUEST_WEIGHT, consumption_weight=CONSUMPTION_WEIGHT, - coefficient_rra=1.0, + coefficient_rra=jnp.asarray(1.0), time_endowment=TIME_ENDOWMENT, time_discount_factor=TIME_DISCOUNT_FACTOR, rate_of_return=RATE_OF_RETURN, @@ -75,9 +77,9 @@ def test_scaled_bequest_weight_log() -> None: def test_scaled_bequest_weight_zero() -> None: result = preferences.scaled_bequest_weight( - bequest_weight=0.0, + bequest_weight=jnp.asarray(0.0), consumption_weight=CONSUMPTION_WEIGHT, - coefficient_rra=5.0, + coefficient_rra=jnp.asarray(5.0), time_endowment=TIME_ENDOWMENT, time_discount_factor=TIME_DISCOUNT_FACTOR, rate_of_return=RATE_OF_RETURN, @@ -90,18 +92,18 @@ def test_scaled_bequest_weight_zero() -> None: def test_utility_log_regression() -> None: scale = preferences.utility_scale_factor( - average_consumption_dollars=AVERAGE_CONSUMPTION, - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(1.0), + average_consumption_equiv=AVERAGE_CONSUMPTION, + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(1.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, ) - result = preferences.u_can_work( + result = preferences.u_alive( consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(1.0), + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(1.0), utility_scale_factor=scale, ) assert jnp.isclose(result, 1.005_046_313_660_588_5, rtol=1e-5) @@ -109,18 +111,18 @@ def test_utility_log_regression() -> None: def test_utility_crra_regression() -> None: scale = preferences.utility_scale_factor( - average_consumption_dollars=AVERAGE_CONSUMPTION, - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(5.0), + average_consumption_equiv=AVERAGE_CONSUMPTION, + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(5.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, ) - result = preferences.u_can_work( + result = preferences.u_alive( consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(5.0), + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(5.0), utility_scale_factor=scale, ) assert jnp.isclose(result, -0.836_511_642_073_019_1, rtol=1e-5) @@ -129,25 +131,25 @@ def test_utility_crra_regression() -> None: def test_utility_married_equivalence() -> None: """Married with equiv-scaled consumption_dollars should equal single utility.""" scale = preferences.utility_scale_factor( - average_consumption_dollars=AVERAGE_CONSUMPTION, - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(5.0), + average_consumption_equiv=AVERAGE_CONSUMPTION, + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(5.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, ) - single = preferences.u_can_work( + single = preferences.u_alive( consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(5.0), + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(5.0), utility_scale_factor=scale, ) - married = preferences.u_can_work( + married = preferences.u_alive( consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(5.0), + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(5.0), utility_scale_factor=scale, ) assert jnp.isclose(single, married, rtol=1e-5) @@ -158,9 +160,9 @@ def test_utility_married_equivalence() -> None: def test_bequest_log_regression() -> None: scale = preferences.utility_scale_factor( - average_consumption_dollars=AVERAGE_CONSUMPTION, - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(1.0), + average_consumption_equiv=AVERAGE_CONSUMPTION, + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(1.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, @@ -168,7 +170,7 @@ def test_bequest_log_regression() -> None: bwt = preferences.scaled_bequest_weight( bequest_weight=BEQUEST_WEIGHT, consumption_weight=CONSUMPTION_WEIGHT, - coefficient_rra=1.0, + coefficient_rra=jnp.asarray(1.0), time_endowment=TIME_ENDOWMENT, time_discount_factor=TIME_DISCOUNT_FACTOR, rate_of_return=RATE_OF_RETURN, @@ -176,9 +178,9 @@ def test_bequest_log_regression() -> None: result = preferences.bequest( assets=jnp.array(10000.0), bequest_shifter=BEQUEST_SHIFTER, - scaled_bequest_weight=bwt.item(), - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(1.0), + scaled_bequest_weight=bwt, + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(1.0), utility_scale_factor=scale, ) assert jnp.isclose(result, 86.539_249_963_643_88, rtol=1e-5) @@ -186,9 +188,9 @@ def test_bequest_log_regression() -> None: def test_bequest_crra_regression() -> None: scale = preferences.utility_scale_factor( - average_consumption_dollars=AVERAGE_CONSUMPTION, - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(5.0), + average_consumption_equiv=AVERAGE_CONSUMPTION, + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(5.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, @@ -196,7 +198,7 @@ def test_bequest_crra_regression() -> None: bwt = preferences.scaled_bequest_weight( bequest_weight=BEQUEST_WEIGHT, consumption_weight=CONSUMPTION_WEIGHT, - coefficient_rra=5.0, + coefficient_rra=jnp.asarray(5.0), time_endowment=TIME_ENDOWMENT, time_discount_factor=TIME_DISCOUNT_FACTOR, rate_of_return=RATE_OF_RETURN, @@ -204,9 +206,9 @@ def test_bequest_crra_regression() -> None: result = preferences.bequest( assets=jnp.array(10000.0), bequest_shifter=BEQUEST_SHIFTER, - scaled_bequest_weight=bwt.item(), - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(5.0), + scaled_bequest_weight=bwt, + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(5.0), utility_scale_factor=scale, ) assert jnp.isclose(result, -37.932_748_117_035_63, rtol=1e-5) diff --git a/tests/test_social_security.py b/tests/test_social_security.py index b8ac44a..c4c704f 100644 --- a/tests/test_social_security.py +++ b/tests/test_social_security.py @@ -23,10 +23,11 @@ PIA_CONVERSION_RATE_2 = 0.15 PIA_KINK_0 = 5151.6 PIA_KINK_1 = 14359.9 -AIME_ACCRUAL_FACTOR = 0.025 -AGGREGATE_WAGE_GROWTH = 0.03 -AIME_LAST_AGE_WITH_INDEXING = 59 -SSDI_SGA = 12840.0 +AIME_ACCRUAL_FACTOR = jnp.asarray(0.025) +AGGREGATE_WAGE_GROWTH = jnp.asarray(0.03) +AIME_LAST_AGE_WITH_INDEXING = jnp.int32(59) +AIME_KINK_2_SCALAR = jnp.asarray(AIME_KINK_2) +SSDI_SGA = jnp.asarray(12840.0) PIA_PARAMS = { "aime_kink_0": AIME_KINK_0, @@ -59,8 +60,8 @@ DI_SCALE = jnp.array( compute_di_dropout_scale( pd.Series(_RATIO_NP), - AIME_ACCRUAL_FACTOR, - start_age=jnp.int32(0), + AIME_ACCRUAL_FACTOR.item(), + start_age=0, n_periods=100, ) ) @@ -135,7 +136,7 @@ def test_next_aime_indexing_high_income() -> None: aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, - aime_kink_2=AIME_KINK_2, + aime_kink_2=AIME_KINK_2_SCALAR, ratio_lowest_earnings=RATIO, ) expected = 1000 * 1.03 + (20000 - 0.2 * 1000 * 1.03) * 0.025 @@ -156,7 +157,7 @@ def test_next_aime_indexing_low_income() -> None: aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, - aime_kink_2=AIME_KINK_2, + aime_kink_2=AIME_KINK_2_SCALAR, ratio_lowest_earnings=RATIO, ) assert jnp.isclose(result, 10000 * 1.03, atol=ATOL) @@ -176,7 +177,7 @@ def test_next_aime_no_indexing_high_income() -> None: aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, - aime_kink_2=AIME_KINK_2, + aime_kink_2=AIME_KINK_2_SCALAR, ratio_lowest_earnings=RATIO, ) expected = 1000 + (20000 - 0.4 * 1000) * 0.025 @@ -197,7 +198,7 @@ def test_next_aime_no_indexing_low_income() -> None: aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, - aime_kink_2=AIME_KINK_2, + aime_kink_2=AIME_KINK_2_SCALAR, ratio_lowest_earnings=RATIO, ) assert jnp.isclose(result, 1000, atol=ATOL) @@ -217,7 +218,7 @@ def test_next_aime_cap_high_aime_high_income() -> None: aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, - aime_kink_2=AIME_KINK_2, + aime_kink_2=AIME_KINK_2_SCALAR, ratio_lowest_earnings=RATIO, ) assert jnp.isclose(result, 39000, atol=ATOL) @@ -237,7 +238,7 @@ def test_next_aime_cap_high_aime_low_income() -> None: aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, - aime_kink_2=AIME_KINK_2, + aime_kink_2=AIME_KINK_2_SCALAR, ratio_lowest_earnings=RATIO, ) assert jnp.isclose(result, 39000, atol=ATOL) diff --git a/tests/test_ss_benefit_integration.py b/tests/test_ss_benefit_integration.py index 5e74e9a..81a1b61 100644 --- a/tests/test_ss_benefit_integration.py +++ b/tests/test_ss_benefit_integration.py @@ -64,7 +64,7 @@ def test_earnings_test_reduces_benefit_before_fra() -> None: earnings_test_threshold=jnp.array([17640.0]), earnings_test_fraction=jnp.array([0.5]), earnings_test_repealed_age=jnp.int32(66), - ssdi_substantial_gainful_activity=13560.0, + ssdi_substantial_gainful_activity=jnp.asarray(13560.0), ) benefit_not_working = social_security.benefit_choose_pre65( @@ -82,7 +82,7 @@ def test_earnings_test_reduces_benefit_before_fra() -> None: earnings_test_threshold=jnp.array([17640.0]), earnings_test_fraction=jnp.array([0.5]), earnings_test_repealed_age=jnp.int32(66), - ssdi_substantial_gainful_activity=13560.0, + ssdi_substantial_gainful_activity=jnp.asarray(13560.0), ) assert benefit_working < benefit_not_working