Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ jobs:
- uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Install pylcm (feature branch — revert to @main once pylcm#350 merges)
- name: Install pylcm (feature branch — revert to @main once pylcm#348/#350 merge)
run: >-
pip install "pylcm @
git+https://github.com/OpenSourceEconomics/pylcm.git@feat/categorical-scalarint"
git+https://github.com/OpenSourceEconomics/pylcm.git@feat/runtime-grid-extra-params"
- name: Install aca-model with test deps
run: pip install -e . pytest pdbp
- name: Run pytest
Expand Down
Binary file modified src/aca_model/_benchmark_data/benchmark_params.pkl
Binary file not shown.
11 changes: 1 addition & 10 deletions src/aca_model/baseline/regimes/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,16 +203,6 @@ class Grids:
_AIME_PIECE_N_POINTS: tuple[int, int, int] = (10, 11, 11)


MAX_CONSUMPTION_DOLLARS: float = 300_000.0
"""Upper bound of the runtime consumption_dollars grid in $/year.

Lives here next to the other grid bounds (assets `stop=500_000.0`,
AIME `stop=8_000.0`).

TODO: route through `fixed_params` once pylcm#348 lands (so the bound
can vary across optimizer iterations without re-importing this module).
"""

# AR(1) persistence of the Rouwenhorst shocks. Calibrated once; not
# routed through fixed_params because they shape the grid topology
# rather than feed any DAG function. The Rouwenhorst innovation std is
Expand Down Expand Up @@ -275,6 +265,7 @@ def build_grids(
aime=_build_aime_grid(grid_config=grid_config, fixed_params=fixed_params),
consumption_dollars=IrregSpacedGrid(
n_points=grid_config.n_consumption_dollars_gridpoints,
extra_param_names=("max_consumption_dollars",),
),
wage_res=wage_res,
hcc_persistent=hcc_persistent,
Expand Down
33 changes: 19 additions & 14 deletions src/aca_model/consumption_dollars_grid.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Runtime-supplied gridpoints for the consumption_dollars action.

Consumption is declared as `IrregSpacedGrid(n_points=N)` in
Consumption is declared as `IrregSpacedGrid(n_points=N,
extra_param_names=("max_consumption_dollars",))` in
`baseline.regimes._common.build_grids` so the bounds can track
runtime parameters: the lower bound from the per-iteration
`consumption_equiv_floor` parameter (and its couples-scaled twin),
the upper bound from `MAX_CONSUMPTION_DOLLARS` in
`baseline.regimes._common`. Callers must inject the actual gridpoints
into `params` via `inject_consumption_dollars_points` before calling
`model.solve()` / `model.simulate()`.
the upper bound from `max_consumption_dollars` carried through
`fixed_params` (per pylcm#348). Callers must inject the actual
gridpoints into `params` via `inject_consumption_dollars_points`
before calling `model.solve()` / `model.simulate()`.

The grid pins the two regime-relevant transfer-floor levels exactly
on the action grid so the borrowing constraint's
Expand All @@ -16,7 +17,7 @@

- `pts[0] = consumption_equiv_floor` (single household: equiv_scale=1)
- `pts[1] = consumption_equiv_floor * 2 ** exponent` (married)
- `pts[2:] = geomspace(pts[1], MAX_CONSUMPTION_DOLLARS, n_points - 1)`
- `pts[2:] = geomspace(pts[1], max_consumption_dollars, n_points - 1)`
"""

from collections.abc import Mapping
Expand All @@ -26,8 +27,6 @@
from jax import Array
from lcm import IrregSpacedGrid, Model

from aca_model.baseline.regimes._common import MAX_CONSUMPTION_DOLLARS


def inject_consumption_dollars_points(
*,
Expand All @@ -41,15 +40,16 @@ def inject_consumption_dollars_points(

The lower two gridpoints are the single and married Dollar-valued
transfer floors; the rest are geomspaced from the married floor up
to `MAX_CONSUMPTION_DOLLARS`.
to `model.fixed_params["max_consumption_dollars"]`.

Args:
params: Existing params mapping with `consumption_equiv_floor`
(per-equivalent floor, varies per iteration). Returned as a
new dict; the input is not mutated.
model: Model whose regimes carry the runtime-points grid and
whose `fixed_params["exponent"]` sets the married
equivalence-scale exponent.
whose `fixed_params` supplies `exponent` (married
equivalence-scale exponent) and `max_consumption_dollars`
(grid upper bound).

Returns:
New params dict with consumption_dollars points injected.
Expand All @@ -61,6 +61,9 @@ def inject_consumption_dollars_points(
"""
consumption_equiv_floor = jnp.asarray(params["consumption_equiv_floor"])
exponent = jnp.asarray(model.fixed_params["exponent"])
max_consumption_dollars = jnp.asarray(
model.fixed_params["max_consumption_dollars"]
)
out: dict[str, Any] = dict(params)
for regime_name, regime in model.regimes.items():
if regime.terminal:
Expand All @@ -85,6 +88,7 @@ def inject_consumption_dollars_points(
points = _compute_consumption_dollars_points(
consumption_equiv_floor=consumption_equiv_floor,
exponent=exponent,
max_consumption_dollars=max_consumption_dollars,
n_points=grid.n_points,
)
regime_entry = dict(out.get(regime_name, {}))
Expand All @@ -97,6 +101,7 @@ def _compute_consumption_dollars_points(
*,
consumption_equiv_floor: Array,
exponent: Array,
max_consumption_dollars: Array,
n_points: int,
) -> Array:
"""Return log-spaced consumption_dollars gridpoints with both floors pinned.
Expand All @@ -108,12 +113,12 @@ def _compute_consumption_dollars_points(
a feasible action; otherwise sub-ULP drift can flip the `<=`
comparison for subjects with very negative cash. The geomspace
tail starts at the married floor and runs to
`MAX_CONSUMPTION_DOLLARS` so the two pinned points stay strictly
`max_consumption_dollars` so the two pinned points stay strictly
increasing.
"""
married_dollar_floor = consumption_equiv_floor * jnp.asarray(2.0) ** exponent
tail = jnp.geomspace(
married_dollar_floor, MAX_CONSUMPTION_DOLLARS, num=n_points - 1
married_dollar_floor, max_consumption_dollars, num=n_points - 1
)
pts = jnp.concatenate([consumption_equiv_floor[None], tail])
# `jnp.geomspace` returns `start * r^0` for the first tail element,
Expand All @@ -129,7 +134,7 @@ def _compute_consumption_dollars_points(
msg = (
f"consumption_dollars grid is not strictly increasing at the "
f"married-floor kink: pts[1]={float(married_dollar_floor):.6g}, "
f"pts[2]={float(pts[2]):.6g}. Either `MAX_CONSUMPTION_DOLLARS` "
f"pts[2]={float(pts[2]):.6g}. Either `max_consumption_dollars` "
f"is too close to the married floor or `n_points` is too small."
)
raise ValueError(msg)
Expand Down
13 changes: 9 additions & 4 deletions tests/test_consumption_dollars_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,21 @@
rejects every action for the affected subjects.

`_compute_consumption_dollars_points` therefore prepends the singles'
floor as `pts[0]`, runs `geomspace` from the married floor up to
`MAX_CONSUMPTION_DOLLARS` for the rest, and pins the geomspace start
back to the married floor exactly. Test those invariants directly.
floor as `pts[0]`, runs `geomspace` from the married floor up to the
caller-supplied `max_consumption_dollars` for the rest, and pins the
geomspace start back to the married floor exactly. Test those invariants
directly.
"""

import jax.numpy as jnp
import pytest

from aca_model.baseline.regimes._common import MAX_CONSUMPTION_DOLLARS
from aca_model.consumption_dollars_grid import _compute_consumption_dollars_points

EXPONENT = 0.7 # production value (env_constants["exponent"])
SINGLE_FLOOR = 1597.0921419521899 # production value
MARRIED_SCALE = 2.0**EXPONENT
MAX_CONSUMPTION_DOLLARS = 300_000.0 # production value (env_constants)


@pytest.mark.parametrize("n_points", [5, 16, 64, 70, 100])
Expand All @@ -46,6 +47,7 @@ def test_compute_consumption_dollars_points_first_equals_singles_floor(
pts = _compute_consumption_dollars_points(
consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR),
exponent=jnp.asarray(EXPONENT),
max_consumption_dollars=jnp.asarray(MAX_CONSUMPTION_DOLLARS),
n_points=n_points,
)
assert float(pts[0]) == SINGLE_FLOOR
Expand All @@ -59,6 +61,7 @@ def test_compute_consumption_dollars_points_second_equals_married_floor(
pts = _compute_consumption_dollars_points(
consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR),
exponent=jnp.asarray(EXPONENT),
max_consumption_dollars=jnp.asarray(MAX_CONSUMPTION_DOLLARS),
n_points=n_points,
)
expected = float(jnp.asarray(SINGLE_FLOOR) * jnp.asarray(2.0) ** EXPONENT)
Expand All @@ -70,6 +73,7 @@ def test_compute_consumption_dollars_points_strictly_increasing() -> None:
pts = _compute_consumption_dollars_points(
consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR),
exponent=jnp.asarray(EXPONENT),
max_consumption_dollars=jnp.asarray(MAX_CONSUMPTION_DOLLARS),
n_points=70,
)
diffs = jnp.diff(pts)
Expand All @@ -81,6 +85,7 @@ def test_compute_consumption_dollars_points_last_equals_max() -> None:
pts = _compute_consumption_dollars_points(
consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR),
exponent=jnp.asarray(EXPONENT),
max_consumption_dollars=jnp.asarray(MAX_CONSUMPTION_DOLLARS),
n_points=70,
)
assert float(pts[-1]) == pytest.approx(MAX_CONSUMPTION_DOLLARS)