Skip to content

Commit c8b4742

Browse files
Add sparse FOC Jacobian to the household solve
1 parent 4a8ef99 commit c8b4742

10 files changed

Lines changed: 512 additions & 52 deletions

File tree

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [0.15.14] - 2026-06-03 12:00:00
9+
10+
### Added
11+
12+
- Adds an optional `use_sparse_FOC_jac` `Specifications` parameter (default off, so default runs are unchanged) that accelerates the time path iteration (TPI) household solve. When True, `scipy.optimize.root` is given a sparse (banded) finite-difference Jacobian for the stacked household Euler and labor first order conditions: the sparsity pattern is auto-detected once per problem size and the solver then needs far fewer function evaluations per Jacobian build (about 20x fewer on the default S=80 cohort solve), with an automatic fallback to dense finite differences if the Jacobian is not sparse enough to benefit or if a solve fails. The result matches the dense-finite-difference solution to within the model's resource-constraint accuracy floor on every calibration tested (OG-Core standard example, OG-ETH, OG-ZAF, OG-PHL, OG-IDN), giving roughly a 1.9-2.4x TPI speedup at no accuracy cost.
13+
814
## [0.15.13] - 2026-05-15 06:00:00
915

1016
### Added

docs/book/content/theory/derivations.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,16 @@ In the Cobb-Douglas unit elasticity case ($\varepsilon=1$) of the CES production
114114
```
115115

116116
Again, even if this simple case, we cannot solve for $r$ as a function of $w$ for the reasons above.
117+
118+
119+
(SecAppDerivHHjac)=
120+
## Sparsity of the household equation Jacobian
121+
122+
Holding fixed the prices and policies a type-$j$ cohort faces, its $2S$ stationarized necessary conditions {eq}`EqStnrz_eul_n`, {eq}`EqStnrz_eul_b`, and {eq}`EqStnrz_eul_bS` in the $2S$ unknowns $\{n_{j,s},\hat b_{j,s+1}\}_{s=E+1}^{E+S}$ have a banded Jacobian. From the budget constraint {eq}`EqStnrzHHBC`, stationarized consumption at age $s$ depends on only three unknowns,
123+
124+
```{math}
125+
:label: EqAppDerivHHjac_cons
126+
\hat c_{j,s} = \frac{1}{p}\Bigl[(1+r_p)\hat b_{j,s} + \hat w\,e_{j,s}\,n_{j,s} - \widehat{tax}_{j,s} - e^{g_y}\hat b_{j,s+1}\Bigr] + X_{j,s},
127+
```
128+
129+
where $\widehat{tax}_{j,s}$ depends only on $(\hat b_{j,s}, n_{j,s})$ through labor and capital income (already in the active set, so it adds no further coupling), and $X_{j,s}$ collects terms fixed in the inner solve (bequests $\hat{bq}_{j,s}$, remittances $\hat{rm}_{j,s}$, government transfers $\hat{tr}_{j,s}$, UBI $\hat{ubi}_{j,s}$, the pension benefit $\theta_j$, and the $\hat c_{min,i}$ terms). The labor Euler equation {eq}`EqStnrz_eul_n` at age $s$ therefore depends on $\{\hat b_{j,s},\hat b_{j,s+1},n_{j,s}\}$ alone, and the savings Euler equation {eq}`EqStnrz_eul_b`---which links $\hat c_{j,s}$ to $\hat b_{j,s+1}$ and $\hat c_{j,s+1}$---depends on $\{\hat b_{j,s},\hat b_{j,s+1},\hat b_{j,s+2},n_{j,s},n_{j,s+1}\}$. The marginal tax rates $\tau^{mtrx}_s$ and $\tau^{mtry}_{s+1}$ are functions of own-age income (already in these sets), so they add no further coupling, and the terminal condition {eq}`EqStnrz_eul_bS` is sparser still. Each of the $2S$ equations therefore depends on at most five of the $2S$ unknowns, regardless of $S$, so the Jacobian has at most $10S$ nonzero entries rather than the $(2S)^2 = 4S^2$ of a fully coupled system. This is the per-cohort counterpart to the dense $2JS$ system noted at the start of Chapter {ref}`Chap_Eqm`: cohorts couple only through prices, which are held fixed in the inner solve. A finite-difference Jacobian can then be built from a number of evaluations set by the bandwidth---about seven at $S = 80$---rather than $2S$, by probing together unknowns that affect no common equation (Figure {numref}`FigHHjacSparsity`).

docs/book/content/theory/equilibrium.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ In all of the specifications of `OG-Core`, we use a two-stage fixed point algori
2525

2626
Our approach is to choose the minimum number of macroeconomic variables in an outer loop in order to be able to solve the household's $2JS$ Euler equations in terms of only the $\bar{n}_{j,s}$ and $\bar{b}_{j,s+1}$ variables directly, holding all other variables constant. The household system of Euler equations has a provable root solution and is orders of magnitude more tractable (less nonlinear) to solve holding these outer loop variables constant.
2727

28+
Moreover, with the outer-loop variables held fixed, each cohort's system of $2S$ Euler equations is not only less nonlinear but structurally sparse: every equation involves at most five of the $2S$ unknowns---a household's own age and its immediate neighbors. The root finder normally probes each unknown separately when building each step ($2S = 160$ evaluations of the system when $S = 80$), but with most equations depending on only a handful of unknowns, those affecting no common equation can be probed together, cutting the count to about seven at $S = 80$---a number set by how many neighbors couple, not by $S$. The parameter `use_sparse_FOC_jac` (default `False`) turns this on; the solver falls back to the standard calculation otherwise. The structure is derived in Appendix {ref}`SecAppDerivHHjac`.
29+
30+
```{figure} ./images/HH_jac_sparsity.png
31+
---
32+
name: FigHHjacSparsity
33+
---
34+
Sparsity pattern of the household equation Jacobian, at $S = 12$. Left: the standard finite-difference solve treats every entry of the $2S\times 2S$ matrix as live ($(2S)^2 = 576$ entries). Right: the actual structure---each Euler equation depends only on a household's own age and its immediate neighbors, leaving most entries zero (92 of 576 here; 636 of 25{,}600 at the default $S = 80$).
35+
```
36+
2837
The steady-state solution method for each of the cases above is associated with a solution method that has a subset of the following outer-loop variables $\Bigl\{\bar{r}_p, \bar{r}, \bar{w}, \{\bar{p}_m\}_{m=1}^{M-1}, \bar{Y}, \overline{TR}, \overline{BQ}, factor\Bigr\}$.
2938

3039

65.5 KB
Loading
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
"""
2+
Side-by-side comparison: a baseline run with ``use_sparse_FOC_jac`` off
3+
(default) and on. Reports the wall-time speedup, diffs the converged
4+
steady-state and TPI paths, prints the resource-constraint residual, and
5+
issues a NO DRIFT / DRIFT DETECTED verdict against a 0.1% threshold.
6+
7+
With no arguments, runs OG-Core's standard example baseline (the same
8+
configuration as ``run_ogcore_example.py``). With a country package name
9+
(e.g. ``ogphl``, ``ogzaf``, ``ogidn``, ``ogeth``) as a single argument,
10+
runs that country's packaged baseline twice. The country package must be
11+
importable in the active environment; outputs land in the current working
12+
directory.
13+
14+
The reform leg is skipped; this is about solver speed and correctness on
15+
a single run.
16+
17+
Run from the repo root:
18+
19+
python examples/run_sparse_FOC_jac_compare.py # OG-Core
20+
python examples/run_sparse_FOC_jac_compare.py ogphl # PHL
21+
"""
22+
23+
# import modules
24+
import importlib
25+
import json
26+
import multiprocessing
27+
import os
28+
import sys
29+
import time
30+
from importlib.resources import files
31+
32+
import numpy as np
33+
from distributed import Client
34+
35+
from ogcore.execute import runner
36+
from ogcore.parameters import Specifications
37+
from ogcore.utils import safe_read_pickle
38+
39+
40+
# Default config for OG-Core mode (no country arg). Matches
41+
# run_ogcore_example.py.
42+
_alpha_T = np.zeros(50)
43+
_alpha_T[0:2] = 0.09
44+
_alpha_T[2:10] = 0.09 + 0.01
45+
_alpha_T[10:40] = 0.09 - 0.01
46+
_alpha_T[40:] = 0.09
47+
_alpha_G = np.zeros(7)
48+
_alpha_G[0:3] = 0.05 - 0.01
49+
_alpha_G[3:6] = 0.05 - 0.005
50+
_alpha_G[6:] = 0.05
51+
OGCORE_SPEC = {
52+
"frisch": 0.41,
53+
"start_year": 2021,
54+
"cit_rate": [[0.21]],
55+
"debt_ratio_ss": 1.0,
56+
"alpha_T": _alpha_T.tolist(),
57+
"alpha_G": _alpha_G.tolist(),
58+
"initial_guess_r_SS": 0.04,
59+
}
60+
61+
KEY_AGGREGATES = (
62+
"Y",
63+
"C",
64+
"K",
65+
"L",
66+
"B",
67+
"I_total",
68+
"r",
69+
"w",
70+
"r_p",
71+
"r_gov",
72+
"TR",
73+
"total_tax_revenue",
74+
"D",
75+
"BQ",
76+
)
77+
78+
# "No drift" threshold: aggregate differences within 0.1% are economically
79+
# indistinguishable from the model's own convergence noise.
80+
NO_DRIFT_THRESHOLD = 1e-3
81+
82+
83+
def _max_rel_diff(a, b):
84+
a = np.asarray(a, dtype=float)
85+
b = np.asarray(b, dtype=float)
86+
if a.shape != b.shape or a.size == 0:
87+
return float("nan")
88+
scale = max(float(np.max(np.abs(a))), 1e-300)
89+
return float(np.max(np.abs(a - b))) / scale
90+
91+
92+
def _diff_dict(d_dense, d_sparse, var_list=None):
93+
"""Return [(var, rel_diff), ...] sorted by rel_diff descending."""
94+
keys = (
95+
var_list
96+
if var_list is not None
97+
else sorted(set(d_dense) & set(d_sparse))
98+
)
99+
out = []
100+
for var in keys:
101+
if var in d_dense and var in d_sparse:
102+
try:
103+
rel = _max_rel_diff(d_dense[var], d_sparse[var])
104+
except (TypeError, ValueError):
105+
continue
106+
if rel == rel:
107+
out.append((var, rel))
108+
out.sort(key=lambda x: -x[1])
109+
return out
110+
111+
112+
def _load_country_defaults(pkg):
113+
"""Load <pkg>_default_parameters.json, with a 2-D shim for older country
114+
calibrations whose replacement_rate_adjust is still 1-D."""
115+
with files(pkg).joinpath(f"{pkg}_default_parameters.json").open("r") as f:
116+
defaults = json.load(f)
117+
rra = defaults.get("replacement_rate_adjust")
118+
if isinstance(rra, list) and rra and not isinstance(rra[0], list):
119+
defaults["replacement_rate_adjust"] = [rra]
120+
return defaults
121+
122+
123+
def _apply_country_calibration(pkg, p):
124+
"""Try the country's offline Calibration; quietly skip on error."""
125+
try:
126+
Cal = importlib.import_module(pkg + ".calibrate").Calibration
127+
try:
128+
c = Cal(p, update_from_api=False)
129+
except TypeError:
130+
c = Cal(p)
131+
p.update_specifications(c.get_dict())
132+
except Exception as e:
133+
print(f" (calibration skipped: {type(e).__name__}: {str(e)[:80]})")
134+
135+
136+
def _run_one(label, country_pkg, out_dir, num_workers, client, sparse_jac):
137+
p = Specifications(
138+
baseline=True,
139+
num_workers=num_workers,
140+
baseline_dir=out_dir,
141+
output_base=out_dir,
142+
)
143+
if country_pkg is None:
144+
p.update_specifications(OGCORE_SPEC)
145+
else:
146+
p.update_specifications(_load_country_defaults(country_pkg))
147+
_apply_country_calibration(country_pkg, p)
148+
p.update_specifications({"use_sparse_FOC_jac": bool(sparse_jac)})
149+
print(f"\n[{label}] use_sparse_FOC_jac = {p.use_sparse_FOC_jac}")
150+
start = time.time()
151+
runner(p, time_path=True, client=client)
152+
wall = time.time() - start
153+
print(f"[{label}] wall time = {wall:.2f} s")
154+
ss = safe_read_pickle(os.path.join(out_dir, "SS", "SS_vars.pkl"))
155+
tpi = safe_read_pickle(os.path.join(out_dir, "TPI", "TPI_vars.pkl"))
156+
return wall, ss, tpi
157+
158+
159+
def main(country_pkg=None):
160+
num_workers = min(multiprocessing.cpu_count(), 7)
161+
label = country_pkg if country_pkg else "ogcore (standard example)"
162+
print(f"Workers: {num_workers} | model: {label}")
163+
client = Client(n_workers=num_workers, threads_per_worker=1)
164+
165+
# Outputs land in the current working directory so they're easy to find
166+
# regardless of where this script file lives.
167+
root = os.path.join(
168+
os.getcwd(),
169+
"sparse-FOC-jac-compare",
170+
country_pkg if country_pkg else "ogcore",
171+
)
172+
dense_dir = os.path.join(root, "dense")
173+
sparse_dir = os.path.join(root, "sparse")
174+
175+
t_dense, ss_dense, tpi_dense = _run_one(
176+
f"{label} DENSE (default)",
177+
country_pkg,
178+
dense_dir,
179+
num_workers,
180+
client,
181+
False,
182+
)
183+
t_sparse, ss_sparse, tpi_sparse = _run_one(
184+
f"{label} SPARSE (use_sparse_FOC_jac=True)",
185+
country_pkg,
186+
sparse_dir,
187+
num_workers,
188+
client,
189+
True,
190+
)
191+
192+
tpi_diffs = _diff_dict(tpi_dense, tpi_sparse, KEY_AGGREGATES)
193+
ss_diffs = _diff_dict(ss_dense, ss_sparse)
194+
195+
tpi_worst_var, tpi_worst = tpi_diffs[0] if tpi_diffs else ("n/a", 0.0)
196+
ss_worst_var, ss_worst = ss_diffs[0] if ss_diffs else ("n/a", 0.0)
197+
worst = max(tpi_worst, ss_worst)
198+
199+
rc_d = float(
200+
np.max(np.abs(tpi_dense.get("resource_constraint_error", np.zeros(1))))
201+
)
202+
rc_s = float(
203+
np.max(
204+
np.abs(tpi_sparse.get("resource_constraint_error", np.zeros(1)))
205+
)
206+
)
207+
208+
speedup = t_dense / t_sparse if t_sparse > 0 else float("inf")
209+
bar = "=" * 64
210+
print()
211+
print(bar)
212+
print(f" MODEL: {label}")
213+
print(" SPEED")
214+
print(f" dense : {t_dense:7.2f} s")
215+
print(f" sparse : {t_sparse:7.2f} s -> {speedup:.2f}x faster")
216+
print()
217+
print(" DRIFT (max relative difference, sparse vs dense)")
218+
print(
219+
f" TPI worst: {tpi_worst_var:22s} "
220+
f"{tpi_worst * 100:9.4f}% ({tpi_worst:.2e})"
221+
)
222+
print(
223+
f" SS worst: {ss_worst_var:22s} "
224+
f"{ss_worst * 100:9.4f}% ({ss_worst:.2e})"
225+
)
226+
if tpi_diffs:
227+
print()
228+
print(" All TPI aggregates, sorted by drift:")
229+
for var, rel in tpi_diffs:
230+
print(f" {var:22s} {rel * 100:9.4f}% ({rel:.2e})")
231+
print()
232+
print(" ACCURACY FLOOR (resource-constraint residual)")
233+
print(f" dense : {rc_d:.2e}")
234+
print(f" sparse : {rc_s:.2e}")
235+
print()
236+
threshold_pct = NO_DRIFT_THRESHOLD * 100
237+
if worst <= NO_DRIFT_THRESHOLD:
238+
print(
239+
f" RESULT: NO DRIFT "
240+
f"(worst {worst * 100:.4f}% <= {threshold_pct:g}% threshold)"
241+
)
242+
else:
243+
print(
244+
f" RESULT: DRIFT DETECTED "
245+
f"(worst {worst * 100:.4f}% > {threshold_pct:g}% threshold) "
246+
f"-- investigate"
247+
)
248+
print(bar)
249+
250+
client.close()
251+
252+
253+
if __name__ == "__main__":
254+
main(sys.argv[1] if len(sys.argv) > 1 else None)

0 commit comments

Comments
 (0)