Skip to content

@categorical: require ScalarInt field annotations + runtime values#350

Open
hmgaudecker wants to merge 3 commits intomainfrom
feat/categorical-scalarint
Open

@categorical: require ScalarInt field annotations + runtime values#350
hmgaudecker wants to merge 3 commits intomainfrom
feat/categorical-scalarint

Conversation

@hmgaudecker
Copy link
Copy Markdown
Member

Summary

  • Tightens @categorical so the annotation contract, runtime values, and validator all agree on ScalarInt. Fields annotated otherwise raise the new CategoricalDefinitionError at decoration time, and class- / instance-level attribute access (LaborSupply.work, LaborSupply().work) now return 0-d jnp.int32 scalars instead of Python ints.
  • Adds lcm.invert_regime_ids (in lcm.utils.containers) so the two pylcm dict-inversion sites can coerce ScalarInt keys to Python int without spelling out the comprehension. DiscreteGrid.codes likewise coerces at the boundary where hashable tuples are needed; the JAX-side surface (to_jax()) is unchanged.
  • Sweeps every @categorical-decorated class across src/lcm_examples/, tests/, and the user-guide notebooks/docs to ScalarInt. The Effort example in mahler_yum_2024/_model.py switches to a make_dataclass + type.__setattr__ shim so its 40 ScalarInt defaults bypass dataclass's mutable-default check.

Closes #349.

Test plan

  • pixi run -e tests-cpu tests — 957 passed, 5 skipped.
  • pixi run -e type-checking ty — clean.
  • prek run --all-files — clean.
  • Downstream cascade in aca-model verified locally (216 aca-model tests + 407-test workspace suite green with this pylcm branch as the submodule pointer); separate aca-model PR will follow.

🤖 Generated with Claude Code

… values (#349)

The decorator used to read `__annotations__` only for field names and
never validate them — every consumer wrote `field: int` even though
codes flowed through JAX as `jnp.int32` everywhere downstream. The
mismatch was a consistent lie ty couldn't catch, and downstream
`jnp.int32(...)` wraps papered over the dtype gap at runtime.

This change closes the gap at all three layers in lockstep:

* **Annotation gate.** `@categorical` requires every field to be
  annotated `ScalarInt` (from `lcm.typing`); other annotations raise
  the new `CategoricalDefinitionError` at decoration time, naming the
  offending fields and pointing at the import.
* **Runtime values.** Class- and instance-level attribute access
  return 0-d `jnp.int32` scalars. The decorator assigns
  `field(default=i, init=False)` (Python int placeholders for
  `dataclass(frozen=True)`'s mutable-default check) then overrides
  the class attributes with `jnp.int32(i)` via `type.__setattr__`
  post-decoration; `init=False` keeps instance `__dict__` empty so
  attribute lookup falls through to the class scalar.
* **Validator.** `validate_category_class` now checks
  `isinstance(value, jax.Array) and value.shape == () and
  jnp.issubdtype(value.dtype, jnp.integer)`; the consecutiveness
  check coerces via `int(v)`.

Hashability fix: JAX 0-d arrays aren't hashable, so the two pylcm
dict-inversion sites (`simulation/simulate.py`, `initial_conditions.py`)
and `DiscreteGrid.codes` now coerce to Python `int` at the boundary
where hashable keys are needed. New helper
`lcm.invert_regime_ids` lives in `lcm.utils.containers` and centralises
the inversion pattern; downstream consumers can call it instead of
hand-rolling `{int(v): k for k, v in ...}`. `RegimeNamesToIds`
tightens to `MappingProxyType[RegimeName, ScalarInt]`.

Sweep: every `@categorical`-decorated class in `src/lcm_examples`,
`tests/`, `tests/test_models/`, and the user-guide notebooks/docs
moves from `field: int` to `field: ScalarInt`. The `Effort` example
in `mahler_yum_2024/_model.py` switches to the `make_dataclass +
type.__setattr__` shim so its 40 `ScalarInt` defaults bypass
`dataclass`'s mutable-default check. `tests/conftest.py` and
`tests/test_grids.py` gain a `_make_dc` helper for hand-crafted
test dataclasses with `ScalarInt` values.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@read-the-docs-community
Copy link
Copy Markdown

read-the-docs-community Bot commented May 11, 2026

hmgaudecker added a commit that referenced this pull request May 11, 2026
PR #350 tightens `@categorical` to require `ScalarInt` field
annotations. The benchmarks-cuda12 env still pulls aca-model from the
pre-#350 sha (`9ac2043`), which annotates fields as `int` and fails
the new decoration-time gate during `setup_cache`.

Point the pin at the aca-model #350 cascade branch
(`feature/categorical-scalarint`, head `b807b28`) until both PRs land.
After aca-model PR#10 and pylcm#350 merge, revert to a `main`-tracking
sha.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
PR #350 tightens `@categorical` to require `ScalarInt` field
annotations. The benchmarks-cuda12 env still pulls aca-model from the
pre-#350 sha (`9ac2043`), which annotates fields as `int` and fails
the new decoration-time gate during `setup_cache`.

Point the pin at the aca-model #350 cascade branch
(`feature/categorical-scalarint`, head `b807b28`) until both PRs land.
After aca-model PR#10 and pylcm#350 merge, revert to a `main`-tracking
sha.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@hmgaudecker hmgaudecker force-pushed the feat/categorical-scalarint branch from 4a05a82 to 6e187a7 Compare May 11, 2026 13:14
aca-model's `get_benchmark_params` returns three values (likely a
moments tuple alongside the two pre-existing entries); the benchmark
file was stuck on the older 2-tuple shape. aca-model's own
`test_benchmark.py` runs against the up-to-date in-tree helper, so the
drift only surfaces here once pylcm's benchmark CI pulls a fresh
aca-model.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown

Benchmark comparison (main → HEAD)

Comparing 99a5e31d (main) → 3f1d1e41 (HEAD)

Benchmark Statistic before after Ratio Alert
aca-baseline execution time 27.474 s 27.841 s 1.01
peak GPU mem 509 MB 579 MB 1.14
compilation time 299.51 s 294.35 s 0.98
peak CPU mem 7.65 GB 7.45 GB 0.97
Mahler-Yum execution time 4.712 s 4.825 s 1.02
peak GPU mem 529 MB 529 MB 1.00
compilation time 14.59 s 14.78 s 1.01
peak CPU mem 1.68 GB 1.67 GB 0.99
Precautionary Savings - Solve execution time 50.8 ms 58.5 ms 1.15
peak GPU mem 101 MB 101 MB 1.00
compilation time 2.71 s 2.71 s 1.00
peak CPU mem 1.13 GB 1.13 GB 1.00
Precautionary Savings - Simulate execution time 126.7 ms 121.6 ms 0.96
peak GPU mem 344 MB 344 MB 1.00
compilation time 4.90 s 4.83 s 0.99
peak CPU mem 1.31 GB 1.31 GB 1.00
Precautionary Savings - Solve & Simulate execution time 145.2 ms 162.4 ms 1.12
peak GPU mem 578 MB 578 MB 1.00
compilation time 7.02 s 7.31 s 1.04
peak CPU mem 1.28 GB 1.28 GB 1.01
Precautionary Savings - Solve & Simulate (irreg) execution time 283.3 ms 284.1 ms 1.00
peak GPU mem 2.19 GB 2.19 GB 1.00
compilation time 7.58 s 7.56 s 1.00
peak CPU mem 1.34 GB 1.34 GB 1.00

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

@categorical should require ScalarInt field annotations

1 participant