Skip to content

Commit 88ba7c2

Browse files
hmgaudeckerclaude
andauthored
solve_brute: defer NaN/stats diagnostics to post-solve (async) (#334)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 5a290a2 commit 88ba7c2

2 files changed

Lines changed: 227 additions & 59 deletions

File tree

src/lcm/solution/solve_brute.py

Lines changed: 227 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,20 @@
44
import time
55
from collections.abc import Callable, Hashable
66
from concurrent.futures import ThreadPoolExecutor, as_completed
7+
from dataclasses import dataclass
78
from types import MappingProxyType
89

910
import jax
1011
import jax.numpy as jnp
1112

1213
from lcm.ages import AgeGrid
1314
from lcm.interfaces import InternalRegime
14-
from lcm.typing import FloatND, InternalParams, RegimeName
15+
from lcm.typing import FlatRegimeParams, FloatND, InternalParams, RegimeName
1516
from lcm.utils.error_handling import validate_V
1617
from lcm.utils.logging import (
1718
format_duration,
18-
log_nan_in_V,
1919
log_period_header,
2020
log_period_timing,
21-
log_V_stats,
2221
)
2322

2423

@@ -71,6 +70,28 @@ def solve(
7170

7271
solution: dict[int, MappingProxyType[RegimeName, FloatND]] = {}
7372

73+
# Async diagnostics accumulators: every `jnp.any(isnan)`,
74+
# `jnp.any(isinf)` (and the debug min/max/mean trio) lives here as
75+
# a device-side scalar during the hot loop. No host sync happens
76+
# until the single flush in `_emit_deferred_diagnostics` post-loop.
77+
# This replaces the pre-existing synchronous `log_nan_in_V` +
78+
# `log_V_stats` + `validate_V` triple, which forced one host
79+
# transfer per (regime, period) — ~n_regimes * n_periods stalls
80+
# per solve, a meaningful throughput tax in MSM-style loops.
81+
# Both gates fall out of the public log level: `"off"` ⇒ nothing,
82+
# `"warning"` / `"progress"` ⇒ NaN/Inf only, `"debug"` ⇒ adds the
83+
# min/max/mean trio. `"off"` skips even the NaN fail-fast — that
84+
# is the documented contract of `"off"` (suppress all output) and
85+
# is what makes the level useful for tight estimation loops.
86+
diagnostics_enabled = logger.isEnabledFor(logging.WARNING)
87+
stats_enabled = logger.isEnabledFor(logging.DEBUG)
88+
diagnostic_rows: list[_DiagnosticRow] = []
89+
diagnostic_min: list[FloatND] = []
90+
diagnostic_max: list[FloatND] = []
91+
diagnostic_mean: list[FloatND] = []
92+
diagnostic_any_nan: list[FloatND] = []
93+
diagnostic_any_inf: list[FloatND] = []
94+
7495
logger.info("Starting solution")
7596
total_start = time.monotonic()
7697

@@ -110,36 +131,36 @@ def solve(
110131
age=ages.values[period],
111132
)
112133

113-
log_nan_in_V(
114-
logger=logger,
115-
regime_name=name,
116-
age=ages.values[period],
117-
V_arr=V_arr,
118-
)
119-
log_V_stats(logger=logger, regime_name=name, V_arr=V_arr)
120-
121-
# Include sibling regimes already solved this period (and the
122-
# current regime's V_arr, even though it is NaN-bearing — users
123-
# debugging the snapshot want to see all of it).
124-
partial = MappingProxyType(
125-
{
126-
**solution,
127-
period: MappingProxyType({**period_solution, name: V_arr}),
128-
}
129-
)
130-
validate_V(
131-
V_arr=V_arr,
132-
age=float(ages.values[period]),
133-
regime_name=name,
134-
partial_solution=partial,
135-
compute_intermediates=internal_regime.solve_functions.compute_intermediates.get(
136-
period
137-
),
138-
state_action_space=state_action_space,
139-
next_regime_to_V_arr=next_regime_to_V_arr,
140-
internal_params=internal_params[name],
141-
period=period,
142-
)
134+
# Async reductions: gated on log level. `"off"` skips
135+
# everything — no kernel launches, no host syncs, no
136+
# NaN fail-fast. `"warning"` / `"progress"` launches the
137+
# two cheap isnan/isinf reductions; `"debug"` adds the
138+
# min/max/mean trio. Each extra full-V read is a
139+
# memory-bandwidth tax on the larger models, so the
140+
# default keeps it to two reductions per (regime, period).
141+
if diagnostics_enabled:
142+
if stats_enabled:
143+
diagnostic_min.append(jnp.min(V_arr))
144+
diagnostic_max.append(jnp.max(V_arr))
145+
diagnostic_mean.append(jnp.mean(V_arr))
146+
diagnostic_any_nan.append(jnp.any(jnp.isnan(V_arr)))
147+
diagnostic_any_inf.append(jnp.any(jnp.isinf(V_arr)))
148+
diagnostic_rows.append(
149+
_DiagnosticRow(
150+
regime_name=name,
151+
period=period,
152+
age=float(ages.values[period]),
153+
state_action_space=state_action_space,
154+
next_regime_to_V_arr=next_regime_to_V_arr,
155+
regime_params=internal_params[name],
156+
compute_intermediates=(
157+
internal_regime.solve_functions.compute_intermediates.get(
158+
period
159+
)
160+
),
161+
)
162+
)
163+
143164
period_solution[name] = V_arr
144165

145166
# Maintain consistent pytree structure: keep all regime keys,
@@ -155,6 +176,24 @@ def solve(
155176
elapsed = time.monotonic() - period_start
156177
log_period_timing(logger=logger, elapsed=elapsed)
157178

179+
# One flush of the GPU kernel queue: ship the stacked reductions
180+
# to host in two transfers (isnan / isinf) by default, plus three
181+
# more (min / max / mean) when debug stats were enabled. Skipped
182+
# entirely at `log_level="off"` — nothing was accumulated.
183+
if diagnostics_enabled:
184+
_emit_deferred_diagnostics(
185+
logger=logger,
186+
diagnostic_rows=diagnostic_rows,
187+
reductions=_StackedReductions(
188+
mins=jnp.stack(diagnostic_min) if diagnostic_min else None,
189+
maxs=jnp.stack(diagnostic_max) if diagnostic_max else None,
190+
means=jnp.stack(diagnostic_mean) if diagnostic_mean else None,
191+
any_nan=jnp.stack(diagnostic_any_nan),
192+
any_inf=jnp.stack(diagnostic_any_inf),
193+
),
194+
solution=MappingProxyType(solution),
195+
)
196+
158197
total_elapsed = time.monotonic() - total_start
159198
logger.info("Solution complete (%s)", format_duration(seconds=total_elapsed))
160199

@@ -338,3 +377,158 @@ def _get_regime_V_shapes(
338377
)
339378
shapes[name] = tuple(len(v) for v in state_action_space.states.values())
340379
return shapes
380+
381+
382+
@dataclass(frozen=True)
383+
class _DiagnosticRow:
384+
"""Metadata captured during the backward-induction loop.
385+
386+
Stored refs only — no device work — so appending these rows inside
387+
the hot loop costs essentially nothing. The expensive part (NaN
388+
diagnostic enrichment via `compute_intermediates`) runs at most
389+
once per solve, on the first offending row found after the single
390+
post-loop host flush.
391+
"""
392+
393+
regime_name: RegimeName
394+
"""Name of the regime whose V-array this row summarises."""
395+
period: int
396+
"""Period index in the backward-induction loop."""
397+
age: float
398+
"""Age corresponding to `period` (pulled off `AgeGrid.values`)."""
399+
state_action_space: object
400+
"""Typed as `object` to avoid a heavy import cycle; consumers know
401+
the actual runtime type from the `max_Q_over_a` signature."""
402+
next_regime_to_V_arr: MappingProxyType[RegimeName, FloatND]
403+
"""Incoming next-period V-arrays, passed through unchanged to
404+
`compute_intermediates` when a NaN is detected."""
405+
regime_params: FlatRegimeParams
406+
"""Flat regime parameters used at this (regime, period)."""
407+
compute_intermediates: Callable | None
408+
"""Optional closure that recomputes U / F / E[V] / Q for NaN
409+
diagnostic enrichment. `None` when the regime has no
410+
compute-intermediates closure (e.g. terminal periods)."""
411+
412+
413+
@dataclass(frozen=True)
414+
class _StackedReductions:
415+
"""Per-stat JAX arrays stacked across all diagnostic rows; still on device.
416+
417+
`mins` / `maxs` / `means` are `None` when the solve ran with a log
418+
level below `debug` — the GPU wasn't asked to compute those
419+
statistics so there's nothing to stack.
420+
"""
421+
422+
mins: FloatND | None
423+
"""Per-row min of V, or `None` below debug log level."""
424+
maxs: FloatND | None
425+
"""Per-row max of V, or `None` below debug log level."""
426+
means: FloatND | None
427+
"""Per-row mean of V, or `None` below debug log level."""
428+
any_nan: FloatND
429+
"""Per-row boolean flag: any NaN in V at this (regime, period)."""
430+
any_inf: FloatND
431+
"""Per-row boolean flag: any Inf in V at this (regime, period)."""
432+
433+
434+
def _emit_deferred_diagnostics(
435+
*,
436+
logger: logging.Logger,
437+
diagnostic_rows: list[_DiagnosticRow],
438+
reductions: _StackedReductions,
439+
solution: MappingProxyType[int, MappingProxyType[RegimeName, FloatND]],
440+
) -> None:
441+
"""Flush async diagnostics to host, emit logs, raise on NaN.
442+
443+
Exactly two host transfers by default (one per stat stack), plus
444+
three more (min / max / mean) when debug stats were enabled.
445+
Ordering: NaN check first so we raise before emitting any stats
446+
lines the user wouldn't see anyway; inf check next (warning only);
447+
per-period stats last at debug log level. The `.tolist()` calls
448+
are what actually block on the GPU queue — everything above this
449+
function ran async.
450+
"""
451+
any_nan = reductions.any_nan.tolist()
452+
any_inf = reductions.any_inf.tolist()
453+
454+
_raise_if_nan(
455+
diagnostic_rows=diagnostic_rows,
456+
any_nan_per_row=any_nan,
457+
solution=solution,
458+
)
459+
_warn_if_inf(
460+
logger=logger,
461+
diagnostic_rows=diagnostic_rows,
462+
any_inf_per_row=any_inf,
463+
)
464+
465+
if (
466+
not logger.isEnabledFor(logging.DEBUG)
467+
or reductions.mins is None
468+
or reductions.maxs is None
469+
or reductions.means is None
470+
):
471+
return
472+
473+
mins = reductions.mins.tolist()
474+
maxs = reductions.maxs.tolist()
475+
means = reductions.means.tolist()
476+
for row, v_min, v_max, v_mean in zip(
477+
diagnostic_rows, mins, maxs, means, strict=True
478+
):
479+
logger.debug(
480+
" %s age %s V min=%.3g max=%.3g mean=%.3g",
481+
row.regime_name,
482+
row.age,
483+
v_min,
484+
v_max,
485+
v_mean,
486+
)
487+
488+
489+
def _raise_if_nan(
490+
*,
491+
diagnostic_rows: list[_DiagnosticRow],
492+
any_nan_per_row: list, # list[bool]
493+
solution: MappingProxyType[int, MappingProxyType[RegimeName, FloatND]],
494+
) -> None:
495+
"""Find the first NaN-bearing (regime, period) and raise."""
496+
for row, flag in zip(diagnostic_rows, any_nan_per_row, strict=True):
497+
if flag:
498+
_raise_at(row=row, solution=solution)
499+
500+
501+
def _raise_at(
502+
*,
503+
row: _DiagnosticRow,
504+
solution: MappingProxyType[int, MappingProxyType[RegimeName, FloatND]],
505+
) -> None:
506+
"""Run the enriched NaN diagnostic on a single offending row and raise."""
507+
V_arr = solution[row.period][row.regime_name]
508+
validate_V(
509+
V_arr=V_arr,
510+
age=row.age,
511+
regime_name=row.regime_name,
512+
partial_solution=solution,
513+
compute_intermediates=row.compute_intermediates,
514+
state_action_space=row.state_action_space, # ty: ignore[invalid-argument-type]
515+
next_regime_to_V_arr=row.next_regime_to_V_arr,
516+
internal_params=row.regime_params,
517+
period=row.period,
518+
)
519+
520+
521+
def _warn_if_inf(
522+
*,
523+
logger: logging.Logger,
524+
diagnostic_rows: list[_DiagnosticRow],
525+
any_inf_per_row: list, # list[bool]
526+
) -> None:
527+
"""Emit a warning per (regime, period) with Inf values."""
528+
for row, flag in zip(diagnostic_rows, any_inf_per_row, strict=True):
529+
if flag:
530+
logger.warning(
531+
"Inf in V_arr for regime '%s' at age %s",
532+
row.regime_name,
533+
row.age,
534+
)

src/lcm/utils/logging.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -76,32 +76,6 @@ def log_nan_in_V(
7676
logger.warning("NaN/Inf in V_arr for regime '%s' at age %s", regime_name, age)
7777

7878

79-
def log_V_stats(
80-
*,
81-
logger: logging.Logger,
82-
regime_name: str,
83-
V_arr: FloatND,
84-
) -> None:
85-
"""Log min/max/mean statistics of a value function array at debug level.
86-
87-
Args:
88-
logger: Logger instance.
89-
regime_name: Name of the regime.
90-
V_arr: Value function array.
91-
92-
"""
93-
if not logger.isEnabledFor(logging.DEBUG):
94-
return
95-
96-
logger.debug(
97-
" - %s: V min=%.3g max=%.3g mean=%.3g",
98-
regime_name,
99-
float(jnp.min(V_arr)),
100-
float(jnp.max(V_arr)),
101-
float(jnp.mean(V_arr)),
102-
)
103-
104-
10579
def log_period_header(
10680
*,
10781
logger: logging.Logger,

0 commit comments

Comments
 (0)