Skip to content

Commit b880524

Browse files
hmgaudeckerclaude
andauthored
regime_template: exempt next_<state> names from fixed_param extraction (#342)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 558dfea commit b880524

11 files changed

Lines changed: 448 additions & 95 deletions

File tree

.ai-instructions

AGENTS.md

Lines changed: 144 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,150 @@ initial_conditions = {
294294
- `model.n_periods` - Number of periods in the model (derived from `ages`)
295295
- `model.regime_names_to_ids` - Immutable mapping from regime names to integer indices
296296

297+
## Testing
298+
299+
### Test-Driven Development — always
300+
301+
**Always write the test first, watch it fail, then implement.** No exceptions for new
302+
behavior or bug fixes. Tests are not an afterthought, they are the spec.
303+
304+
The cycle:
305+
306+
1. **Red.** Write a failing test that asserts the desired behavior in user-facing terms.
307+
Run it. Confirm it fails for the *right* reason (the missing behavior — not a typo,
308+
not an import error).
309+
1. **Green.** Write the smallest amount of code that makes the test pass.
310+
1. **Refactor.** Clean up while keeping the test green.
311+
312+
Apply per case:
313+
314+
- **New feature** → red-green-refactor.
315+
- **Bug fix** → reproduce as a failing test before writing the fix. The test then
316+
prevents regression.
317+
- **Refactor (no behavior change)** → existing tests are the spec. Keep them green
318+
before, during, and after. No new test needed if behavior is unchanged; if you find a
319+
behavior gap, fill it with a new test *before* refactoring.
320+
321+
### Test docstrings — describe behavior, not history
322+
323+
Test docstrings state what *should* be true, in user-facing terms. Pretend the reader
324+
has never seen the PR. They should not need to.
325+
326+
```python
327+
# Good — behavior, in plain language
328+
def test_simulate_with_chained_transitions_yields_expected_next_wealth():
329+
"""`next_wealth_t = wealth_t - c_t + 0.1 * next_aime_t` holds in simulation."""
330+
331+
332+
# Bad — rehearses the prior bug or implementation history
333+
def test_solve_resolves_chain_via_dags():
334+
"""Before the fix, `_resolve_fixed_params` raised
335+
`InvalidParamsError: Missing required parameter: ...` because
336+
`create_regime_params_template` classified ..."""
337+
```
338+
339+
Rule of thumb: **would the docstring still make sense in 9 months without the PR
340+
context?** If not, rewrite it.
341+
342+
### Concrete-value assertions
343+
344+
Assert *what* the result is, not just that it didn't crash.
345+
346+
```python
347+
# Good — analytical value with explicit tolerance
348+
np.testing.assert_allclose(curr["wealth"], expected_next_wealth, atol=1e-6)
349+
350+
# Bad — passes whether the math is right or not
351+
assert not jnp.any(jnp.isnan(V_arr))
352+
assert df["wealth"].notna().all()
353+
```
354+
355+
`not isnan` and `no exception raised` belong in CI smoke tests, not in the unit tests
356+
for the feature itself.
357+
358+
### Mechanics
359+
360+
- Use plain pytest functions, never test classes (`class TestFoo`)
361+
- Use `@pytest.mark.parametrize` for test variations
362+
363+
## Docstring Style
364+
365+
Docstrings and inline comments describe the code's *current* state in user-facing terms.
366+
The 9-month-without-PR-context reader is the audience: a docstring that survives that
367+
test stays useful; one that rehearses the diff or the prior implementation rots
368+
immediately.
369+
370+
This applies to **all** docstrings and comments — source and tests. For tests
371+
specifically, see also "Test docstrings — describe behavior, not history" above.
372+
373+
### Describe state, not history
374+
375+
State what is true now. Don't reference prior designs, removed code, or what was
376+
changed. Words like "earlier", "previously", "now", "formerly", "the old", "before the
377+
fix" are red flags.
378+
379+
```python
380+
# Good — forward-looking constraint
381+
class _DiagnosticRow:
382+
"""Metadata captured during the backward-induction loop.
383+
384+
Holds only Python-scalar metadata — no device-array references —
385+
so every (regime, period) row stays at a few bytes regardless of
386+
grid size.
387+
"""
388+
389+
390+
# Bad — rehearses prior design
391+
class _DiagnosticRow:
392+
"""Metadata captured during the backward-induction loop.
393+
394+
Holds only Python-scalar metadata. The earlier design captured
395+
state_action_space and a closure directly on each row, which
396+
pinned every period's V template in device memory until the
397+
post-loop flush.
398+
"""
399+
```
400+
401+
### No PR numbers, no model-specific magic numbers
402+
403+
PR references (`#334 removed the host stalls`, `the bug was fixed in #42`) rot as the
404+
codebase evolves and provide no useful signal to a reader who isn't already in context.
405+
Magic numbers tied to a specific model size or hardware
406+
(`~2 MB at production grid sizes`, `fits on a 16 GB device`) imply a fixed scale that's
407+
only true on whichever model/box the comment was written against. State the qualitative
408+
dependency instead.
409+
410+
```python
411+
# Good — qualitative dependency
412+
# Frees per-period intermediate buffers (V_arr-shaped, so
413+
# model-dependent) so they don't stack up across the loop.
414+
415+
# Bad — PR reference + magic number
416+
# Frees per-period intermediate buffers (~2 MB each at production
417+
# grid sizes) so we don't re-introduce the host stalls that #334
418+
# removed.
419+
```
420+
421+
### Bulleted lists for enumerated cases
422+
423+
When describing a fixed set of cases (log levels, regime kinds, parameter types,
424+
dispatch strategies), use one bullet per case rather than running prose. Bullets scan;
425+
prose hides cases.
426+
427+
```python
428+
# Good — scannable
429+
# Gate falls out of the public log level:
430+
# - `"off"` ⇒ nothing (skips even the NaN fail-fast)
431+
# - `"warning"` / `"progress"` ⇒ NaN/Inf only
432+
# - `"debug"` ⇒ adds the min/max/mean trio
433+
434+
435+
# Bad — buried in prose
436+
# Gate falls out of the public log level: `"off"` ⇒ nothing,
437+
# `"warning"` / `"progress"` ⇒ NaN/Inf only, `"debug"` ⇒ adds the
438+
# min/max/mean trio. `"off"` skips even the NaN fail-fast.
439+
```
440+
297441
## Development Notes
298442

299443
### JAX Integration
@@ -401,11 +545,6 @@ Code structure should be self-evident from function names and ordering.
401545
display math, and `[text](url)` for links. Never use rST-style ``` `` code `` ```,
402546
`:math:`, `:func:`, or `` `link <url>`_ ``.
403547

404-
### Testing Style
405-
406-
- Use plain pytest functions, never test classes (`class TestFoo`)
407-
- Use `@pytest.mark.parametrize` for test variations
408-
409548
### Plotting
410549

411550
- Always use **plotly** for visualizations, never matplotlib. Use `plotly.graph_objects`

pixi.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/lcm/params/regime_template.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,18 @@
1717
)
1818

1919

20-
def create_regime_params_template(
21-
regime: Regime,
22-
) -> RegimeParamsTemplate:
20+
def create_regime_params_template(regime: Regime) -> RegimeParamsTemplate:
2321
"""Create parameter template from a regime specification.
2422
25-
Discover parameters from function signatures via `dags.tree`. Parameters are
26-
function arguments that are not states, actions, other regime functions, or
27-
special variables (period, age, E_next_V).
23+
Discover parameters from function signatures via `dags.tree`. Parameters
24+
are function arguments that are not states, actions, regime functions,
25+
`next_<state>` outputs, or special variables (`period`, `age`, `E_next_V`).
2826
2927
For `SolveSimulateFunctionPair` entries, the template contains the **union**
3028
of both variants' parameters so the user can provide a single flat params
3129
dict that satisfies both phases.
3230
33-
Grids with runtime-supplied values (IrregSpacedGrid without points,
31+
Grids with runtime-supplied values (`IrregSpacedGrid` without points,
3432
`_ShockGrid` without full shock_params) add entries to the template under
3533
pseudo-function keys matching the state or action name.
3634
@@ -41,8 +39,15 @@ def create_regime_params_template(
4139
The regime parameter template with type annotations as values.
4240
4341
"""
44-
H_variables = {*regime.functions, "period", "age", "E_next_V"}
45-
variables = H_variables | set(regime.actions) | set(regime.states)
42+
variables = {
43+
*set(regime.states),
44+
*set(regime.actions),
45+
*regime.functions,
46+
*(f"next_{name}" for name in regime.states),
47+
"period",
48+
"age",
49+
"E_next_V",
50+
}
4651

4752
function_params: dict[FunctionName, dict[str, str]] = {}
4853

src/lcm/regime_building/next_state.py

Lines changed: 69 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import jax
77
import pandas as pd
88
from dags import concatenate_functions, with_signature
9-
from dags.tree import qname_from_tree_path, tree_path_from_qname
9+
from dags.tree import qname_from_tree_path
1010
from jax import Array
1111

1212
from lcm.grids import Grid
@@ -69,6 +69,17 @@ def get_next_state_function_for_simulation(
6969
) -> NextStateSimulationFunction:
7070
"""Get function that computes the next states during the simulation.
7171
72+
Builds one DAG per target regime using unqualified `next_<state>` keys, mirroring
73+
the per-target structure of {func}`get_next_state_function_for_solution`. This
74+
lets a transition function or auxiliary regime function consume another
75+
transition's `next_<state>` output via plain name resolution within the same
76+
target's DAG. The combined function returns a nested mapping keyed by target
77+
regime name, with each inner dict using unqualified `next_<state>` keys.
78+
79+
Stochastic-transition wrappers expose `key_<target>__next_<state>` and
80+
`weight_<target>__next_<state>` as external arguments so callers can pass a
81+
distinct random key and pre-computed weight per target.
82+
7283
Args:
7384
transitions: Nested mapping of target regime names to transition functions.
7485
functions: Immutable mapping of auxiliary functions of a regime.
@@ -78,26 +89,31 @@ def get_next_state_function_for_simulation(
7889
7990
Returns:
8091
Function that computes the next states. Depends on states and actions of the
81-
current period, and the regime parameters ("params"). If target is "simulate",
82-
the function also depends on the dictionary of random keys ("keys"), which
83-
corresponds to the names of stochastic next functions.
92+
current period, and the regime parameters ("params"). The function also
93+
depends on the dictionary of random keys ("keys") for stochastic transitions.
94+
Returns `{target_regime_name: {next_<state>: array}}`.
8495
8596
"""
86-
flat_transitions = flatten_regime_namespace(transitions)
87-
88-
# For the simulation target, we need to extend the functions dictionary with
89-
# stochastic next states functions and their weights.
90-
extended_transitions = _extend_transitions_for_simulation(
91-
all_grids=all_grids,
92-
flat_transitions=flat_transitions,
93-
variable_info=variable_info,
94-
stochastic_transition_names=stochastic_transition_names,
95-
)
96-
functions_to_concatenate = extended_transitions | dict(functions)
97+
per_target_funcs: dict[RegimeName, Callable[..., dict[str, Array]]] = {}
98+
for target, target_transitions in transitions.items():
99+
extended = _extend_target_transitions_for_simulation(
100+
target=target,
101+
target_transitions=target_transitions,
102+
all_grids=all_grids,
103+
variable_info=variable_info,
104+
stochastic_transition_names=stochastic_transition_names,
105+
)
106+
per_target_funcs[target] = concatenate_functions(
107+
functions=dict(extended) | dict(functions),
108+
targets=list(extended.keys()),
109+
return_type="dict",
110+
enforce_signature=False,
111+
set_annotations=True,
112+
)
97113

98114
return concatenate_functions(
99-
functions=functions_to_concatenate,
100-
targets=list(flat_transitions.keys()),
115+
functions=per_target_funcs,
116+
targets=list(per_target_funcs.keys()),
101117
return_type="dict",
102118
enforce_signature=False,
103119
set_annotations=True,
@@ -137,64 +153,59 @@ def get_next_stochastic_weights_function(
137153
)
138154

139155

140-
def _extend_transitions_for_simulation(
156+
def _extend_target_transitions_for_simulation(
141157
*,
158+
target: RegimeName,
159+
target_transitions: MappingProxyType[TransitionFunctionName, Callable[..., Array]],
142160
all_grids: MappingProxyType[RegimeName, MappingProxyType[StateOrActionName, Grid]],
143-
flat_transitions: FunctionsMapping,
144161
variable_info: pd.DataFrame,
145162
stochastic_transition_names: frozenset[TransitionFunctionName],
146163
) -> dict[TransitionFunctionName, Callable[..., Array]]:
147-
"""Extend the functions dictionary for the simulation target.
164+
"""Replace stochastic transitions for one target with realisation wrappers.
165+
166+
Deterministic transitions are passed through unchanged. Stochastic transitions
167+
are replaced by wrappers that draw a realisation from a precomputed weight
168+
vector and a random key. The wrapper's external argument names use
169+
target-qualified form (`key_<target>__<next_state>`,
170+
`weight_<target>__<next_state>`) so multi-target callers can supply distinct
171+
random keys per target. The dict key keeps the unqualified `next_<state>` so
172+
other transitions or regime functions in the same target's DAG can resolve
173+
it by name.
148174
149175
Args:
176+
target: Target regime name.
177+
target_transitions: Mapping of unqualified `next_<state>` transition names
178+
to functions, restricted to one target regime.
150179
all_grids: Immutable mapping of regime names to Grid spec objects.
151-
flat_transitions: Flattened mapping of transition names to functions.
152180
variable_info: Variable info of the current regime.
153181
stochastic_transition_names: Frozenset of stochastic transition function names.
154182
155183
Returns:
156-
Extended functions dictionary.
184+
Extended transitions dictionary keyed by unqualified `next_<state>` names.
157185
158186
"""
159187
shock_names: set[ShockName] = set(variable_info.query("is_shock").index.to_list())
160188
flat_grids = flatten_regime_namespace(all_grids)
161-
discrete_stochastic_targets = [
162-
func_name
163-
for func_name in flat_transitions
164-
if tree_path_from_qname(func_name)[-1] in stochastic_transition_names
165-
and tree_path_from_qname(func_name)[-1].removeprefix("next_") not in shock_names
166-
]
167-
continuous_stochastic_targets = [
168-
func_name
169-
for func_name in flat_transitions
170-
if tree_path_from_qname(func_name)[-1] in stochastic_transition_names
171-
and tree_path_from_qname(func_name)[-1].removeprefix("next_") in shock_names
172-
]
173-
# Handle stochastic next states functions
174-
# ----------------------------------------------------------------------------------
175-
# We generate stochastic next states functions that simulate the next state given
176-
# a random key (think of a seed) and the weights corresponding to the labels of the
177-
# stochastic variable. The weights are computed using the stochastic weight
178-
# functions, which we add the to functions dict. `dags.concatenate_functions` then
179-
# generates a function that computes the weights and simulates the next state in
180-
# one go.
181-
# ----------------------------------------------------------------------------------
182-
discrete_stochastic_next = {
183-
name: _create_discrete_stochastic_next_func(
184-
name=name, labels=flat_grids[name.replace("next_", "")].to_jax()
185-
)
186-
for name in discrete_stochastic_targets
187-
}
188-
continuous_stochastic_next = {
189-
name: _create_continuous_stochastic_next_func(name=name, flat_grids=flat_grids)
190-
for name in continuous_stochastic_targets
191-
}
192-
193-
# Overwrite regime transitions with generated stochastic next states functions
194-
# ----------------------------------------------------------------------------------
195-
return (
196-
dict(flat_transitions) | discrete_stochastic_next | continuous_stochastic_next
189+
extended: dict[TransitionFunctionName, Callable[..., Array]] = dict(
190+
target_transitions
197191
)
192+
for next_state_name in target_transitions:
193+
if next_state_name not in stochastic_transition_names:
194+
continue
195+
qname = qname_from_tree_path((target, next_state_name))
196+
raw_state_name = next_state_name.removeprefix("next_")
197+
if raw_state_name in shock_names:
198+
extended[next_state_name] = _create_continuous_stochastic_next_func(
199+
name=qname, flat_grids=flat_grids
200+
)
201+
else:
202+
extended[next_state_name] = _create_discrete_stochastic_next_func(
203+
name=qname,
204+
labels=flat_grids[
205+
qname_from_tree_path((target, raw_state_name))
206+
].to_jax(),
207+
)
208+
return extended
198209

199210

200211
def _create_discrete_stochastic_next_func(

0 commit comments

Comments
 (0)