Skip to content

Commit e639f18

Browse files
hmgaudeckerclaude
andcommitted
IrregSpacedGrid: extra_param_names for user-side runtime scalars (#348)
`broadcast_to_template` rejects any `fixed_params` key that no DAG function lists in its signature. That assumption breaks for runtime-supplied grids whose points are computed by user-side code from per-iteration scalars (e.g., a grid upper bound) — pylcm itself never reads those scalars, but they still need to flow through the params machinery instead of being funneled through `simulate(params=...)` where they compete with estimation parameters. This adds an `extra_param_names: tuple[str, ...]` keyword to `IrregSpacedGrid`. When `pass_points_at_runtime=True`, each name in the tuple surfaces as a `ScalarFloat` slot in the action/state's pseudo-function template entry alongside `points`. The user's injection code reads the resolved values from `model.fixed_params` (or per-iteration params) before computing the points; pylcm just threads them through `broadcast_to_template` unchanged. Rejected on fixed-points grids — a baked-in grid has no user-side runtime computation to feed. Closes #348. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 3f1d1e4 commit e639f18

3 files changed

Lines changed: 101 additions & 2 deletions

File tree

src/lcm/grids/continuous.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,21 @@ class IrregSpacedGrid(ContinuousGrid):
227227
omitted and only `n_points` is given, the points must be supplied at
228228
runtime via the params.
229229
230+
`extra_param_names` declares additional scalar parameters consumed by
231+
user-side code that *constructs* the runtime points (e.g., a grid
232+
upper bound that changes across optimizer iterations). These names
233+
enter the params template alongside `points` and pass through
234+
`broadcast_to_template` even though no DAG function references them
235+
— the user's injection code reads them from the resolved params /
236+
`model.fixed_params` to derive the points before calling
237+
`solve` / `simulate`.
238+
230239
Example:
231240
--------
232241
Fixed grid: `IrregSpacedGrid(points=[-1.73, -0.58, 0.58, 1.73])` Grid that is only
233242
completed at runtime via params: `IrregSpacedGrid(n_points=4)`
243+
Grid that pairs runtime `points` with an extra scalar bound:
244+
`IrregSpacedGrid(n_points=4, extra_param_names=("max_consumption",))`
234245
235246
"""
236247

@@ -240,12 +251,23 @@ class IrregSpacedGrid(ContinuousGrid):
240251
n_points: int
241252
"""Number of points. Derived from `len(points)` when points are given."""
242253

254+
extra_param_names: tuple[str, ...]
255+
"""Names of additional scalar params surfaced in the template.
256+
257+
Only meaningful when points are supplied at runtime
258+
(`pass_points_at_runtime=True`); pylcm itself never reads these
259+
values — they're carried through the params machinery so user-side
260+
injection code can pick them up without fighting the `Unknown keys`
261+
validator.
262+
"""
263+
243264
def __init__(
244265
self,
245266
*,
246267
points: Sequence[float] | Float1D | None = None,
247268
n_points: int | None = None,
248269
batch_size: int = 0,
270+
extra_param_names: tuple[str, ...] = (),
249271
) -> None:
250272
if points is not None:
251273
_validate_irreg_spaced_grid(points)
@@ -269,9 +291,16 @@ def __init__(
269291
)
270292
else:
271293
stored_points = None
294+
if extra_param_names and stored_points is not None:
295+
raise GridInitializationError(
296+
"`extra_param_names` is only valid when points are supplied at "
297+
"runtime (i.e. `points=None`); a fixed-points grid has no "
298+
"user-side params to thread through."
299+
)
272300
object.__setattr__(self, "points", stored_points)
273301
object.__setattr__(self, "n_points", n_points)
274302
object.__setattr__(self, "batch_size", batch_size)
303+
object.__setattr__(self, "extra_param_names", tuple(extra_param_names))
275304

276305
@property
277306
def pass_points_at_runtime(self) -> bool:

src/lcm/params/regime_template.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _add_runtime_grid_params(
9292
_fail_if_runtime_grid_shadows_function(
9393
function_params=function_params, name=state_name, kind="state"
9494
)
95-
function_params[state_name] = {"points": "Float1D"}
95+
function_params[state_name] = _irreg_grid_template_entry(grid)
9696
elif isinstance(grid, _ShockGrid) and grid.params_to_pass_at_runtime:
9797
_fail_if_runtime_grid_shadows_function(
9898
function_params=function_params,
@@ -108,7 +108,22 @@ def _add_runtime_grid_params(
108108
_fail_if_runtime_grid_shadows_function(
109109
function_params=function_params, name=action_name, kind="action"
110110
)
111-
function_params[action_name] = {"points": "Float1D"}
111+
function_params[action_name] = _irreg_grid_template_entry(grid)
112+
113+
114+
def _irreg_grid_template_entry(grid: IrregSpacedGrid) -> dict[str, str]:
115+
"""Template slots for a runtime-points `IrregSpacedGrid`.
116+
117+
Always exposes `points: Float1D`. Each entry in
118+
`grid.extra_param_names` adds a `ScalarFloat` slot so user-side
119+
injection code can thread its scalar bounds through `fixed_params`
120+
or per-iteration params without tripping `broadcast_to_template`'s
121+
unknown-keys check.
122+
"""
123+
entry: dict[str, str] = {"points": "Float1D"}
124+
for name in grid.extra_param_names:
125+
entry[name] = "ScalarFloat"
126+
return entry
112127

113128

114129
def _fail_if_runtime_grid_shadows_function(

tests/test_runtime_params.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,58 @@ def test_simulate_with_runtime_action_grid_no_nan() -> None:
313313
)
314314
df = result.to_dataframe()
315315
assert not df["value"].isna().any()
316+
317+
318+
# --- extra_param_names: scalar params consumed by user-side injection code ---
319+
320+
321+
def test_extra_param_names_added_to_template():
322+
"""`extra_param_names` show up in the action's template alongside `points`."""
323+
model = _make_action_grid_model(
324+
consumption_grid=IrregSpacedGrid(
325+
n_points=5,
326+
extra_param_names=("max_consumption",),
327+
),
328+
)
329+
alive_template = model._params_template["alive"]
330+
assert alive_template["consumption"] == {
331+
"points": "Float1D",
332+
"max_consumption": "ScalarFloat",
333+
}
334+
335+
336+
def test_extra_param_names_accepted_via_fixed_params():
337+
"""Model-level `fixed_params` with extra-grid-param keys broadcast cleanly."""
338+
grid = IrregSpacedGrid(n_points=5, extra_param_names=("max_consumption",))
339+
model_fixed = _make_action_grid_model(consumption_grid=grid)
340+
# Build via fresh Model to inject `fixed_params`.
341+
alive = model_fixed.regimes["alive"]
342+
dead = model_fixed.regimes["dead"]
343+
model = Model(
344+
regimes={"alive": alive, "dead": dead},
345+
ages=AgeGrid(start=0, stop=2, step="Y"),
346+
regime_id_class=RegimeId,
347+
fixed_params={"max_consumption": 5.0},
348+
)
349+
params = {
350+
"discount_factor": 0.95,
351+
"interest_rate": 0.05,
352+
"alive": {"consumption": {"points": jnp.linspace(0.1, 5.0, 5)}},
353+
}
354+
period_to_regime_to_V_arr = model.solve(params=params, log_level="off")
355+
assert len(period_to_regime_to_V_arr) > 0
356+
357+
358+
def test_extra_param_names_rejected_on_fixed_points_grid():
359+
"""`extra_param_names` is meaningless when points are baked in at construction."""
360+
with pytest.raises(Exception, match="only valid when points are supplied"):
361+
IrregSpacedGrid(points=[1.0, 2.0, 3.0], extra_param_names=("foo",))
362+
363+
364+
def test_extra_param_names_empty_by_default():
365+
"""No `extra_param_names` keeps the template's grid entry to just `points`."""
366+
model = _make_action_grid_model(
367+
consumption_grid=IrregSpacedGrid(n_points=5),
368+
)
369+
alive_template = model._params_template["alive"]
370+
assert alive_template["consumption"] == {"points": "Float1D"}

0 commit comments

Comments
 (0)