Skip to content

Commit ed569aa

Browse files
committed
Add continual-shift benchmark for retention vs adaptation
1 parent 77c0d07 commit ed569aa

7 files changed

Lines changed: 809 additions & 3 deletions

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ for versioning even while in research-stage development.
4646
- `.github/CODEOWNERS`
4747
- `docs/model-card.md`
4848
- `docs/figures/README.md`
49+
- Continual-shift comparison benchmark for retention vs adaptation:
50+
- `src/app/continual_shift_benchmark.py`
51+
- `scripts/run_continual_shift_benchmark.py`
52+
- `tests/test_continual_shift_benchmark.py`
53+
- shifted/rotated dataset support in `src/infra/datasets.py`
4954

5055
### Changed
5156

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,12 @@ Toy baseline with review-driven circadian controls:
164164
python predictive_coding_experiment.py --adaptive-sleep-trigger --adaptive-sleep-budget --reward-modulated-learning --reward-scale-min 0.8 --reward-scale-max 1.4
165165
```
166166

167+
Continual shift stress test (retention vs adaptation):
168+
169+
```powershell
170+
python scripts/run_continual_shift_benchmark.py --profile strength-case --seeds 3,7,11,19,23,31,37
171+
```
172+
167173
ResNet benchmark (all 3 models):
168174

169175
```powershell
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Continual Shift Benchmark
2+
-------------------------
3+
Phase A trains on base distribution; phase B trains on shifted/rotated distribution.
4+
Seeds: [3, 7, 11, 19, 23, 31, 37]
5+
Phase B transform: rotation=40.0 deg, translation=(0.90, -0.70)
6+
Phase B train fraction: 0.14
7+
8+
Backprop: A_pre=0.975+/-0.012, A_post=0.960+/-0.022, B_post=0.933+/-0.026, retention=0.985+/-0.026, balanced=0.946+/-0.016
9+
Predictive coding: A_pre=0.970+/-0.016, A_post=0.967+/-0.016, B_post=0.927+/-0.024, retention=0.997+/-0.020, balanced=0.947+/-0.008
10+
Circadian predictive coding: A_pre=0.975+/-0.015, A_post=0.968+/-0.009, B_post=0.930+/-0.023, retention=0.993+/-0.008, balanced=0.949+/-0.010, sleep_events=5.00, splits=5.00, prunes=0.00, hidden_end=17.00
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
"""Run continual-shift benchmark across backprop, predictive coding, and circadian PC."""
2+
3+
from __future__ import annotations
4+
5+
import argparse
6+
from pathlib import Path
7+
import sys
8+
9+
REPO_ROOT = Path(__file__).resolve().parents[1]
10+
if str(REPO_ROOT) not in sys.path:
11+
sys.path.insert(0, str(REPO_ROOT))
12+
13+
from src.app.continual_shift_benchmark import (
14+
ContinualShiftConfig,
15+
format_continual_shift_benchmark,
16+
run_continual_shift_benchmark,
17+
)
18+
from src.core.circadian_predictive_coding import CircadianConfig
19+
20+
21+
def build_parser() -> argparse.ArgumentParser:
22+
"""Build CLI parser for continual shift benchmark."""
23+
parser = argparse.ArgumentParser(
24+
description="Run phase-A/phase-B continual-shift benchmark for all three models."
25+
)
26+
parser.add_argument("--seeds", type=str, default="3,7,11,19,23,31,37")
27+
parser.add_argument(
28+
"--profile",
29+
type=str,
30+
choices=["baseline", "strength-case"],
31+
default="strength-case",
32+
help="Circadian profile: baseline uses defaults, strength-case emphasizes replay/splits.",
33+
)
34+
parser.add_argument("--sample-count-phase-a", type=int, default=500)
35+
parser.add_argument("--sample-count-phase-b", type=int, default=500)
36+
parser.add_argument("--phase-b-train-fraction", type=float, default=0.14)
37+
parser.add_argument("--phase-a-epochs", type=int, default=110)
38+
parser.add_argument("--phase-b-epochs", type=int, default=80)
39+
parser.add_argument("--hidden-dim", type=int, default=12)
40+
parser.add_argument("--phase-a-noise-scale", type=float, default=0.8)
41+
parser.add_argument("--phase-b-noise-scale", type=float, default=1.0)
42+
parser.add_argument("--phase-b-rotation-degrees", type=float, default=40.0)
43+
parser.add_argument("--phase-b-translation-x", type=float, default=0.9)
44+
parser.add_argument("--phase-b-translation-y", type=float, default=-0.7)
45+
parser.add_argument("--sleep-interval-phase-a", type=int, default=40)
46+
parser.add_argument("--sleep-interval-phase-b", type=int, default=8)
47+
parser.add_argument("--output-file", type=str, default="")
48+
return parser
49+
50+
51+
def main() -> None:
52+
"""Run CLI entrypoint."""
53+
parser = build_parser()
54+
args = parser.parse_args()
55+
56+
seeds = _parse_int_list(args.seeds)
57+
circadian_config = (
58+
_build_strength_case_circadian_config()
59+
if args.profile == "strength-case"
60+
else CircadianConfig()
61+
)
62+
config = ContinualShiftConfig(
63+
sample_count_phase_a=args.sample_count_phase_a,
64+
sample_count_phase_b=args.sample_count_phase_b,
65+
phase_b_train_fraction=args.phase_b_train_fraction,
66+
phase_a_epochs=args.phase_a_epochs,
67+
phase_b_epochs=args.phase_b_epochs,
68+
hidden_dim=args.hidden_dim,
69+
phase_a_noise_scale=args.phase_a_noise_scale,
70+
phase_b_noise_scale=args.phase_b_noise_scale,
71+
phase_b_rotation_degrees=args.phase_b_rotation_degrees,
72+
phase_b_translation_x=args.phase_b_translation_x,
73+
phase_b_translation_y=args.phase_b_translation_y,
74+
circadian_sleep_interval_phase_a=args.sleep_interval_phase_a,
75+
circadian_sleep_interval_phase_b=args.sleep_interval_phase_b,
76+
circadian_config=circadian_config,
77+
)
78+
79+
result = run_continual_shift_benchmark(config=config, seeds=seeds)
80+
formatted = format_continual_shift_benchmark(result)
81+
print(formatted)
82+
if args.output_file:
83+
output_path = Path(args.output_file)
84+
output_path.write_text(formatted + "\n", encoding="utf-8")
85+
86+
87+
def _build_strength_case_circadian_config() -> CircadianConfig:
88+
"""Build a practical circadian profile for retention/adaptation stress tests."""
89+
return CircadianConfig(
90+
use_reward_modulated_learning=True,
91+
reward_scale_min=0.8,
92+
reward_scale_max=1.3,
93+
split_threshold=0.30,
94+
prune_threshold=0.04,
95+
max_split_per_sleep=1,
96+
max_prune_per_sleep=0,
97+
replay_steps=2,
98+
replay_memory_size=8,
99+
replay_learning_rate=0.03,
100+
replay_inference_steps=10,
101+
replay_inference_learning_rate=0.12,
102+
)
103+
104+
105+
def _parse_int_list(raw_values: str) -> list[int]:
106+
items = [item.strip() for item in raw_values.split(",") if item.strip()]
107+
if not items:
108+
raise ValueError("Expected at least one integer seed.")
109+
return [int(item) for item in items]
110+
111+
112+
if __name__ == "__main__":
113+
main()

0 commit comments

Comments
 (0)