Skip to content

Commit c609ccc

Browse files
committed
Merge branch 'feat/canonical-float-dtype' into distributed
2 parents 218cd43 + 7547ac3 commit c609ccc

1 file changed

Lines changed: 12 additions & 50 deletions

File tree

src/lcm/model.py

Lines changed: 12 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import logging
55
import threading
66
from collections.abc import Mapping
7-
from concurrent.futures import Future, ThreadPoolExecutor
87
from pathlib import Path
98
from types import MappingProxyType
109

@@ -97,7 +96,8 @@ class Model:
9796
9897
- `None`: purely lazy behaviour, no AOT.
9998
- First `simulate(...)` with `actual_n == n_subjects`: AOT-compiles all
100-
simulate functions for that batch shape in parallel and caches them.
99+
simulate functions for that batch shape (blocks before solve runs)
100+
and caches them.
101101
- Subsequent `simulate(...)` with the same matching size: reuses the
102102
cached compiled programs.
103103
- `simulate(...)` with a mismatching size: warns once per size and falls
@@ -158,8 +158,8 @@ def __init__(
158158
already has a conflicting entry.
159159
n_subjects: Expected simulate batch size; if set, the first matching
160160
`simulate(...)` call AOT-compiles all simulate functions for
161-
batch shape `n_subjects` in parallel. `None` keeps the purely
162-
lazy behaviour.
161+
batch shape `n_subjects` before backward induction starts.
162+
`None` keeps the purely lazy behaviour.
163163
164164
"""
165165
self.description = description
@@ -333,39 +333,9 @@ def _solve_compiled(
333333
)
334334
return period_to_regime_to_V_arr
335335

336-
def _spawn_simulate_compile(
337-
self,
338-
*,
339-
n_subjects: int,
340-
internal_params: InternalParams,
341-
max_compilation_workers: int | None,
342-
logger: logging.Logger,
343-
) -> Future[MappingProxyType[RegimeName, InternalRegime]]:
344-
"""Submit `compile_all_simulate_functions` to a single-thread executor.
345-
346-
Caller decides whether to spawn (`n_subjects` set, batch shape
347-
matches, no cache hit). The returned `Future` runs in parallel with
348-
whatever the caller does next — typically `_solve_compiled(...)`.
349-
"""
350-
executor = ThreadPoolExecutor(
351-
max_workers=1, thread_name_prefix="lcm-simulate-compile"
352-
)
353-
future = executor.submit(
354-
compile_all_simulate_functions,
355-
internal_regimes=self.internal_regimes,
356-
internal_params=internal_params,
357-
ages=self.ages,
358-
n_subjects=n_subjects,
359-
max_compilation_workers=max_compilation_workers,
360-
logger=logger,
361-
)
362-
executor.shutdown(wait=False)
363-
return future
364-
365336
def _resolve_simulate_internal_regimes(
366337
self,
367338
*,
368-
compile_future: Future[MappingProxyType[RegimeName, InternalRegime]] | None,
369339
actual_n_subjects: int,
370340
log: logging.Logger,
371341
) -> MappingProxyType[RegimeName, InternalRegime]:
@@ -377,11 +347,8 @@ def _resolve_simulate_internal_regimes(
377347
(purely lazy path).
378348
- `actual_n_subjects != n_subjects`: warn once per mismatching size,
379349
return the original `internal_regimes`.
380-
- `actual_n_subjects == n_subjects`, `compile_future is not None`:
381-
await it and cache the result.
382-
- `actual_n_subjects == n_subjects`, `compile_future is None`: cache
383-
must already hold the entry (caller spawned only on cache miss);
384-
return the cached compiled regimes.
350+
- `actual_n_subjects == n_subjects`: return the cached compiled
351+
regimes (caller must have populated the cache before calling).
385352
"""
386353
if self.n_subjects is None:
387354
return self.internal_regimes
@@ -398,11 +365,6 @@ def _resolve_simulate_internal_regimes(
398365
self.n_subjects,
399366
)
400367
return self.internal_regimes
401-
if compile_future is not None:
402-
compiled = compile_future.result()
403-
with self._simulate_compile_lock:
404-
self._simulate_compile_cache[self.n_subjects] = compiled
405-
return compiled
406368
with self._simulate_compile_lock:
407369
return self._simulate_compile_cache[self.n_subjects]
408370

@@ -488,19 +450,20 @@ def simulate(
488450
log = get_logger(log_level=log_level)
489451
actual_n_subjects = len(next(iter(initial_conditions.values())))
490452
n_subjects = self.n_subjects
491-
compile_future: Future[MappingProxyType[RegimeName, InternalRegime]] | None = (
492-
None
493-
)
494453
if n_subjects is not None and n_subjects == actual_n_subjects:
495454
with self._simulate_compile_lock:
496455
needs_compile = n_subjects not in self._simulate_compile_cache
497456
if needs_compile:
498-
compile_future = self._spawn_simulate_compile(
499-
n_subjects=n_subjects,
457+
compiled = compile_all_simulate_functions(
458+
internal_regimes=self.internal_regimes,
500459
internal_params=internal_params,
460+
ages=self.ages,
461+
n_subjects=n_subjects,
501462
max_compilation_workers=max_compilation_workers,
502463
logger=log,
503464
)
465+
with self._simulate_compile_lock:
466+
self._simulate_compile_cache[n_subjects] = compiled
504467
if period_to_regime_to_V_arr is None:
505468
period_to_regime_to_V_arr = self._solve_compiled(
506469
internal_params=internal_params,
@@ -512,7 +475,6 @@ def simulate(
512475
max_compilation_workers=max_compilation_workers,
513476
)
514477
simulate_internal_regimes = self._resolve_simulate_internal_regimes(
515-
compile_future=compile_future,
516478
actual_n_subjects=actual_n_subjects,
517479
log=log,
518480
)

0 commit comments

Comments
 (0)