Skip to content

Commit 558dfea

Browse files
hmgaudeckerclaudepre-commit-ci[bot]
authored
Runtime-supplied points on action grids; tighten grid + regime validators (#338)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8f2a4cf commit 558dfea

22 files changed

Lines changed: 1372 additions & 87 deletions

.github/workflows/main.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
- uses: actions/checkout@v6
2929
- uses: prefix-dev/setup-pixi@v0.9.5
3030
with:
31-
pixi-version: v0.66.0
31+
pixi-version: v0.67.2
3232
cache: true
3333
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
3434
environments: tests-cpu
@@ -59,7 +59,7 @@ jobs:
5959
- uses: actions/checkout@v6
6060
- uses: prefix-dev/setup-pixi@v0.9.5
6161
with:
62-
pixi-version: v0.66.0
62+
pixi-version: v0.67.2
6363
cache: true
6464
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
6565
environments: type-checking
@@ -82,7 +82,7 @@ jobs:
8282
- uses: actions/checkout@v6
8383
- uses: prefix-dev/setup-pixi@v0.9.5
8484
with:
85-
pixi-version: v0.66.0
85+
pixi-version: v0.67.2
8686
cache: true
8787
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
8888
environments: tests-cuda12
@@ -101,7 +101,7 @@ jobs:
101101
- uses: actions/checkout@v6
102102
- uses: prefix-dev/setup-pixi@v0.9.5
103103
with:
104-
pixi-version: v0.66.0
104+
pixi-version: v0.67.2
105105
cache: true
106106
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
107107
environments: tests-cuda12
@@ -116,7 +116,7 @@ jobs:
116116
# - uses: actions/checkout@v6
117117
# - uses: prefix-dev/setup-pixi@v0.9.5
118118
# with:
119-
# pixi-version: v0.66.0
119+
# pixi-version: v0.67.2
120120
# cache: true
121121
# cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
122122
# environments: tests-cpu

.gitignore

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@ docs/_build/
3131
.pixi/
3232
node_modules/
3333

34-
# pytask
35-
.pytask.sqlite3
36-
3734
# Python
3835
__pycache__/
3936
*.py[cod]

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ repos:
5050
hooks:
5151
- id: yamllint
5252
- repo: https://github.com/python-jsonschema/check-jsonschema
53-
rev: 0.37.1
53+
rev: 0.37.2
5454
hooks:
5555
- id: check-github-workflows
5656
- repo: https://github.com/astral-sh/ruff-pre-commit
57-
rev: v0.15.11
57+
rev: v0.15.12
5858
hooks:
5959
- id: ruff-check
6060
args:
@@ -86,7 +86,7 @@ repos:
8686
args:
8787
- --wrap
8888
- '88'
89-
files: (AGENTS\.md|CLAUDE\.md|README\.md|modules/.*\.md|profiles/.*\.md)
89+
files: (AGENTS\.md|CLAUDE\.md|README\.md)
9090
- id: mdformat
9191
additional_dependencies:
9292
- mdformat-myst

AGENTS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
@.ai-instructions/profiles/tier-a.md @.ai-instructions/modules/jax.md
22
@.ai-instructions/modules/pandas.md @.ai-instructions/modules/plotting.md
3+
@.ai-instructions/modules/dags.md
34

45
# PyLCM
56

benchmarks/bench_aca_baseline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _build() -> tuple[object, object, object]:
5555
)
5656

5757
model = create_benchmark_model()
58-
_, model_params = get_benchmark_params()
58+
_, model_params = get_benchmark_params(model=model)
5959
initial_conditions = get_benchmark_initial_conditions(
6060
model=model, n_subjects=_N_SUBJECTS, seed=0
6161
)

pixi.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" }
9898
tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" }
9999
type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" }
100100
[tool.pixi.feature.benchmarks.pypi-dependencies]
101-
aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "adc8a19328608781a5cb2a65ab2d93d580163aae" }
101+
aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "134286108b7445f3e17e8824bcdd1739a98b6089" }
102102
[tool.pixi.feature.cuda12]
103103
platforms = [ "linux-64" ]
104104
system-requirements = { cuda = "12" }

src/lcm/grids/continuous.py

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,22 @@ def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array:
116116
class LogSpacedGrid(UniformContinuousGrid):
117117
"""A logarithmically spaced grid of continuous values.
118118
119+
Requires `start > 0`.
120+
119121
Example:
120122
--------
121123
Let `start = 1`, `stop = 100`, and `n_points = 3`. The grid is `[1, 10, 100]`.
122124
123125
"""
124126

127+
def __post_init__(self) -> None:
128+
_validate_continuous_grid(
129+
start=self.start,
130+
stop=self.stop,
131+
n_points=self.n_points,
132+
requires_positive_start=True,
133+
)
134+
125135
def to_jax(self) -> Float1D:
126136
"""Convert the grid to a Jax array."""
127137
return grid_coordinates.logspace(
@@ -188,9 +198,23 @@ def pass_points_at_runtime(self) -> bool:
188198
return self.points is None
189199

190200
def to_jax(self) -> Float1D:
191-
"""Convert the grid to a Jax array."""
201+
"""Convert the grid to a Jax array.
202+
203+
Raises `GridInitializationError` for runtime-supplied grids
204+
(`pass_points_at_runtime=True`). To get the substituted points,
205+
call `internal_regime.state_action_space(regime_params=...)` and
206+
read from `.states[name]` or `.continuous_actions[name]`.
207+
"""
192208
if self.points is None:
193-
return jnp.full(self.n_points, jnp.nan)
209+
raise GridInitializationError(
210+
f"IrregSpacedGrid declared with n_points={self.n_points} and "
211+
f"no points; values are supplied at runtime via "
212+
f"params['<regime>']['<grid_name>']['points']. To get the "
213+
f"substituted points, call "
214+
f"`internal_regime.state_action_space(regime_params=...)` and "
215+
f"read from `.states[name]` or `.continuous_actions[name]`. "
216+
f"Use `.n_points` if only the shape is needed."
217+
)
194218
return jnp.asarray(self.points)
195219

196220
@overload
@@ -213,13 +237,16 @@ def _validate_continuous_grid(
213237
start: float,
214238
stop: float,
215239
n_points: int,
240+
requires_positive_start: bool = False,
216241
) -> None:
217242
"""Validate the continuous grid parameters.
218243
219244
Args:
220245
start: The start value of the grid.
221246
stop: The stop value of the grid.
222247
n_points: The number of points in the grid.
248+
requires_positive_start: If True, also require `start > 0` (used by
249+
log-spaced grids since `log(x)` is undefined for `x <= 0`).
223250
224251
Raises:
225252
GridInitializationError: If the grid parameters are invalid.
@@ -235,6 +262,15 @@ def _validate_continuous_grid(
235262
if not valid_stop_type:
236263
error_messages.append("stop must be a scalar int or float value")
237264

265+
# Reject NaN/inf early — `start >= stop` returns False for NaN, so an
266+
# un-finite start would otherwise pass silently and produce a broken grid.
267+
if valid_start_type and not jnp.isfinite(start):
268+
error_messages.append(f"start must be finite, got {start}")
269+
valid_start_type = False
270+
if valid_stop_type and not jnp.isfinite(stop):
271+
error_messages.append(f"stop must be finite, got {stop}")
272+
valid_stop_type = False
273+
238274
if not isinstance(n_points, int) or n_points < 1:
239275
error_messages.append(
240276
f"n_points must be an int greater than 0 but is {n_points}",
@@ -243,6 +279,12 @@ def _validate_continuous_grid(
243279
if valid_start_type and valid_stop_type and start >= stop:
244280
error_messages.append("start must be less than stop")
245281

282+
if valid_start_type and requires_positive_start and start <= 0:
283+
error_messages.append(
284+
f"start must be > 0 for a log-spaced grid (got {start}); "
285+
f"`log(x)` is undefined for `x <= 0`."
286+
)
287+
246288
if error_messages:
247289
msg = format_messages(error_messages)
248290
raise GridInitializationError(msg)
@@ -275,15 +317,24 @@ def _validate_irreg_spaced_grid(points: Sequence[float] | Float1D) -> None:
275317
f"Non-numeric elements found at indices: {non_numeric}"
276318
)
277319
else:
278-
# Check that points are in ascending order
279-
for i in range(len(points) - 1):
280-
if points[i] >= points[i + 1]:
281-
error_messages.append(
282-
"Points must be in strictly ascending order. "
283-
f"Found points[{i}]={points[i]} >= "
284-
f"points[{i + 1}]={points[i + 1]}"
285-
)
286-
break
320+
# Reject NaN/inf — comparisons with NaN are False, so the
321+
# ascending-order check below would silently let them through.
322+
non_finite = [(i, p) for i, p in enumerate(points) if not jnp.isfinite(p)]
323+
if non_finite:
324+
error_messages.append(
325+
f"All elements of points must be finite. "
326+
f"Non-finite elements found at: {non_finite}"
327+
)
328+
else:
329+
# Check that points are in strictly ascending order
330+
for i in range(len(points) - 1):
331+
if points[i] >= points[i + 1]:
332+
error_messages.append(
333+
"Points must be in strictly ascending order. "
334+
f"Found points[{i}]={points[i]} >= "
335+
f"points[{i + 1}]={points[i + 1]}"
336+
)
337+
break
287338

288339
if error_messages:
289340
msg = format_messages(error_messages)

src/lcm/interfaces.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -242,12 +242,12 @@ class InternalRegime:
242242
"""Flat resolved fixed params for this regime, used by to_dataframe targets."""
243243

244244
def state_action_space(self, regime_params: FlatRegimeParams) -> StateActionSpace:
245-
"""Return the state-action space with runtime state grids filled in.
245+
"""Return the state-action space with runtime grids filled in.
246246
247-
For IrregSpacedGrid with runtime-supplied points, the grid points come from
248-
params as `{state_name}__points`. For _ShockGrid with runtime-supplied params,
249-
the grid points are computed from shock params in the params dict or
250-
resolved_fixed_params.
247+
For IrregSpacedGrid (state or continuous action) with runtime-supplied
248+
points, the grid points come from params as `{name}__points`. For
249+
`_ShockGrid` with runtime-supplied params, the grid points are computed
250+
from shock params in the params dict or `resolved_fixed_params`.
251251
252252
Args:
253253
regime_params: Flat regime parameters supplied at runtime.
@@ -257,35 +257,68 @@ def state_action_space(self, regime_params: FlatRegimeParams) -> StateActionSpac
257257
258258
"""
259259
all_params = {**self.resolved_fixed_params, **regime_params}
260-
replacements: dict[str, ContinuousState | DiscreteState] = {}
261-
for state_name, spec in self.grids.items():
262-
if state_name not in self._base_state_action_space.states:
260+
state_replacements: dict[str, ContinuousState | DiscreteState] = {}
261+
action_replacements: dict[str, ContinuousAction] = {}
262+
for name, spec in self.grids.items():
263+
in_states = name in self._base_state_action_space.states
264+
in_continuous_actions = (
265+
name in self._base_state_action_space.continuous_actions
266+
)
267+
if not (in_states or in_continuous_actions):
263268
continue
264269
if isinstance(spec, IrregSpacedGrid) and spec.pass_points_at_runtime:
265-
points_key = f"{state_name}__points"
270+
points_key = f"{name}__points"
266271
if points_key not in all_params:
267272
continue
268-
replacements[state_name] = cast(
269-
"ContinuousState", all_params[points_key]
270-
)
271-
elif isinstance(spec, _ShockGrid) and spec.params_to_pass_at_runtime:
273+
if in_states:
274+
state_replacements[name] = cast(
275+
"ContinuousState", all_params[points_key]
276+
)
277+
else:
278+
action_replacements[name] = cast(
279+
"ContinuousAction", all_params[points_key]
280+
)
281+
# `_ShockGrid` is state-only by construction (intrinsic
282+
# transitions, forbidden as actions per AGENTS.md). The
283+
# `in_states` gate makes that invariant explicit — a
284+
# `_ShockGrid` reaching the action branch would be a model
285+
# bug, not something this method should silently substitute.
286+
elif (
287+
in_states
288+
and isinstance(spec, _ShockGrid)
289+
and spec.params_to_pass_at_runtime
290+
):
272291
all_present = all(
273-
f"{state_name}__{p}" in all_params
274-
for p in spec.params_to_pass_at_runtime
292+
f"{name}__{p}" in all_params for p in spec.params_to_pass_at_runtime
275293
)
276294
if not all_present:
277295
continue
278296
shock_kw: dict[str, float] = dict(spec.params)
279297
for p in spec.params_to_pass_at_runtime:
280-
shock_kw[p] = cast("float", all_params[f"{state_name}__{p}"])
281-
replacements[state_name] = spec.compute_gridpoints(**shock_kw)
298+
shock_kw[p] = cast("float", all_params[f"{name}__{p}"])
299+
state_replacements[name] = spec.compute_gridpoints(**shock_kw)
282300

283-
if not replacements:
301+
if not state_replacements and not action_replacements:
284302
return self._base_state_action_space
285303

286-
new_states = dict(self._base_state_action_space.states) | replacements
304+
new_states = (
305+
MappingProxyType(
306+
dict(self._base_state_action_space.states) | state_replacements
307+
)
308+
if state_replacements
309+
else None
310+
)
311+
new_continuous_actions = (
312+
MappingProxyType(
313+
dict(self._base_state_action_space.continuous_actions)
314+
| action_replacements
315+
)
316+
if action_replacements
317+
else None
318+
)
287319
return self._base_state_action_space.replace(
288-
states=MappingProxyType(new_states)
320+
states=new_states,
321+
continuous_actions=new_continuous_actions,
289322
)
290323

291324

0 commit comments

Comments
 (0)