|
| 1 | +"""End-to-end benchmark for the aca baseline model (benchmark-sized grids). |
| 2 | +
|
| 3 | +Uses `aca_model.benchmark.create_benchmark_model()` — the full 18-regime |
| 4 | +aca baseline with tiny continuous grids (`BENCHMARK_GRID_CONFIG`) and a |
| 5 | +2-type `BenchmarkPrefType` (half the compile + execution volume of the |
| 6 | +production 3-type `PrefType`). The kernel exercised here keeps the |
| 7 | +expensive parts of aca-baseline's cost structure (compile pipeline |
| 8 | +over 19 regimes, DAG resolution, pref_type batching) while shrinking |
| 9 | +per-call numerical work so the benchmark fits in an asv invocation. |
| 10 | +
|
| 11 | +Requires the `aca_model` package to be importable. Use the |
| 12 | +`benchmarks-cuda12` pixi environment, which pulls aca-model from its |
| 13 | +public git URL. Inside the aca-dev monorepo the editable path install |
| 14 | +takes precedence. Benchmark params are loaded from a frozen pickle |
| 15 | +shipped in aca-model — no aca-data pipeline run required. |
| 16 | +
|
| 17 | +ASV wiring notes: |
| 18 | +
|
| 19 | +- We use `setup_cache` with a cloudpickle-bytes wrapper. ASV's cache |
| 20 | + machinery serialises `setup_cache`'s return value through stdlib |
| 21 | + `pickle`, which can't handle the `MappingProxyType` leaves or |
| 22 | + user-defined callables inside a pylcm `Model`. `setup_cache` returns |
| 23 | + `cloudpickle.dumps(...)` of the `(model, params, initial_conditions)` |
| 24 | + triple; each method's `setup(cache)` calls `cloudpickle.loads` and |
| 25 | + runs a warm simulate. This amortises Python-level model construction |
| 26 | + across `time_execution`, `peakmem_execution`, and |
| 27 | + `track_compilation_time` (~60-120 s saved per ASV run). JAX |
| 28 | + compilation is still per-method — the JIT cache is process-local — |
| 29 | + but the persistent XLA disk cache keeps second and third compiles |
| 30 | + fast. |
| 31 | +- `AcaBaselineGpuPeakMem` runs in a separate subprocess via `_gpu_mem` |
| 32 | + that does not go through ASV's `setup_cache` pipeline. It calls |
| 33 | + `setup_for_gpu_measurement()` (rebuild fresh, no warm-up) then |
| 34 | + `time_execution()` to measure cold peak memory. Both methods |
| 35 | + accept `cache=None` so the same callable serves ASV (cache passed |
| 36 | + in) and the subprocess (cache omitted). |
| 37 | +""" |
| 38 | + |
| 39 | +import gc |
| 40 | +import time |
| 41 | + |
| 42 | +import cloudpickle |
| 43 | + |
| 44 | +from benchmarks import _gpu_mem |
| 45 | + |
| 46 | +_N_SUBJECTS = 1000 |
| 47 | + |
| 48 | + |
| 49 | +def _build() -> tuple[object, object, object]: |
| 50 | + """Build the aca-baseline model, params, and initial conditions.""" |
| 51 | + from aca_model.benchmark import ( |
| 52 | + create_benchmark_model, |
| 53 | + get_benchmark_initial_conditions, |
| 54 | + get_benchmark_params, |
| 55 | + ) |
| 56 | + |
| 57 | + model = create_benchmark_model() |
| 58 | + _, model_params = get_benchmark_params() |
| 59 | + initial_conditions = get_benchmark_initial_conditions( |
| 60 | + model=model, n_subjects=_N_SUBJECTS, seed=0 |
| 61 | + ) |
| 62 | + return model, model_params, initial_conditions |
| 63 | + |
| 64 | + |
| 65 | +class AcaBaseline: |
| 66 | + timeout = 3600 |
| 67 | + # Pin every ASV sample knob to 1 so setup runs once per subprocess |
| 68 | + # and one warm call is timed. `timeout=3600` gives headroom for the |
| 69 | + # cold compile that happens inside setup(cache). |
| 70 | + rounds = 1 |
| 71 | + repeat = 1 |
| 72 | + number = 1 |
| 73 | + warmup_time = 0 |
| 74 | + |
| 75 | + def setup_cache(self) -> bytes: |
| 76 | + # Build once per ASV benchmark class run and hand the result to |
| 77 | + # every method via ASV's setup_cache mechanism. ASV pickles the |
| 78 | + # return value with stdlib `pickle`, which can't handle the |
| 79 | + # `MappingProxyType` leaves or user callables inside a pylcm |
| 80 | + # `Model` — so wrap the triple in cloudpickle bytes. ASV then |
| 81 | + # ships plain bytes; each method's setup(cache) reconstructs. |
| 82 | + return cloudpickle.dumps(_build()) |
| 83 | + |
| 84 | + def setup(self, cache: bytes) -> None: |
| 85 | + self.model, self.model_params, self.initial_conditions = cloudpickle.loads( |
| 86 | + cache |
| 87 | + ) |
| 88 | + # Warm-trigger compilation so time_execution runs on a hot kernel. |
| 89 | + start = time.perf_counter() |
| 90 | + self.model.simulate( |
| 91 | + params=self.model_params, |
| 92 | + initial_conditions=self.initial_conditions, |
| 93 | + period_to_regime_to_V_arr=None, |
| 94 | + log_level="off", |
| 95 | + check_initial_conditions=False, |
| 96 | + ) |
| 97 | + self._compile_time = time.perf_counter() - start |
| 98 | + |
| 99 | + def setup_for_gpu_measurement(self) -> None: |
| 100 | + # Called by the _gpu_mem subprocess; bypasses ASV's setup_cache |
| 101 | + # pipeline so the subprocess can measure cold peak memory |
| 102 | + # (build + compile + run, no warm-up). |
| 103 | + self.model, self.model_params, self.initial_conditions = _build() |
| 104 | + |
| 105 | + def time_execution(self, cache: bytes | None = None) -> None: |
| 106 | + self.model.simulate( |
| 107 | + params=self.model_params, |
| 108 | + initial_conditions=self.initial_conditions, |
| 109 | + period_to_regime_to_V_arr=None, |
| 110 | + log_level="off", |
| 111 | + check_initial_conditions=False, |
| 112 | + ) |
| 113 | + |
| 114 | + def peakmem_execution(self, cache: bytes | None = None) -> None: |
| 115 | + self.model.simulate( |
| 116 | + params=self.model_params, |
| 117 | + initial_conditions=self.initial_conditions, |
| 118 | + period_to_regime_to_V_arr=None, |
| 119 | + log_level="off", |
| 120 | + check_initial_conditions=False, |
| 121 | + ) |
| 122 | + |
| 123 | + def teardown(self, cache: bytes | None = None) -> None: |
| 124 | + import jax |
| 125 | + |
| 126 | + jax.clear_caches() |
| 127 | + gc.collect() |
| 128 | + |
| 129 | + def track_compilation_time(self, cache: bytes | None = None) -> float: |
| 130 | + return self._compile_time |
| 131 | + |
| 132 | + track_compilation_time.unit = "seconds" |
| 133 | + |
| 134 | + |
| 135 | +class AcaBaselineGpuPeakMem(_gpu_mem.GpuPeakMem): |
| 136 | + bench_module = "benchmarks.bench_aca_baseline" |
| 137 | + bench_class = "AcaBaseline" |
| 138 | + timeout = 3600 |
0 commit comments