Skip to content

Commit c1aa68d

Browse files
committed
Merge branch 'distributed' of https://github.com/OpenSourceEconomics/pylcm into distributed
2 parents e3bd7e4 + 1b2baa3 commit c1aa68d

50 files changed

Lines changed: 3411 additions & 579 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.ai-instructions

AGENTS.md

Lines changed: 144 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,150 @@ initial_conditions = {
294294
- `model.n_periods` - Number of periods in the model (derived from `ages`)
295295
- `model.regime_names_to_ids` - Immutable mapping from regime names to integer indices
296296

297+
## Testing
298+
299+
### Test-Driven Development — always
300+
301+
**Always write the test first, watch it fail, then implement.** No exceptions for new
302+
behavior or bug fixes. Tests are not an afterthought, they are the spec.
303+
304+
The cycle:
305+
306+
1. **Red.** Write a failing test that asserts the desired behavior in user-facing terms.
307+
Run it. Confirm it fails for the *right* reason (the missing behavior — not a typo,
308+
not an import error).
309+
1. **Green.** Write the smallest amount of code that makes the test pass.
310+
1. **Refactor.** Clean up while keeping the test green.
311+
312+
Apply per case:
313+
314+
- **New feature** → red-green-refactor.
315+
- **Bug fix** → reproduce as a failing test before writing the fix. The test then
316+
prevents regression.
317+
- **Refactor (no behavior change)** → existing tests are the spec. Keep them green
318+
before, during, and after. No new test needed if behavior is unchanged; if you find a
319+
behavior gap, fill it with a new test *before* refactoring.
320+
321+
### Test docstrings — describe behavior, not history
322+
323+
Test docstrings state what *should* be true, in user-facing terms. Pretend the reader
324+
has never seen the PR. They should not need to.
325+
326+
```python
327+
# Good — behavior, in plain language
328+
def test_simulate_with_chained_transitions_yields_expected_next_wealth():
329+
"""`next_wealth_t = wealth_t - c_t + 0.1 * next_aime_t` holds in simulation."""
330+
331+
332+
# Bad — rehearses the prior bug or implementation history
333+
def test_solve_resolves_chain_via_dags():
334+
"""Before the fix, `_resolve_fixed_params` raised
335+
`InvalidParamsError: Missing required parameter: ...` because
336+
`create_regime_params_template` classified ..."""
337+
```
338+
339+
Rule of thumb: **would the docstring still make sense in 9 months without the PR
340+
context?** If not, rewrite it.
341+
342+
### Concrete-value assertions
343+
344+
Assert *what* the result is, not just that it didn't crash.
345+
346+
```python
347+
# Good — analytical value with explicit tolerance
348+
np.testing.assert_allclose(curr["wealth"], expected_next_wealth, atol=1e-6)
349+
350+
# Bad — passes whether the math is right or not
351+
assert not jnp.any(jnp.isnan(V_arr))
352+
assert df["wealth"].notna().all()
353+
```
354+
355+
`not isnan` and `no exception raised` belong in CI smoke tests, not in the unit tests
356+
for the feature itself.
357+
358+
### Mechanics
359+
360+
- Use plain pytest functions, never test classes (`class TestFoo`)
361+
- Use `@pytest.mark.parametrize` for test variations
362+
363+
## Docstring Style
364+
365+
Docstrings and inline comments describe the code's *current* state in user-facing terms.
366+
The 9-month-without-PR-context reader is the audience: a docstring that survives that
367+
test stays useful; one that rehearses the diff or the prior implementation rots
368+
immediately.
369+
370+
This applies to **all** docstrings and comments — source and tests. For tests
371+
specifically, see also "Test docstrings — describe behavior, not history" above.
372+
373+
### Describe state, not history
374+
375+
State what is true now. Don't reference prior designs, removed code, or what was
376+
changed. Words like "earlier", "previously", "now", "formerly", "the old", "before the
377+
fix" are red flags.
378+
379+
```python
380+
# Good — forward-looking constraint
381+
class _DiagnosticRow:
382+
"""Metadata captured during the backward-induction loop.
383+
384+
Holds only Python-scalar metadata — no device-array references —
385+
so every (regime, period) row stays at a few bytes regardless of
386+
grid size.
387+
"""
388+
389+
390+
# Bad — rehearses prior design
391+
class _DiagnosticRow:
392+
"""Metadata captured during the backward-induction loop.
393+
394+
Holds only Python-scalar metadata. The earlier design captured
395+
state_action_space and a closure directly on each row, which
396+
pinned every period's V template in device memory until the
397+
post-loop flush.
398+
"""
399+
```
400+
401+
### No PR numbers, no model-specific magic numbers
402+
403+
PR references (`#334 removed the host stalls`, `the bug was fixed in #42`) rot as the
404+
codebase evolves and provide no useful signal to a reader who isn't already in context.
405+
Magic numbers tied to a specific model size or hardware
406+
(`~2 MB at production grid sizes`, `fits on a 16 GB device`) imply a fixed scale that's
407+
only true on whichever model/box the comment was written against. State the qualitative
408+
dependency instead.
409+
410+
```python
411+
# Good — qualitative dependency
412+
# Frees per-period intermediate buffers (V_arr-shaped, so
413+
# model-dependent) so they don't stack up across the loop.
414+
415+
# Bad — PR reference + magic number
416+
# Frees per-period intermediate buffers (~2 MB each at production
417+
# grid sizes) so we don't re-introduce the host stalls that #334
418+
# removed.
419+
```
420+
421+
### Bulleted lists for enumerated cases
422+
423+
When describing a fixed set of cases (log levels, regime kinds, parameter types,
424+
dispatch strategies), use one bullet per case rather than running prose. Bullets scan;
425+
prose hides cases.
426+
427+
```python
428+
# Good — scannable
429+
# Gate falls out of the public log level:
430+
# - `"off"` ⇒ nothing (skips even the NaN fail-fast)
431+
# - `"warning"` / `"progress"` ⇒ NaN/Inf only
432+
# - `"debug"` ⇒ adds the min/max/mean trio
433+
434+
435+
# Bad — buried in prose
436+
# Gate falls out of the public log level: `"off"` ⇒ nothing,
437+
# `"warning"` / `"progress"` ⇒ NaN/Inf only, `"debug"` ⇒ adds the
438+
# min/max/mean trio. `"off"` skips even the NaN fail-fast.
439+
```
440+
297441
## Development Notes
298442

299443
### JAX Integration
@@ -401,11 +545,6 @@ Code structure should be self-evident from function names and ordering.
401545
display math, and `[text](url)` for links. Never use rST-style ``` `` code `` ```,
402546
`:math:`, `:func:`, or `` `link <url>`_ ``.
403547

404-
### Testing Style
405-
406-
- Use plain pytest functions, never test classes (`class TestFoo`)
407-
- Use `@pytest.mark.parametrize` for test variations
408-
409548
### Plotting
410549

411550
- Always use **plotly** for visualizations, never matplotlib. Use `plotly.graph_objects`

benchmarks/bench_aca_baseline.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,30 @@
4747

4848

4949
def _build() -> tuple[object, object, object]:
50-
"""Build the aca-baseline model, params, and initial conditions."""
50+
"""Build the aca-baseline model, params, and initial conditions.
51+
52+
aca_model and lcm imports are deferred to the function body — ASV's
53+
forkserver runs `preimport` to discover benchmarks across every
54+
`bench_*.py` module before forking workers. Importing JAX at module
55+
top loads the multithreaded XLA backend into the forkserver; every
56+
subsequent `os.fork()` inherits a corrupted CUDA context and the
57+
first device op in the worker aborts with
58+
`CUDA_ERROR_NOT_INITIALIZED`. Per-call imports keep JAX out of the
59+
forkserver and confine it to the worker process.
60+
"""
61+
from aca_model.agent.preferences import BenchmarkPrefType
5162
from aca_model.benchmark import (
5263
create_benchmark_model,
5364
get_benchmark_initial_conditions,
5465
get_benchmark_params,
5566
)
5667

57-
model = create_benchmark_model()
68+
from lcm import DiscreteGrid
69+
70+
model = create_benchmark_model(
71+
n_subjects=_N_SUBJECTS,
72+
pref_type_grid=DiscreteGrid(BenchmarkPrefType),
73+
)
5874
_, model_params = get_benchmark_params(model=model)
5975
initial_conditions = get_benchmark_initial_conditions(
6076
model=model, n_subjects=_N_SUBJECTS, seed=0

pixi.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" }
9898
tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" }
9999
type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" }
100100
[tool.pixi.feature.benchmarks.pypi-dependencies]
101-
aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "134286108b7445f3e17e8824bcdd1739a98b6089" }
101+
aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "9ac20430f499a8b1cdb056af85bc2a26e850bad2" }
102102
[tool.pixi.feature.cuda12]
103103
platforms = [ "linux-64" ]
104104
system-requirements = { cuda = "12" }
@@ -242,6 +242,15 @@ per-file-ignores."tests/*" = [
242242
"S301", # Use of pickle
243243
"SLF001", # Private member access
244244
]
245+
per-file-ignores."tests/test_dtypes.py" = [
246+
"ARG001", # Unused function argument (x64_enabled / x64_disabled fixtures)
247+
]
248+
per-file-ignores."tests/test_explicit_dtype_filter.py" = [
249+
"ARG001", # Unused function argument (x64_disabled fixture)
250+
]
251+
per-file-ignores."tests/test_float_dtype_invariants.py" = [
252+
"ARG001", # Unused function argument (x64_disabled fixture)
253+
]
245254
per-file-ignores."tests/test_next_state.py" = [
246255
"ARG001", # Unused function argument
247256
"ARG005", # Unused lambda argument
@@ -294,7 +303,15 @@ ini_options.addopts = [
294303
"--dist",
295304
"loadfile",
296305
]
297-
ini_options.filterwarnings = []
306+
ini_options.filterwarnings = [
307+
# JAX emits this UserWarning when user code asks for a dtype wider
308+
# than the active x64 setting allows. Under `--precision=32` it
309+
# surfaces every stray `jnp.int64` / `jnp.float64` / `dtype=int64`
310+
# literal in src/ — the only files that legitimately trigger it are
311+
# the dtype-invariant test modules, which opt out via a local
312+
# `pytestmark` filter.
313+
"error:Explicitly requested dtype.*:UserWarning",
314+
]
298315
ini_options.markers = [
299316
"illustrative: Tests are designed for illustrative purposes",
300317
"gpu: Tests that require a GPU (skipped on CPU-only machines)",

src/lcm/ages.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import jax.numpy as jnp
1111

1212
from lcm.exceptions import GridInitializationError, format_messages
13-
from lcm.typing import Age, Float1D, Int1D
13+
from lcm.typing import Float1D, Int1D
1414

1515
STEP_UNITS: MappingProxyType[str, Fraction] = MappingProxyType(
1616
{
@@ -129,7 +129,7 @@ def exact_step_size(self) -> int | Fraction | None:
129129
"""
130130
return self._exact_step_size
131131

132-
def period_to_age(self, period: int) -> Age:
132+
def period_to_age(self, period: int) -> int | float:
133133
"""Convert a period index to the corresponding age.
134134
135135
Args:
@@ -151,7 +151,7 @@ def period_to_age(self, period: int) -> Age:
151151
return int(self._values[period])
152152
return float(self._values[period])
153153

154-
def age_to_period(self, age: Age) -> int:
154+
def age_to_period(self, age: float) -> int:
155155
"""Convert an age to the corresponding period index.
156156
157157
Args:
@@ -172,12 +172,14 @@ def age_to_period(self, age: Age) -> int:
172172
raise ValueError(msg) from None
173173

174174
@functools.cached_property
175-
def _age_to_period_map(self) -> dict[Age, int]:
175+
def _age_to_period_map(self) -> dict[int | float, int]:
176176
if self._is_integer:
177177
return {int(v): i for i, v in enumerate(self._exact_values)}
178178
return {float(v): i for i, v in enumerate(self._exact_values)}
179179

180-
def get_periods_where(self, predicate: Callable[[Age], bool]) -> tuple[int, ...]:
180+
def get_periods_where(
181+
self, predicate: Callable[[int | float], bool]
182+
) -> tuple[int, ...]:
181183
"""Get period indices where predicate is True.
182184
183185
Args:
@@ -187,7 +189,7 @@ def get_periods_where(self, predicate: Callable[[Age], bool]) -> tuple[int, ...]
187189
Tuple of period indices where predicate(age) is True.
188190
189191
"""
190-
_convert: Callable[[object], Age] = int if self._is_integer else float # ty: ignore[invalid-assignment]
192+
_convert: Callable[[object], int | float] = int if self._is_integer else float # ty: ignore[invalid-assignment]
191193
return tuple(
192194
period
193195
for period in range(self.n_periods)

0 commit comments

Comments
 (0)