|
| 1 | +""" |
| 2 | +Side-by-side comparison: a baseline run with ``use_sparse_FOC_jac`` off |
| 3 | +(default) and on. Reports the wall-time speedup, diffs the converged |
| 4 | +steady-state and TPI paths, prints the resource-constraint residual, and |
| 5 | +issues a NO DRIFT / DRIFT DETECTED verdict against a 0.1% threshold. |
| 6 | +
|
| 7 | +With no arguments, runs OG-Core's standard example baseline (the same |
| 8 | +configuration as ``run_ogcore_example.py``). With a country package name |
| 9 | +(e.g. ``ogphl``, ``ogzaf``, ``ogidn``, ``ogeth``) as a single argument, |
| 10 | +runs that country's packaged baseline twice. The country package must be |
| 11 | +importable in the active environment; outputs land in the current working |
| 12 | +directory. |
| 13 | +
|
| 14 | +The reform leg is skipped; this is about solver speed and correctness on |
| 15 | +a single run. |
| 16 | +
|
| 17 | +Run from the repo root: |
| 18 | +
|
| 19 | + python examples/run_sparse_FOC_jac_compare.py # OG-Core |
| 20 | + python examples/run_sparse_FOC_jac_compare.py ogphl # PHL |
| 21 | +""" |
| 22 | + |
| 23 | +# import modules |
| 24 | +import importlib |
| 25 | +import json |
| 26 | +import multiprocessing |
| 27 | +import os |
| 28 | +import sys |
| 29 | +import time |
| 30 | +from importlib.resources import files |
| 31 | + |
| 32 | +import numpy as np |
| 33 | +from distributed import Client |
| 34 | + |
| 35 | +from ogcore.execute import runner |
| 36 | +from ogcore.parameters import Specifications |
| 37 | +from ogcore.utils import safe_read_pickle |
| 38 | + |
| 39 | + |
| 40 | +# Default config for OG-Core mode (no country arg). Matches |
| 41 | +# run_ogcore_example.py. |
| 42 | +_alpha_T = np.zeros(50) |
| 43 | +_alpha_T[0:2] = 0.09 |
| 44 | +_alpha_T[2:10] = 0.09 + 0.01 |
| 45 | +_alpha_T[10:40] = 0.09 - 0.01 |
| 46 | +_alpha_T[40:] = 0.09 |
| 47 | +_alpha_G = np.zeros(7) |
| 48 | +_alpha_G[0:3] = 0.05 - 0.01 |
| 49 | +_alpha_G[3:6] = 0.05 - 0.005 |
| 50 | +_alpha_G[6:] = 0.05 |
| 51 | +OGCORE_SPEC = { |
| 52 | + "frisch": 0.41, |
| 53 | + "start_year": 2021, |
| 54 | + "cit_rate": [[0.21]], |
| 55 | + "debt_ratio_ss": 1.0, |
| 56 | + "alpha_T": _alpha_T.tolist(), |
| 57 | + "alpha_G": _alpha_G.tolist(), |
| 58 | + "initial_guess_r_SS": 0.04, |
| 59 | +} |
| 60 | + |
| 61 | +KEY_AGGREGATES = ( |
| 62 | + "Y", |
| 63 | + "C", |
| 64 | + "K", |
| 65 | + "L", |
| 66 | + "B", |
| 67 | + "I_total", |
| 68 | + "r", |
| 69 | + "w", |
| 70 | + "r_p", |
| 71 | + "r_gov", |
| 72 | + "TR", |
| 73 | + "total_tax_revenue", |
| 74 | + "D", |
| 75 | + "BQ", |
| 76 | +) |
| 77 | + |
| 78 | +# "No drift" threshold: aggregate differences within 0.1% are economically |
| 79 | +# indistinguishable from the model's own convergence noise. |
| 80 | +NO_DRIFT_THRESHOLD = 1e-3 |
| 81 | + |
| 82 | + |
| 83 | +def _max_rel_diff(a, b): |
| 84 | + a = np.asarray(a, dtype=float) |
| 85 | + b = np.asarray(b, dtype=float) |
| 86 | + if a.shape != b.shape or a.size == 0: |
| 87 | + return float("nan") |
| 88 | + scale = max(float(np.max(np.abs(a))), 1e-300) |
| 89 | + return float(np.max(np.abs(a - b))) / scale |
| 90 | + |
| 91 | + |
| 92 | +def _diff_dict(d_dense, d_sparse, var_list=None): |
| 93 | + """Return [(var, rel_diff), ...] sorted by rel_diff descending.""" |
| 94 | + keys = ( |
| 95 | + var_list |
| 96 | + if var_list is not None |
| 97 | + else sorted(set(d_dense) & set(d_sparse)) |
| 98 | + ) |
| 99 | + out = [] |
| 100 | + for var in keys: |
| 101 | + if var in d_dense and var in d_sparse: |
| 102 | + try: |
| 103 | + rel = _max_rel_diff(d_dense[var], d_sparse[var]) |
| 104 | + except (TypeError, ValueError): |
| 105 | + continue |
| 106 | + if rel == rel: |
| 107 | + out.append((var, rel)) |
| 108 | + out.sort(key=lambda x: -x[1]) |
| 109 | + return out |
| 110 | + |
| 111 | + |
| 112 | +def _load_country_defaults(pkg): |
| 113 | + """Load <pkg>_default_parameters.json, with a 2-D shim for older country |
| 114 | + calibrations whose replacement_rate_adjust is still 1-D.""" |
| 115 | + with files(pkg).joinpath(f"{pkg}_default_parameters.json").open("r") as f: |
| 116 | + defaults = json.load(f) |
| 117 | + rra = defaults.get("replacement_rate_adjust") |
| 118 | + if isinstance(rra, list) and rra and not isinstance(rra[0], list): |
| 119 | + defaults["replacement_rate_adjust"] = [rra] |
| 120 | + return defaults |
| 121 | + |
| 122 | + |
| 123 | +def _apply_country_calibration(pkg, p): |
| 124 | + """Try the country's offline Calibration; quietly skip on error.""" |
| 125 | + try: |
| 126 | + Cal = importlib.import_module(pkg + ".calibrate").Calibration |
| 127 | + try: |
| 128 | + c = Cal(p, update_from_api=False) |
| 129 | + except TypeError: |
| 130 | + c = Cal(p) |
| 131 | + p.update_specifications(c.get_dict()) |
| 132 | + except Exception as e: |
| 133 | + print(f" (calibration skipped: {type(e).__name__}: {str(e)[:80]})") |
| 134 | + |
| 135 | + |
| 136 | +def _run_one(label, country_pkg, out_dir, num_workers, client, sparse_jac): |
| 137 | + p = Specifications( |
| 138 | + baseline=True, |
| 139 | + num_workers=num_workers, |
| 140 | + baseline_dir=out_dir, |
| 141 | + output_base=out_dir, |
| 142 | + ) |
| 143 | + if country_pkg is None: |
| 144 | + p.update_specifications(OGCORE_SPEC) |
| 145 | + else: |
| 146 | + p.update_specifications(_load_country_defaults(country_pkg)) |
| 147 | + _apply_country_calibration(country_pkg, p) |
| 148 | + p.update_specifications({"use_sparse_FOC_jac": bool(sparse_jac)}) |
| 149 | + print(f"\n[{label}] use_sparse_FOC_jac = {p.use_sparse_FOC_jac}") |
| 150 | + start = time.time() |
| 151 | + runner(p, time_path=True, client=client) |
| 152 | + wall = time.time() - start |
| 153 | + print(f"[{label}] wall time = {wall:.2f} s") |
| 154 | + ss = safe_read_pickle(os.path.join(out_dir, "SS", "SS_vars.pkl")) |
| 155 | + tpi = safe_read_pickle(os.path.join(out_dir, "TPI", "TPI_vars.pkl")) |
| 156 | + return wall, ss, tpi |
| 157 | + |
| 158 | + |
| 159 | +def main(country_pkg=None): |
| 160 | + num_workers = min(multiprocessing.cpu_count(), 7) |
| 161 | + label = country_pkg if country_pkg else "ogcore (standard example)" |
| 162 | + print(f"Workers: {num_workers} | model: {label}") |
| 163 | + client = Client(n_workers=num_workers, threads_per_worker=1) |
| 164 | + |
| 165 | + # Outputs land in the current working directory so they're easy to find |
| 166 | + # regardless of where this script file lives. |
| 167 | + root = os.path.join( |
| 168 | + os.getcwd(), |
| 169 | + "sparse-FOC-jac-compare", |
| 170 | + country_pkg if country_pkg else "ogcore", |
| 171 | + ) |
| 172 | + dense_dir = os.path.join(root, "dense") |
| 173 | + sparse_dir = os.path.join(root, "sparse") |
| 174 | + |
| 175 | + t_dense, ss_dense, tpi_dense = _run_one( |
| 176 | + f"{label} DENSE (default)", |
| 177 | + country_pkg, |
| 178 | + dense_dir, |
| 179 | + num_workers, |
| 180 | + client, |
| 181 | + False, |
| 182 | + ) |
| 183 | + t_sparse, ss_sparse, tpi_sparse = _run_one( |
| 184 | + f"{label} SPARSE (use_sparse_FOC_jac=True)", |
| 185 | + country_pkg, |
| 186 | + sparse_dir, |
| 187 | + num_workers, |
| 188 | + client, |
| 189 | + True, |
| 190 | + ) |
| 191 | + |
| 192 | + tpi_diffs = _diff_dict(tpi_dense, tpi_sparse, KEY_AGGREGATES) |
| 193 | + ss_diffs = _diff_dict(ss_dense, ss_sparse) |
| 194 | + |
| 195 | + tpi_worst_var, tpi_worst = tpi_diffs[0] if tpi_diffs else ("n/a", 0.0) |
| 196 | + ss_worst_var, ss_worst = ss_diffs[0] if ss_diffs else ("n/a", 0.0) |
| 197 | + worst = max(tpi_worst, ss_worst) |
| 198 | + |
| 199 | + rc_d = float( |
| 200 | + np.max(np.abs(tpi_dense.get("resource_constraint_error", np.zeros(1)))) |
| 201 | + ) |
| 202 | + rc_s = float( |
| 203 | + np.max( |
| 204 | + np.abs(tpi_sparse.get("resource_constraint_error", np.zeros(1))) |
| 205 | + ) |
| 206 | + ) |
| 207 | + |
| 208 | + speedup = t_dense / t_sparse if t_sparse > 0 else float("inf") |
| 209 | + bar = "=" * 64 |
| 210 | + print() |
| 211 | + print(bar) |
| 212 | + print(f" MODEL: {label}") |
| 213 | + print(" SPEED") |
| 214 | + print(f" dense : {t_dense:7.2f} s") |
| 215 | + print(f" sparse : {t_sparse:7.2f} s -> {speedup:.2f}x faster") |
| 216 | + print() |
| 217 | + print(" DRIFT (max relative difference, sparse vs dense)") |
| 218 | + print( |
| 219 | + f" TPI worst: {tpi_worst_var:22s} " |
| 220 | + f"{tpi_worst * 100:9.4f}% ({tpi_worst:.2e})" |
| 221 | + ) |
| 222 | + print( |
| 223 | + f" SS worst: {ss_worst_var:22s} " |
| 224 | + f"{ss_worst * 100:9.4f}% ({ss_worst:.2e})" |
| 225 | + ) |
| 226 | + if tpi_diffs: |
| 227 | + print() |
| 228 | + print(" All TPI aggregates, sorted by drift:") |
| 229 | + for var, rel in tpi_diffs: |
| 230 | + print(f" {var:22s} {rel * 100:9.4f}% ({rel:.2e})") |
| 231 | + print() |
| 232 | + print(" ACCURACY FLOOR (resource-constraint residual)") |
| 233 | + print(f" dense : {rc_d:.2e}") |
| 234 | + print(f" sparse : {rc_s:.2e}") |
| 235 | + print() |
| 236 | + threshold_pct = NO_DRIFT_THRESHOLD * 100 |
| 237 | + if worst <= NO_DRIFT_THRESHOLD: |
| 238 | + print( |
| 239 | + f" RESULT: NO DRIFT " |
| 240 | + f"(worst {worst * 100:.4f}% <= {threshold_pct:g}% threshold)" |
| 241 | + ) |
| 242 | + else: |
| 243 | + print( |
| 244 | + f" RESULT: DRIFT DETECTED " |
| 245 | + f"(worst {worst * 100:.4f}% > {threshold_pct:g}% threshold) " |
| 246 | + f"-- investigate" |
| 247 | + ) |
| 248 | + print(bar) |
| 249 | + |
| 250 | + client.close() |
| 251 | + |
| 252 | + |
| 253 | +if __name__ == "__main__": |
| 254 | + main(sys.argv[1] if len(sys.argv) > 1 else None) |
0 commit comments