44import logging
55import threading
66from collections .abc import Mapping
7- from concurrent .futures import Future , ThreadPoolExecutor
87from pathlib import Path
98from types import MappingProxyType
109
@@ -97,7 +96,8 @@ class Model:
9796
9897 - `None`: purely lazy behaviour, no AOT.
9998 - First `simulate(...)` with `actual_n == n_subjects`: AOT-compiles all
100- simulate functions for that batch shape in parallel and caches them.
99+ simulate functions for that batch shape (blocks before solve runs)
100+ and caches them.
101101 - Subsequent `simulate(...)` with the same matching size: reuses the
102102 cached compiled programs.
103103 - `simulate(...)` with a mismatching size: warns once per size and falls
@@ -158,8 +158,8 @@ def __init__(
158158 already has a conflicting entry.
159159 n_subjects: Expected simulate batch size; if set, the first matching
160160 `simulate(...)` call AOT-compiles all simulate functions for
161- batch shape `n_subjects` in parallel. `None` keeps the purely
162- lazy behaviour.
161+ batch shape `n_subjects` before backward induction starts.
162+ `None` keeps the purely lazy behaviour.
163163
164164 """
165165 self .description = description
@@ -333,39 +333,9 @@ def _solve_compiled(
333333 )
334334 return period_to_regime_to_V_arr
335335
336- def _spawn_simulate_compile (
337- self ,
338- * ,
339- n_subjects : int ,
340- internal_params : InternalParams ,
341- max_compilation_workers : int | None ,
342- logger : logging .Logger ,
343- ) -> Future [MappingProxyType [RegimeName , InternalRegime ]]:
344- """Submit `compile_all_simulate_functions` to a single-thread executor.
345-
346- Caller decides whether to spawn (`n_subjects` set, batch shape
347- matches, no cache hit). The returned `Future` runs in parallel with
348- whatever the caller does next — typically `_solve_compiled(...)`.
349- """
350- executor = ThreadPoolExecutor (
351- max_workers = 1 , thread_name_prefix = "lcm-simulate-compile"
352- )
353- future = executor .submit (
354- compile_all_simulate_functions ,
355- internal_regimes = self .internal_regimes ,
356- internal_params = internal_params ,
357- ages = self .ages ,
358- n_subjects = n_subjects ,
359- max_compilation_workers = max_compilation_workers ,
360- logger = logger ,
361- )
362- executor .shutdown (wait = False )
363- return future
364-
365336 def _resolve_simulate_internal_regimes (
366337 self ,
367338 * ,
368- compile_future : Future [MappingProxyType [RegimeName , InternalRegime ]] | None ,
369339 actual_n_subjects : int ,
370340 log : logging .Logger ,
371341 ) -> MappingProxyType [RegimeName , InternalRegime ]:
@@ -377,11 +347,8 @@ def _resolve_simulate_internal_regimes(
377347 (purely lazy path).
378348 - `actual_n_subjects != n_subjects`: warn once per mismatching size,
379349 return the original `internal_regimes`.
380- - `actual_n_subjects == n_subjects`, `compile_future is not None`:
381- await it and cache the result.
382- - `actual_n_subjects == n_subjects`, `compile_future is None`: cache
383- must already hold the entry (caller spawned only on cache miss);
384- return the cached compiled regimes.
350+ - `actual_n_subjects == n_subjects`: return the cached compiled
351+ regimes (caller must have populated the cache before calling).
385352 """
386353 if self .n_subjects is None :
387354 return self .internal_regimes
@@ -398,11 +365,6 @@ def _resolve_simulate_internal_regimes(
398365 self .n_subjects ,
399366 )
400367 return self .internal_regimes
401- if compile_future is not None :
402- compiled = compile_future .result ()
403- with self ._simulate_compile_lock :
404- self ._simulate_compile_cache [self .n_subjects ] = compiled
405- return compiled
406368 with self ._simulate_compile_lock :
407369 return self ._simulate_compile_cache [self .n_subjects ]
408370
@@ -488,19 +450,20 @@ def simulate(
488450 log = get_logger (log_level = log_level )
489451 actual_n_subjects = len (next (iter (initial_conditions .values ())))
490452 n_subjects = self .n_subjects
491- compile_future : Future [MappingProxyType [RegimeName , InternalRegime ]] | None = (
492- None
493- )
494453 if n_subjects is not None and n_subjects == actual_n_subjects :
495454 with self ._simulate_compile_lock :
496455 needs_compile = n_subjects not in self ._simulate_compile_cache
497456 if needs_compile :
498- compile_future = self . _spawn_simulate_compile (
499- n_subjects = n_subjects ,
457+ compiled = compile_all_simulate_functions (
458+ internal_regimes = self . internal_regimes ,
500459 internal_params = internal_params ,
460+ ages = self .ages ,
461+ n_subjects = n_subjects ,
501462 max_compilation_workers = max_compilation_workers ,
502463 logger = log ,
503464 )
465+ with self ._simulate_compile_lock :
466+ self ._simulate_compile_cache [n_subjects ] = compiled
504467 if period_to_regime_to_V_arr is None :
505468 period_to_regime_to_V_arr = self ._solve_compiled (
506469 internal_params = internal_params ,
@@ -512,7 +475,6 @@ def simulate(
512475 max_compilation_workers = max_compilation_workers ,
513476 )
514477 simulate_internal_regimes = self ._resolve_simulate_internal_regimes (
515- compile_future = compile_future ,
516478 actual_n_subjects = actual_n_subjects ,
517479 log = log ,
518480 )
0 commit comments