Skip to content

Commit 44e8522

Browse files
hmgaudeckerclaude
andcommitted
Redesign the solver seam as an ABC (GridSearch + DCEGM stub).
Replace the `type(solver)`-keyed builder registry with a polymorphic `Solver` ABC. The engine now calls `solver.validate(context)` then `solver.build_period_kernels(context)` — no `SOLVER_KERNEL_BUILDERS` dict, no `BruteForce | DCEGM` union, no standalone DC-EGM guard. - `_lcm/solution/contract.py` (new): the `Solver` ABC (abstract `build_period_kernels`, default no-op `validate`), `SolverBuildContext`, and `SolverKernels`. An engine leaf — imports nothing that reaches `lcm.solvers`, so the façade can re-export it without an import cycle. - `_lcm/solution/solvers.py` (new): `GridSearch(Solver)` (the relocated grid-search builder, with function-local `jax`/`get_max_Q_over_a` imports) and `DCEGM(Solver)` (the published config; `validate` raises the not-yet-available guard, so a regime requesting it is rejected at model build). - `lcm/solvers.py` → thin re-export façade; `registry.py` deleted; the `processing` dispatch and `Regime.solver` field updated. - Rename the default solver `BruteForce` → `GridSearch` (more descriptive; alpha permits the break). Faithful to dcegm-solver-seam-abc-design.md (ABC, not Protocol). Layer 2 (generic KernelResult) does not apply here — the stub seam has no EGM fork. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent 91cc590 commit 44e8522

8 files changed

Lines changed: 379 additions & 364 deletions

File tree

src/_lcm/regime_building/processing.py

Lines changed: 17 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
collect_stochastic_state_transitions,
4646
)
4747
from _lcm.regime_building.V import VInterpolationInfo, create_v_interpolation_info
48-
from _lcm.solution.registry import SOLVER_KERNEL_BUILDERS, SolverBuildContext
48+
from _lcm.solution.contract import SolverBuildContext
4949
from _lcm.state_action_space import create_state_action_space
5050
from _lcm.typing import (
5151
ArgmaxQOverAFunction,
@@ -79,7 +79,7 @@
7979
from lcm.ages import AgeGrid
8080
from lcm.exceptions import ModelInitializationError
8181
from lcm.regime import Regime as UserRegime
82-
from lcm.solvers import DCEGM, BruteForce
82+
from lcm.solvers import Solver
8383
from lcm.transition import (
8484
MarkovTransition,
8585
)
@@ -114,8 +114,6 @@ def process_regimes(
114114
The processed canonical regimes.
115115
116116
"""
117-
_fail_if_dcegm_solver_requested(user_regimes)
118-
119117
# The canonical specs hold every law in target-granular form, resolved per
120118
# phase: the simulate slice additionally holds every carried-only state
121119
# and its law of motion, so the canonical mapping carries the law toward
@@ -240,29 +238,6 @@ def process_regimes(
240238
return ensure_containers_are_immutable(canonical_regimes)
241239

242240

243-
def _fail_if_dcegm_solver_requested(
244-
user_regimes: Mapping[RegimeName, FinalizedUserRegime],
245-
) -> None:
246-
"""Reject the not-yet-available DC-EGM solver at model build.
247-
248-
The `DCEGM` configuration is published so a model can name the solver and
249-
its parameters, but the solver engine is not yet wired in; a regime that
250-
requests it cannot be solved. `BruteForce` is the only available solver.
251-
"""
252-
dcegm_regimes = sorted(
253-
name
254-
for name, user_regime in user_regimes.items()
255-
if isinstance(user_regime.solver, DCEGM)
256-
)
257-
if dcegm_regimes:
258-
msg = (
259-
"The DC-EGM solver is not yet available. Regime(s) "
260-
f"{dcegm_regimes} request `solver=DCEGM(...)`; use `BruteForce()` "
261-
"(the default) until DC-EGM is wired in."
262-
)
263-
raise NotImplementedError(msg)
264-
265-
266241
def _build_solution_phase(
267242
*,
268243
spec: PhasedRegimeSpec,
@@ -279,7 +254,7 @@ def _build_solution_phase(
279254
ages: AgeGrid,
280255
enable_jit: bool,
281256
has_taste_shocks: bool,
282-
solver: BruteForce | DCEGM,
257+
solver: Solver,
283258
) -> SolutionPhase:
284259
"""Build all compiled functions for the backward-induction (solve) phase.
285260
@@ -301,8 +276,8 @@ def _build_solution_phase(
301276
enable_jit: Whether to jit the internal functions.
302277
has_taste_shocks: Whether the regime declares EV1 taste shocks on its
303278
discrete actions.
304-
solver: The regime's solver configuration; selects the per-period
305-
kernel builder dispatched through `SOLVER_KERNEL_BUILDERS`.
279+
solver: The regime's solver; the engine calls `validate` then
280+
`build_period_kernels` on it to obtain the per-period kernels.
306281
307282
Returns:
308283
Complete solve functions container.
@@ -372,20 +347,19 @@ def _build_solution_phase(
372347
enable_jit=enable_jit,
373348
)
374349

375-
# Dispatch the per-period kernel build on the regime's solver
376-
# configuration. `BruteForce` builds the max-Q-over-a grid-search kernels;
377-
# other solvers register their own builders in `SOLVER_KERNEL_BUILDERS`.
378-
solver_kernel_builder = SOLVER_KERNEL_BUILDERS[type(solver)]
379-
solver_kernels = solver_kernel_builder(
380-
solver=solver,
381-
context=SolverBuildContext(
382-
state_action_space=state_action_space,
383-
Q_and_F_functions=Q_and_F_functions,
384-
grids=all_grids[regime_name],
385-
enable_jit=enable_jit,
386-
has_taste_shocks=has_taste_shocks,
387-
),
350+
# Dispatch the per-period kernel build polymorphically on the regime's
351+
# solver: `validate` rejects out-of-scope configurations at build time,
352+
# then `build_period_kernels` returns the per-period kernels. `GridSearch`
353+
# builds the max-Q-over-a grid-search kernels.
354+
context = SolverBuildContext(
355+
state_action_space=state_action_space,
356+
Q_and_F_functions=Q_and_F_functions,
357+
grids=all_grids[regime_name],
358+
enable_jit=enable_jit,
359+
has_taste_shocks=has_taste_shocks,
388360
)
361+
solver.validate(context=context)
362+
solver_kernels = solver.build_period_kernels(context=context)
389363
max_Q_over_a = solver_kernels.max_Q_over_a
390364

391365
return SolutionPhase(

src/_lcm/solution/contract.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""The solver contract: what every regime solver provides to the engine.
2+
3+
A regime's `solver` field selects its backward-induction algorithm. The engine
4+
dispatches polymorphically on the solver instance — `solver.validate(context)`
5+
then `solver.build_period_kernels(context)` — with no switch on solver type.
6+
Add a solver by subclassing `Solver` and implementing `build_period_kernels`;
7+
override `validate` for a build-time model-contract check (the default is a
8+
no-op). `SolverBuildContext` carries everything a solver may read to build one
9+
regime's kernels; `SolverKernels` is what it hands back.
10+
11+
This module is an engine leaf: it imports only `_lcm.engine` / `_lcm.grids` /
12+
`_lcm.typing` (none of which reach `lcm.solvers`), so the public solver façade
13+
can re-export it without forming an import cycle.
14+
"""
15+
16+
from abc import ABC, abstractmethod
17+
from dataclasses import dataclass
18+
from types import MappingProxyType
19+
20+
from _lcm.engine import StateActionSpace
21+
from _lcm.grids import Grid
22+
from _lcm.typing import MaxQOverAFunction, QAndFFunction, StateOrActionName
23+
24+
25+
@dataclass(frozen=True, kw_only=True)
26+
class SolverBuildContext:
27+
"""Everything a solver may read to build one regime's kernels.
28+
29+
Bundled so the solver method signature stays stable as solvers with
30+
different needs are added; each solver reads only the fields it uses.
31+
"""
32+
33+
state_action_space: StateActionSpace
34+
"""The regime's state-action space."""
35+
36+
Q_and_F_functions: MappingProxyType[int, QAndFFunction]
37+
"""Immutable mapping of period to Q-and-F closures."""
38+
39+
grids: MappingProxyType[StateOrActionName, Grid]
40+
"""Immutable mapping of the regime's variable names to grid objects."""
41+
42+
enable_jit: bool
43+
"""Whether to JIT-compile the kernels."""
44+
45+
has_taste_shocks: bool
46+
"""Whether the regime declares EV1 taste shocks on its discrete actions."""
47+
48+
49+
@dataclass(frozen=True, kw_only=True)
50+
class SolverKernels:
51+
"""Per-period solve kernels produced by a solver."""
52+
53+
max_Q_over_a: MappingProxyType[int, MaxQOverAFunction]
54+
"""Immutable mapping of period to max-Q-over-actions kernels.
55+
56+
Empty for solvers that replace the grid search with their own kernels.
57+
"""
58+
59+
60+
class Solver(ABC):
61+
"""Base class for regime solvers — the polymorphic dispatch target.
62+
63+
The engine calls `validate` then `build_period_kernels` on the instance,
64+
matching the engine's own polymorphism (`Grid(ABC)`, the stochastic
65+
processes). Subclasses are frozen dataclasses carrying the solver's
66+
configuration.
67+
"""
68+
69+
@abstractmethod
70+
def build_period_kernels(self, *, context: SolverBuildContext) -> SolverKernels:
71+
"""Build the regime's per-period solve kernels."""
72+
73+
def validate(self, *, context: SolverBuildContext) -> None: # noqa: B027
74+
"""Check the regime is in scope for this solver. Default: no-op."""

src/_lcm/solution/registry.py

Lines changed: 0 additions & 111 deletions
This file was deleted.

0 commit comments

Comments
 (0)