Skip to content

Commit 54c0247

Browse files
committed
Add hardest continual-shift benchmark profile and results
1 parent ed569aa commit 54c0247

5 files changed

Lines changed: 143 additions & 31 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ for versioning even while in research-stage development.
5151
- `scripts/run_continual_shift_benchmark.py`
5252
- `tests/test_continual_shift_benchmark.py`
5353
- shifted/rotated dataset support in `src/infra/datasets.py`
54+
- new `hardest-case` profile in continual-shift CLI for a stronger stress scenario
5455

5556
### Changed
5657

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,12 @@ Continual shift stress test (retention vs adaptation):
170170
python scripts/run_continual_shift_benchmark.py --profile strength-case --seeds 3,7,11,19,23,31,37
171171
```
172172

173+
Hardest continual-shift stress test (small starting capacity + heavy drift):
174+
175+
```powershell
176+
python scripts/run_continual_shift_benchmark.py --profile hardest-case --seeds 3,7,11,19,23,31,37
177+
```
178+
173179
ResNet benchmark (all 3 models):
174180

175181
```powershell
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
Continual Shift Benchmark
2+
-------------------------
3+
Phase A trains on base distribution; phase B trains on shifted/rotated distribution.
4+
Seeds: [3, 7, 11, 19, 23, 31, 37]
5+
Setup: hidden_dim=8, phaseA_epochs=90, phaseB_epochs=120, phaseA_noise=0.80, phaseB_noise=1.20
6+
Phase B transform: rotation=44.0 deg, translation=(0.90, -0.70)
7+
Phase B train fraction: 0.08
8+
9+
Backprop: A_pre=0.975+/-0.014, A_post=0.888+/-0.081, B_post=0.889+/-0.031, retention=0.911+/-0.083, balanced=0.889+/-0.052
10+
Predictive coding: A_pre=0.979+/-0.013, A_post=0.958+/-0.018, B_post=0.874+/-0.028, retention=0.978+/-0.021, balanced=0.916+/-0.015
11+
Circadian predictive coding: A_pre=0.973+/-0.015, A_post=0.967+/-0.017, B_post=0.878+/-0.031, retention=0.994+/-0.021, balanced=0.922+/-0.014, sleep_events=6.29, splits=6.29, prunes=0.00, hidden_end=14.29

scripts/run_continual_shift_benchmark.py

Lines changed: 119 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,26 @@ def build_parser() -> argparse.ArgumentParser:
2727
parser.add_argument(
2828
"--profile",
2929
type=str,
30-
choices=["baseline", "strength-case"],
30+
choices=["baseline", "strength-case", "hardest-case"],
3131
default="strength-case",
32-
help="Circadian profile: baseline uses defaults, strength-case emphasizes replay/splits.",
32+
help=(
33+
"baseline: circadian defaults, strength-case: tuned moderate stress, "
34+
"hardest-case: aggressively difficult shift with tuned circadian policy."
35+
),
3336
)
34-
parser.add_argument("--sample-count-phase-a", type=int, default=500)
35-
parser.add_argument("--sample-count-phase-b", type=int, default=500)
36-
parser.add_argument("--phase-b-train-fraction", type=float, default=0.14)
37-
parser.add_argument("--phase-a-epochs", type=int, default=110)
38-
parser.add_argument("--phase-b-epochs", type=int, default=80)
39-
parser.add_argument("--hidden-dim", type=int, default=12)
40-
parser.add_argument("--phase-a-noise-scale", type=float, default=0.8)
41-
parser.add_argument("--phase-b-noise-scale", type=float, default=1.0)
42-
parser.add_argument("--phase-b-rotation-degrees", type=float, default=40.0)
43-
parser.add_argument("--phase-b-translation-x", type=float, default=0.9)
44-
parser.add_argument("--phase-b-translation-y", type=float, default=-0.7)
45-
parser.add_argument("--sleep-interval-phase-a", type=int, default=40)
46-
parser.add_argument("--sleep-interval-phase-b", type=int, default=8)
37+
parser.add_argument("--sample-count-phase-a", type=int, default=None)
38+
parser.add_argument("--sample-count-phase-b", type=int, default=None)
39+
parser.add_argument("--phase-b-train-fraction", type=float, default=None)
40+
parser.add_argument("--phase-a-epochs", type=int, default=None)
41+
parser.add_argument("--phase-b-epochs", type=int, default=None)
42+
parser.add_argument("--hidden-dim", type=int, default=None)
43+
parser.add_argument("--phase-a-noise-scale", type=float, default=None)
44+
parser.add_argument("--phase-b-noise-scale", type=float, default=None)
45+
parser.add_argument("--phase-b-rotation-degrees", type=float, default=None)
46+
parser.add_argument("--phase-b-translation-x", type=float, default=None)
47+
parser.add_argument("--phase-b-translation-y", type=float, default=None)
48+
parser.add_argument("--sleep-interval-phase-a", type=int, default=None)
49+
parser.add_argument("--sleep-interval-phase-b", type=int, default=None)
4750
parser.add_argument("--output-file", type=str, default="")
4851
return parser
4952

@@ -54,25 +57,50 @@ def main() -> None:
5457
args = parser.parse_args()
5558

5659
seeds = _parse_int_list(args.seeds)
60+
profile_defaults = _build_profile_defaults(args.profile)
5761
circadian_config = (
58-
_build_strength_case_circadian_config()
59-
if args.profile == "strength-case"
60-
else CircadianConfig()
62+
_build_baseline_circadian_config()
63+
if args.profile == "baseline"
64+
else (
65+
_build_strength_case_circadian_config()
66+
if args.profile == "strength-case"
67+
else _build_hardest_case_circadian_config()
68+
)
6169
)
6270
config = ContinualShiftConfig(
63-
sample_count_phase_a=args.sample_count_phase_a,
64-
sample_count_phase_b=args.sample_count_phase_b,
65-
phase_b_train_fraction=args.phase_b_train_fraction,
66-
phase_a_epochs=args.phase_a_epochs,
67-
phase_b_epochs=args.phase_b_epochs,
68-
hidden_dim=args.hidden_dim,
69-
phase_a_noise_scale=args.phase_a_noise_scale,
70-
phase_b_noise_scale=args.phase_b_noise_scale,
71-
phase_b_rotation_degrees=args.phase_b_rotation_degrees,
72-
phase_b_translation_x=args.phase_b_translation_x,
73-
phase_b_translation_y=args.phase_b_translation_y,
74-
circadian_sleep_interval_phase_a=args.sleep_interval_phase_a,
75-
circadian_sleep_interval_phase_b=args.sleep_interval_phase_b,
71+
sample_count_phase_a=_resolve_optional(
72+
args.sample_count_phase_a, profile_defaults["sample_count_phase_a"]
73+
),
74+
sample_count_phase_b=_resolve_optional(
75+
args.sample_count_phase_b, profile_defaults["sample_count_phase_b"]
76+
),
77+
phase_b_train_fraction=_resolve_optional(
78+
args.phase_b_train_fraction, profile_defaults["phase_b_train_fraction"]
79+
),
80+
phase_a_epochs=_resolve_optional(args.phase_a_epochs, profile_defaults["phase_a_epochs"]),
81+
phase_b_epochs=_resolve_optional(args.phase_b_epochs, profile_defaults["phase_b_epochs"]),
82+
hidden_dim=_resolve_optional(args.hidden_dim, profile_defaults["hidden_dim"]),
83+
phase_a_noise_scale=_resolve_optional(
84+
args.phase_a_noise_scale, profile_defaults["phase_a_noise_scale"]
85+
),
86+
phase_b_noise_scale=_resolve_optional(
87+
args.phase_b_noise_scale, profile_defaults["phase_b_noise_scale"]
88+
),
89+
phase_b_rotation_degrees=_resolve_optional(
90+
args.phase_b_rotation_degrees, profile_defaults["phase_b_rotation_degrees"]
91+
),
92+
phase_b_translation_x=_resolve_optional(
93+
args.phase_b_translation_x, profile_defaults["phase_b_translation_x"]
94+
),
95+
phase_b_translation_y=_resolve_optional(
96+
args.phase_b_translation_y, profile_defaults["phase_b_translation_y"]
97+
),
98+
circadian_sleep_interval_phase_a=_resolve_optional(
99+
args.sleep_interval_phase_a, profile_defaults["sleep_interval_phase_a"]
100+
),
101+
circadian_sleep_interval_phase_b=_resolve_optional(
102+
args.sleep_interval_phase_b, profile_defaults["sleep_interval_phase_b"]
103+
),
76104
circadian_config=circadian_config,
77105
)
78106

@@ -102,6 +130,66 @@ def _build_strength_case_circadian_config() -> CircadianConfig:
102130
)
103131

104132

133+
def _build_hardest_case_circadian_config() -> CircadianConfig:
134+
"""Build circadian profile tuned for the hardest continual-shift setup."""
135+
return CircadianConfig(
136+
use_reward_modulated_learning=False,
137+
split_threshold=0.25,
138+
prune_threshold=0.04,
139+
max_split_per_sleep=1,
140+
max_prune_per_sleep=0,
141+
replay_steps=2,
142+
replay_memory_size=10,
143+
replay_learning_rate=0.04,
144+
replay_inference_steps=12,
145+
replay_inference_learning_rate=0.14,
146+
)
147+
148+
149+
def _build_baseline_circadian_config() -> CircadianConfig:
150+
return CircadianConfig()
151+
152+
153+
def _build_profile_defaults(profile: str) -> dict[str, float | int]:
154+
if profile == "hardest-case":
155+
return {
156+
"sample_count_phase_a": 500,
157+
"sample_count_phase_b": 500,
158+
"phase_b_train_fraction": 0.08,
159+
"phase_a_epochs": 90,
160+
"phase_b_epochs": 120,
161+
"hidden_dim": 8,
162+
"phase_a_noise_scale": 0.8,
163+
"phase_b_noise_scale": 1.2,
164+
"phase_b_rotation_degrees": 44.0,
165+
"phase_b_translation_x": 0.9,
166+
"phase_b_translation_y": -0.7,
167+
"sleep_interval_phase_a": 40,
168+
"sleep_interval_phase_b": 8,
169+
}
170+
return {
171+
"sample_count_phase_a": 500,
172+
"sample_count_phase_b": 500,
173+
"phase_b_train_fraction": 0.14,
174+
"phase_a_epochs": 110,
175+
"phase_b_epochs": 80,
176+
"hidden_dim": 12,
177+
"phase_a_noise_scale": 0.8,
178+
"phase_b_noise_scale": 1.0,
179+
"phase_b_rotation_degrees": 40.0,
180+
"phase_b_translation_x": 0.9,
181+
"phase_b_translation_y": -0.7,
182+
"sleep_interval_phase_a": 40,
183+
"sleep_interval_phase_b": 8,
184+
}
185+
186+
187+
def _resolve_optional(value: int | float | None, fallback: int | float) -> int | float:
188+
if value is None:
189+
return fallback
190+
return value
191+
192+
105193
def _parse_int_list(raw_values: str) -> list[int]:
106194
items = [item.strip() for item in raw_values.split(",") if item.strip()]
107195
if not items:

src/app/continual_shift_benchmark.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,12 @@ def format_continual_shift_benchmark(result: ContinualShiftBenchmarkResult) -> s
177177
"-------------------------",
178178
"Phase A trains on base distribution; phase B trains on shifted/rotated distribution.",
179179
f"Seeds: {result.seeds}",
180+
(
181+
"Setup: "
182+
f"hidden_dim={config.hidden_dim}, "
183+
f"phaseA_epochs={config.phase_a_epochs}, phaseB_epochs={config.phase_b_epochs}, "
184+
f"phaseA_noise={config.phase_a_noise_scale:.2f}, phaseB_noise={config.phase_b_noise_scale:.2f}"
185+
),
180186
(
181187
"Phase B transform: "
182188
f"rotation={config.phase_b_rotation_degrees:.1f} deg, "

0 commit comments

Comments
 (0)