|
| 1 | +"""Run continual-shift benchmark across backprop, predictive coding, and circadian PC.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import argparse |
| 6 | +from pathlib import Path |
| 7 | +import sys |
| 8 | + |
| 9 | +REPO_ROOT = Path(__file__).resolve().parents[1] |
| 10 | +if str(REPO_ROOT) not in sys.path: |
| 11 | + sys.path.insert(0, str(REPO_ROOT)) |
| 12 | + |
| 13 | +from src.app.continual_shift_benchmark import ( |
| 14 | + ContinualShiftConfig, |
| 15 | + format_continual_shift_benchmark, |
| 16 | + run_continual_shift_benchmark, |
| 17 | +) |
| 18 | +from src.core.circadian_predictive_coding import CircadianConfig |
| 19 | + |
| 20 | + |
| 21 | +def build_parser() -> argparse.ArgumentParser: |
| 22 | + """Build CLI parser for continual shift benchmark.""" |
| 23 | + parser = argparse.ArgumentParser( |
| 24 | + description="Run phase-A/phase-B continual-shift benchmark for all three models." |
| 25 | + ) |
| 26 | + parser.add_argument("--seeds", type=str, default="3,7,11,19,23,31,37") |
| 27 | + parser.add_argument( |
| 28 | + "--profile", |
| 29 | + type=str, |
| 30 | + choices=["baseline", "strength-case"], |
| 31 | + default="strength-case", |
| 32 | + help="Circadian profile: baseline uses defaults, strength-case emphasizes replay/splits.", |
| 33 | + ) |
| 34 | + parser.add_argument("--sample-count-phase-a", type=int, default=500) |
| 35 | + parser.add_argument("--sample-count-phase-b", type=int, default=500) |
| 36 | + parser.add_argument("--phase-b-train-fraction", type=float, default=0.14) |
| 37 | + parser.add_argument("--phase-a-epochs", type=int, default=110) |
| 38 | + parser.add_argument("--phase-b-epochs", type=int, default=80) |
| 39 | + parser.add_argument("--hidden-dim", type=int, default=12) |
| 40 | + parser.add_argument("--phase-a-noise-scale", type=float, default=0.8) |
| 41 | + parser.add_argument("--phase-b-noise-scale", type=float, default=1.0) |
| 42 | + parser.add_argument("--phase-b-rotation-degrees", type=float, default=40.0) |
| 43 | + parser.add_argument("--phase-b-translation-x", type=float, default=0.9) |
| 44 | + parser.add_argument("--phase-b-translation-y", type=float, default=-0.7) |
| 45 | + parser.add_argument("--sleep-interval-phase-a", type=int, default=40) |
| 46 | + parser.add_argument("--sleep-interval-phase-b", type=int, default=8) |
| 47 | + parser.add_argument("--output-file", type=str, default="") |
| 48 | + return parser |
| 49 | + |
| 50 | + |
| 51 | +def main() -> None: |
| 52 | + """Run CLI entrypoint.""" |
| 53 | + parser = build_parser() |
| 54 | + args = parser.parse_args() |
| 55 | + |
| 56 | + seeds = _parse_int_list(args.seeds) |
| 57 | + circadian_config = ( |
| 58 | + _build_strength_case_circadian_config() |
| 59 | + if args.profile == "strength-case" |
| 60 | + else CircadianConfig() |
| 61 | + ) |
| 62 | + config = ContinualShiftConfig( |
| 63 | + sample_count_phase_a=args.sample_count_phase_a, |
| 64 | + sample_count_phase_b=args.sample_count_phase_b, |
| 65 | + phase_b_train_fraction=args.phase_b_train_fraction, |
| 66 | + phase_a_epochs=args.phase_a_epochs, |
| 67 | + phase_b_epochs=args.phase_b_epochs, |
| 68 | + hidden_dim=args.hidden_dim, |
| 69 | + phase_a_noise_scale=args.phase_a_noise_scale, |
| 70 | + phase_b_noise_scale=args.phase_b_noise_scale, |
| 71 | + phase_b_rotation_degrees=args.phase_b_rotation_degrees, |
| 72 | + phase_b_translation_x=args.phase_b_translation_x, |
| 73 | + phase_b_translation_y=args.phase_b_translation_y, |
| 74 | + circadian_sleep_interval_phase_a=args.sleep_interval_phase_a, |
| 75 | + circadian_sleep_interval_phase_b=args.sleep_interval_phase_b, |
| 76 | + circadian_config=circadian_config, |
| 77 | + ) |
| 78 | + |
| 79 | + result = run_continual_shift_benchmark(config=config, seeds=seeds) |
| 80 | + formatted = format_continual_shift_benchmark(result) |
| 81 | + print(formatted) |
| 82 | + if args.output_file: |
| 83 | + output_path = Path(args.output_file) |
| 84 | + output_path.write_text(formatted + "\n", encoding="utf-8") |
| 85 | + |
| 86 | + |
| 87 | +def _build_strength_case_circadian_config() -> CircadianConfig: |
| 88 | + """Build a practical circadian profile for retention/adaptation stress tests.""" |
| 89 | + return CircadianConfig( |
| 90 | + use_reward_modulated_learning=True, |
| 91 | + reward_scale_min=0.8, |
| 92 | + reward_scale_max=1.3, |
| 93 | + split_threshold=0.30, |
| 94 | + prune_threshold=0.04, |
| 95 | + max_split_per_sleep=1, |
| 96 | + max_prune_per_sleep=0, |
| 97 | + replay_steps=2, |
| 98 | + replay_memory_size=8, |
| 99 | + replay_learning_rate=0.03, |
| 100 | + replay_inference_steps=10, |
| 101 | + replay_inference_learning_rate=0.12, |
| 102 | + ) |
| 103 | + |
| 104 | + |
| 105 | +def _parse_int_list(raw_values: str) -> list[int]: |
| 106 | + items = [item.strip() for item in raw_values.split(",") if item.strip()] |
| 107 | + if not items: |
| 108 | + raise ValueError("Expected at least one integer seed.") |
| 109 | + return [int(item) for item in items] |
| 110 | + |
| 111 | + |
| 112 | +if __name__ == "__main__": |
| 113 | + main() |
0 commit comments