@categorical: require ScalarInt field annotations + runtime values#350
Open
hmgaudecker wants to merge 3 commits intomainfrom
Open
@categorical: require ScalarInt field annotations + runtime values#350hmgaudecker wants to merge 3 commits intomainfrom
hmgaudecker wants to merge 3 commits intomainfrom
Conversation
… values (#349) The decorator used to read `__annotations__` only for field names and never validate them — every consumer wrote `field: int` even though codes flowed through JAX as `jnp.int32` everywhere downstream. The mismatch was a consistent lie ty couldn't catch, and downstream `jnp.int32(...)` wraps papered over the dtype gap at runtime. This change closes the gap at all three layers in lockstep: * **Annotation gate.** `@categorical` requires every field to be annotated `ScalarInt` (from `lcm.typing`); other annotations raise the new `CategoricalDefinitionError` at decoration time, naming the offending fields and pointing at the import. * **Runtime values.** Class- and instance-level attribute access return 0-d `jnp.int32` scalars. The decorator assigns `field(default=i, init=False)` (Python int placeholders for `dataclass(frozen=True)`'s mutable-default check) then overrides the class attributes with `jnp.int32(i)` via `type.__setattr__` post-decoration; `init=False` keeps instance `__dict__` empty so attribute lookup falls through to the class scalar. * **Validator.** `validate_category_class` now checks `isinstance(value, jax.Array) and value.shape == () and jnp.issubdtype(value.dtype, jnp.integer)`; the consecutiveness check coerces via `int(v)`. Hashability fix: JAX 0-d arrays aren't hashable, so the two pylcm dict-inversion sites (`simulation/simulate.py`, `initial_conditions.py`) and `DiscreteGrid.codes` now coerce to Python `int` at the boundary where hashable keys are needed. New helper `lcm.invert_regime_ids` lives in `lcm.utils.containers` and centralises the inversion pattern; downstream consumers can call it instead of hand-rolling `{int(v): k for k, v in ...}`. `RegimeNamesToIds` tightens to `MappingProxyType[RegimeName, ScalarInt]`. Sweep: every `@categorical`-decorated class in `src/lcm_examples`, `tests/`, `tests/test_models/`, and the user-guide notebooks/docs moves from `field: int` to `field: ScalarInt`. The `Effort` example in `mahler_yum_2024/_model.py` switches to the `make_dataclass + type.__setattr__` shim so its 40 `ScalarInt` defaults bypass `dataclass`'s mutable-default check. `tests/conftest.py` and `tests/test_grids.py` gain a `_make_dc` helper for hand-crafted test dataclasses with `ScalarInt` values. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
3 tasks
hmgaudecker
added a commit
that referenced
this pull request
May 11, 2026
PR #350 tightens `@categorical` to require `ScalarInt` field annotations. The benchmarks-cuda12 env still pulls aca-model from the pre-#350 sha (`9ac2043`), which annotates fields as `int` and fails the new decoration-time gate during `setup_cache`. Point the pin at the aca-model #350 cascade branch (`feature/categorical-scalarint`, head `b807b28`) until both PRs land. After aca-model PR#10 and pylcm#350 merge, revert to a `main`-tracking sha. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
PR #350 tightens `@categorical` to require `ScalarInt` field annotations. The benchmarks-cuda12 env still pulls aca-model from the pre-#350 sha (`9ac2043`), which annotates fields as `int` and fails the new decoration-time gate during `setup_cache`. Point the pin at the aca-model #350 cascade branch (`feature/categorical-scalarint`, head `b807b28`) until both PRs land. After aca-model PR#10 and pylcm#350 merge, revert to a `main`-tracking sha. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
4a05a82 to
6e187a7
Compare
aca-model's `get_benchmark_params` returns three values (likely a moments tuple alongside the two pre-existing entries); the benchmark file was stuck on the older 2-tuple shape. aca-model's own `test_benchmark.py` runs against the up-to-date in-tree helper, so the drift only surfaces here once pylcm's benchmark CI pulls a fresh aca-model. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
4 tasks
Benchmark comparison (main → HEAD)Comparing
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
@categoricalso the annotation contract, runtime values, and validator all agree onScalarInt. Fields annotated otherwise raise the newCategoricalDefinitionErrorat decoration time, and class- / instance-level attribute access (LaborSupply.work,LaborSupply().work) now return 0-djnp.int32scalars instead of Python ints.lcm.invert_regime_ids(inlcm.utils.containers) so the two pylcm dict-inversion sites can coerceScalarIntkeys to Pythonintwithout spelling out the comprehension.DiscreteGrid.codeslikewise coerces at the boundary where hashable tuples are needed; the JAX-side surface (to_jax()) is unchanged.@categorical-decorated class acrosssrc/lcm_examples/,tests/, and the user-guide notebooks/docs toScalarInt. TheEffortexample inmahler_yum_2024/_model.pyswitches to amake_dataclass + type.__setattr__shim so its 40ScalarIntdefaults bypassdataclass's mutable-default check.Closes #349.
Test plan
pixi run -e tests-cpu tests— 957 passed, 5 skipped.pixi run -e type-checking ty— clean.prek run --all-files— clean.🤖 Generated with Claude Code