44import time
55from collections .abc import Callable , Hashable
66from concurrent .futures import ThreadPoolExecutor , as_completed
7+ from dataclasses import dataclass
78from types import MappingProxyType
89
910import jax
1011import jax .numpy as jnp
1112
1213from lcm .ages import AgeGrid
1314from lcm .interfaces import InternalRegime
14- from lcm .typing import FloatND , InternalParams , RegimeName
15+ from lcm .typing import FlatRegimeParams , FloatND , InternalParams , RegimeName
1516from lcm .utils .error_handling import validate_V
1617from lcm .utils .logging import (
1718 format_duration ,
18- log_nan_in_V ,
1919 log_period_header ,
2020 log_period_timing ,
21- log_V_stats ,
2221)
2322
2423
@@ -71,6 +70,28 @@ def solve(
7170
7271 solution : dict [int , MappingProxyType [RegimeName , FloatND ]] = {}
7372
73+ # Async diagnostics accumulators: every `jnp.any(isnan)`,
74+ # `jnp.any(isinf)` (and the debug min/max/mean trio) lives here as
75+ # a device-side scalar during the hot loop. No host sync happens
76+ # until the single flush in `_emit_deferred_diagnostics` post-loop.
77+ # This replaces the pre-existing synchronous `log_nan_in_V` +
78+ # `log_V_stats` + `validate_V` triple, which forced one host
79+ # transfer per (regime, period) — ~n_regimes * n_periods stalls
80+ # per solve, a meaningful throughput tax in MSM-style loops.
81+ # Both gates fall out of the public log level: `"off"` ⇒ nothing,
82+ # `"warning"` / `"progress"` ⇒ NaN/Inf only, `"debug"` ⇒ adds the
83+ # min/max/mean trio. `"off"` skips even the NaN fail-fast — that
84+ # is the documented contract of `"off"` (suppress all output) and
85+ # is what makes the level useful for tight estimation loops.
86+ diagnostics_enabled = logger .isEnabledFor (logging .WARNING )
87+ stats_enabled = logger .isEnabledFor (logging .DEBUG )
88+ diagnostic_rows : list [_DiagnosticRow ] = []
89+ diagnostic_min : list [FloatND ] = []
90+ diagnostic_max : list [FloatND ] = []
91+ diagnostic_mean : list [FloatND ] = []
92+ diagnostic_any_nan : list [FloatND ] = []
93+ diagnostic_any_inf : list [FloatND ] = []
94+
7495 logger .info ("Starting solution" )
7596 total_start = time .monotonic ()
7697
@@ -110,36 +131,36 @@ def solve(
110131 age = ages .values [period ],
111132 )
112133
113- log_nan_in_V (
114- logger = logger ,
115- regime_name = name ,
116- age = ages . values [ period ],
117- V_arr = V_arr ,
118- )
119- log_V_stats ( logger = logger , regime_name = name , V_arr = V_arr )
120-
121- # Include sibling regimes already solved this period (and the
122- # current regime's V_arr, even though it is NaN-bearing — users
123- # debugging the snapshot want to see all of it).
124- partial = MappingProxyType (
125- {
126- ** solution ,
127- period : MappingProxyType ({ ** period_solution , name : V_arr }),
128- }
129- )
130- validate_V (
131- V_arr = V_arr ,
132- age = float ( ages . values [ period ]) ,
133- regime_name = name ,
134- partial_solution = partial ,
135- compute_intermediates = internal_regime . solve_functions . compute_intermediates . get (
136- period
137- ),
138- state_action_space = state_action_space ,
139- next_regime_to_V_arr = next_regime_to_V_arr ,
140- internal_params = internal_params [ name ],
141- period = period ,
142- )
134+ # Async reductions: gated on log level. `"off"` skips
135+ # everything — no kernel launches, no host syncs, no
136+ # NaN fail-fast. `"warning"` / `"progress"` launches the
137+ # two cheap isnan/isinf reductions; `"debug"` adds the
138+ # min/max/mean trio. Each extra full-V read is a
139+ # memory-bandwidth tax on the larger models, so the
140+ # default keeps it to two reductions per (regime, period).
141+ if diagnostics_enabled :
142+ if stats_enabled :
143+ diagnostic_min . append ( jnp . min ( V_arr ))
144+ diagnostic_max . append ( jnp . max ( V_arr ))
145+ diagnostic_mean . append ( jnp . mean ( V_arr ))
146+ diagnostic_any_nan . append ( jnp . any ( jnp . isnan ( V_arr )))
147+ diagnostic_any_inf . append ( jnp . any ( jnp . isinf ( V_arr )))
148+ diagnostic_rows . append (
149+ _DiagnosticRow (
150+ regime_name = name ,
151+ period = period ,
152+ age = float ( ages . values [ period ]) ,
153+ state_action_space = state_action_space ,
154+ next_regime_to_V_arr = next_regime_to_V_arr ,
155+ regime_params = internal_params [ name ] ,
156+ compute_intermediates = (
157+ internal_regime . solve_functions . compute_intermediates . get (
158+ period
159+ )
160+ ) ,
161+ )
162+ )
163+
143164 period_solution [name ] = V_arr
144165
145166 # Maintain consistent pytree structure: keep all regime keys,
@@ -155,6 +176,24 @@ def solve(
155176 elapsed = time .monotonic () - period_start
156177 log_period_timing (logger = logger , elapsed = elapsed )
157178
179+ # One flush of the GPU kernel queue: ship the stacked reductions
180+ # to host in two transfers (isnan / isinf) by default, plus three
181+ # more (min / max / mean) when debug stats were enabled. Skipped
182+ # entirely at `log_level="off"` — nothing was accumulated.
183+ if diagnostics_enabled :
184+ _emit_deferred_diagnostics (
185+ logger = logger ,
186+ diagnostic_rows = diagnostic_rows ,
187+ reductions = _StackedReductions (
188+ mins = jnp .stack (diagnostic_min ) if diagnostic_min else None ,
189+ maxs = jnp .stack (diagnostic_max ) if diagnostic_max else None ,
190+ means = jnp .stack (diagnostic_mean ) if diagnostic_mean else None ,
191+ any_nan = jnp .stack (diagnostic_any_nan ),
192+ any_inf = jnp .stack (diagnostic_any_inf ),
193+ ),
194+ solution = MappingProxyType (solution ),
195+ )
196+
158197 total_elapsed = time .monotonic () - total_start
159198 logger .info ("Solution complete (%s)" , format_duration (seconds = total_elapsed ))
160199
@@ -338,3 +377,158 @@ def _get_regime_V_shapes(
338377 )
339378 shapes [name ] = tuple (len (v ) for v in state_action_space .states .values ())
340379 return shapes
380+
381+
382+ @dataclass (frozen = True )
383+ class _DiagnosticRow :
384+ """Metadata captured during the backward-induction loop.
385+
386+ Stored refs only — no device work — so appending these rows inside
387+ the hot loop costs essentially nothing. The expensive part (NaN
388+ diagnostic enrichment via `compute_intermediates`) runs at most
389+ once per solve, on the first offending row found after the single
390+ post-loop host flush.
391+ """
392+
393+ regime_name : RegimeName
394+ """Name of the regime whose V-array this row summarises."""
395+ period : int
396+ """Period index in the backward-induction loop."""
397+ age : float
398+ """Age corresponding to `period` (pulled off `AgeGrid.values`)."""
399+ state_action_space : object
400+ """Typed as `object` to avoid a heavy import cycle; consumers know
401+ the actual runtime type from the `max_Q_over_a` signature."""
402+ next_regime_to_V_arr : MappingProxyType [RegimeName , FloatND ]
403+ """Incoming next-period V-arrays, passed through unchanged to
404+ `compute_intermediates` when a NaN is detected."""
405+ regime_params : FlatRegimeParams
406+ """Flat regime parameters used at this (regime, period)."""
407+ compute_intermediates : Callable | None
408+ """Optional closure that recomputes U / F / E[V] / Q for NaN
409+ diagnostic enrichment. `None` when the regime has no
410+ compute-intermediates closure (e.g. terminal periods)."""
411+
412+
413+ @dataclass (frozen = True )
414+ class _StackedReductions :
415+ """Per-stat JAX arrays stacked across all diagnostic rows; still on device.
416+
417+ `mins` / `maxs` / `means` are `None` when the solve ran with a log
418+ level below `debug` — the GPU wasn't asked to compute those
419+ statistics so there's nothing to stack.
420+ """
421+
422+ mins : FloatND | None
423+ """Per-row min of V, or `None` below debug log level."""
424+ maxs : FloatND | None
425+ """Per-row max of V, or `None` below debug log level."""
426+ means : FloatND | None
427+ """Per-row mean of V, or `None` below debug log level."""
428+ any_nan : FloatND
429+ """Per-row boolean flag: any NaN in V at this (regime, period)."""
430+ any_inf : FloatND
431+ """Per-row boolean flag: any Inf in V at this (regime, period)."""
432+
433+
434+ def _emit_deferred_diagnostics (
435+ * ,
436+ logger : logging .Logger ,
437+ diagnostic_rows : list [_DiagnosticRow ],
438+ reductions : _StackedReductions ,
439+ solution : MappingProxyType [int , MappingProxyType [RegimeName , FloatND ]],
440+ ) -> None :
441+ """Flush async diagnostics to host, emit logs, raise on NaN.
442+
443+ Exactly two host transfers by default (one per stat stack), plus
444+ three more (min / max / mean) when debug stats were enabled.
445+ Ordering: NaN check first so we raise before emitting any stats
446+ lines the user wouldn't see anyway; inf check next (warning only);
447+ per-period stats last at debug log level. The `.tolist()` calls
448+ are what actually block on the GPU queue — everything above this
449+ function ran async.
450+ """
451+ any_nan = reductions .any_nan .tolist ()
452+ any_inf = reductions .any_inf .tolist ()
453+
454+ _raise_if_nan (
455+ diagnostic_rows = diagnostic_rows ,
456+ any_nan_per_row = any_nan ,
457+ solution = solution ,
458+ )
459+ _warn_if_inf (
460+ logger = logger ,
461+ diagnostic_rows = diagnostic_rows ,
462+ any_inf_per_row = any_inf ,
463+ )
464+
465+ if (
466+ not logger .isEnabledFor (logging .DEBUG )
467+ or reductions .mins is None
468+ or reductions .maxs is None
469+ or reductions .means is None
470+ ):
471+ return
472+
473+ mins = reductions .mins .tolist ()
474+ maxs = reductions .maxs .tolist ()
475+ means = reductions .means .tolist ()
476+ for row , v_min , v_max , v_mean in zip (
477+ diagnostic_rows , mins , maxs , means , strict = True
478+ ):
479+ logger .debug (
480+ " %s age %s V min=%.3g max=%.3g mean=%.3g" ,
481+ row .regime_name ,
482+ row .age ,
483+ v_min ,
484+ v_max ,
485+ v_mean ,
486+ )
487+
488+
489+ def _raise_if_nan (
490+ * ,
491+ diagnostic_rows : list [_DiagnosticRow ],
492+ any_nan_per_row : list , # list[bool]
493+ solution : MappingProxyType [int , MappingProxyType [RegimeName , FloatND ]],
494+ ) -> None :
495+ """Find the first NaN-bearing (regime, period) and raise."""
496+ for row , flag in zip (diagnostic_rows , any_nan_per_row , strict = True ):
497+ if flag :
498+ _raise_at (row = row , solution = solution )
499+
500+
501+ def _raise_at (
502+ * ,
503+ row : _DiagnosticRow ,
504+ solution : MappingProxyType [int , MappingProxyType [RegimeName , FloatND ]],
505+ ) -> None :
506+ """Run the enriched NaN diagnostic on a single offending row and raise."""
507+ V_arr = solution [row .period ][row .regime_name ]
508+ validate_V (
509+ V_arr = V_arr ,
510+ age = row .age ,
511+ regime_name = row .regime_name ,
512+ partial_solution = solution ,
513+ compute_intermediates = row .compute_intermediates ,
514+ state_action_space = row .state_action_space , # ty: ignore[invalid-argument-type]
515+ next_regime_to_V_arr = row .next_regime_to_V_arr ,
516+ internal_params = row .regime_params ,
517+ period = row .period ,
518+ )
519+
520+
521+ def _warn_if_inf (
522+ * ,
523+ logger : logging .Logger ,
524+ diagnostic_rows : list [_DiagnosticRow ],
525+ any_inf_per_row : list , # list[bool]
526+ ) -> None :
527+ """Emit a warning per (regime, period) with Inf values."""
528+ for row , flag in zip (diagnostic_rows , any_inf_per_row , strict = True ):
529+ if flag :
530+ logger .warning (
531+ "Inf in V_arr for regime '%s' at age %s" ,
532+ row .regime_name ,
533+ row .age ,
534+ )
0 commit comments