@@ -27,23 +27,26 @@ def build_parser() -> argparse.ArgumentParser:
2727 parser .add_argument (
2828 "--profile" ,
2929 type = str ,
30- choices = ["baseline" , "strength-case" ],
30+ choices = ["baseline" , "strength-case" , "hardest-case" ],
3131 default = "strength-case" ,
32- help = "Circadian profile: baseline uses defaults, strength-case emphasizes replay/splits." ,
32+ help = (
33+ "baseline: circadian defaults, strength-case: tuned moderate stress, "
34+ "hardest-case: aggressively difficult shift with tuned circadian policy."
35+ ),
3336 )
34- parser .add_argument ("--sample-count-phase-a" , type = int , default = 500 )
35- parser .add_argument ("--sample-count-phase-b" , type = int , default = 500 )
36- parser .add_argument ("--phase-b-train-fraction" , type = float , default = 0.14 )
37- parser .add_argument ("--phase-a-epochs" , type = int , default = 110 )
38- parser .add_argument ("--phase-b-epochs" , type = int , default = 80 )
39- parser .add_argument ("--hidden-dim" , type = int , default = 12 )
40- parser .add_argument ("--phase-a-noise-scale" , type = float , default = 0.8 )
41- parser .add_argument ("--phase-b-noise-scale" , type = float , default = 1.0 )
42- parser .add_argument ("--phase-b-rotation-degrees" , type = float , default = 40.0 )
43- parser .add_argument ("--phase-b-translation-x" , type = float , default = 0.9 )
44- parser .add_argument ("--phase-b-translation-y" , type = float , default = - 0.7 )
45- parser .add_argument ("--sleep-interval-phase-a" , type = int , default = 40 )
46- parser .add_argument ("--sleep-interval-phase-b" , type = int , default = 8 )
37+ parser .add_argument ("--sample-count-phase-a" , type = int , default = None )
38+ parser .add_argument ("--sample-count-phase-b" , type = int , default = None )
39+ parser .add_argument ("--phase-b-train-fraction" , type = float , default = None )
40+ parser .add_argument ("--phase-a-epochs" , type = int , default = None )
41+ parser .add_argument ("--phase-b-epochs" , type = int , default = None )
42+ parser .add_argument ("--hidden-dim" , type = int , default = None )
43+ parser .add_argument ("--phase-a-noise-scale" , type = float , default = None )
44+ parser .add_argument ("--phase-b-noise-scale" , type = float , default = None )
45+ parser .add_argument ("--phase-b-rotation-degrees" , type = float , default = None )
46+ parser .add_argument ("--phase-b-translation-x" , type = float , default = None )
47+ parser .add_argument ("--phase-b-translation-y" , type = float , default = None )
48+ parser .add_argument ("--sleep-interval-phase-a" , type = int , default = None )
49+ parser .add_argument ("--sleep-interval-phase-b" , type = int , default = None )
4750 parser .add_argument ("--output-file" , type = str , default = "" )
4851 return parser
4952
@@ -54,25 +57,50 @@ def main() -> None:
5457 args = parser .parse_args ()
5558
5659 seeds = _parse_int_list (args .seeds )
60+ profile_defaults = _build_profile_defaults (args .profile )
5761 circadian_config = (
58- _build_strength_case_circadian_config ()
59- if args .profile == "strength-case"
60- else CircadianConfig ()
62+ _build_baseline_circadian_config ()
63+ if args .profile == "baseline"
64+ else (
65+ _build_strength_case_circadian_config ()
66+ if args .profile == "strength-case"
67+ else _build_hardest_case_circadian_config ()
68+ )
6169 )
6270 config = ContinualShiftConfig (
63- sample_count_phase_a = args .sample_count_phase_a ,
64- sample_count_phase_b = args .sample_count_phase_b ,
65- phase_b_train_fraction = args .phase_b_train_fraction ,
66- phase_a_epochs = args .phase_a_epochs ,
67- phase_b_epochs = args .phase_b_epochs ,
68- hidden_dim = args .hidden_dim ,
69- phase_a_noise_scale = args .phase_a_noise_scale ,
70- phase_b_noise_scale = args .phase_b_noise_scale ,
71- phase_b_rotation_degrees = args .phase_b_rotation_degrees ,
72- phase_b_translation_x = args .phase_b_translation_x ,
73- phase_b_translation_y = args .phase_b_translation_y ,
74- circadian_sleep_interval_phase_a = args .sleep_interval_phase_a ,
75- circadian_sleep_interval_phase_b = args .sleep_interval_phase_b ,
71+ sample_count_phase_a = _resolve_optional (
72+ args .sample_count_phase_a , profile_defaults ["sample_count_phase_a" ]
73+ ),
74+ sample_count_phase_b = _resolve_optional (
75+ args .sample_count_phase_b , profile_defaults ["sample_count_phase_b" ]
76+ ),
77+ phase_b_train_fraction = _resolve_optional (
78+ args .phase_b_train_fraction , profile_defaults ["phase_b_train_fraction" ]
79+ ),
80+ phase_a_epochs = _resolve_optional (args .phase_a_epochs , profile_defaults ["phase_a_epochs" ]),
81+ phase_b_epochs = _resolve_optional (args .phase_b_epochs , profile_defaults ["phase_b_epochs" ]),
82+ hidden_dim = _resolve_optional (args .hidden_dim , profile_defaults ["hidden_dim" ]),
83+ phase_a_noise_scale = _resolve_optional (
84+ args .phase_a_noise_scale , profile_defaults ["phase_a_noise_scale" ]
85+ ),
86+ phase_b_noise_scale = _resolve_optional (
87+ args .phase_b_noise_scale , profile_defaults ["phase_b_noise_scale" ]
88+ ),
89+ phase_b_rotation_degrees = _resolve_optional (
90+ args .phase_b_rotation_degrees , profile_defaults ["phase_b_rotation_degrees" ]
91+ ),
92+ phase_b_translation_x = _resolve_optional (
93+ args .phase_b_translation_x , profile_defaults ["phase_b_translation_x" ]
94+ ),
95+ phase_b_translation_y = _resolve_optional (
96+ args .phase_b_translation_y , profile_defaults ["phase_b_translation_y" ]
97+ ),
98+ circadian_sleep_interval_phase_a = _resolve_optional (
99+ args .sleep_interval_phase_a , profile_defaults ["sleep_interval_phase_a" ]
100+ ),
101+ circadian_sleep_interval_phase_b = _resolve_optional (
102+ args .sleep_interval_phase_b , profile_defaults ["sleep_interval_phase_b" ]
103+ ),
76104 circadian_config = circadian_config ,
77105 )
78106
@@ -102,6 +130,66 @@ def _build_strength_case_circadian_config() -> CircadianConfig:
102130 )
103131
104132
133+ def _build_hardest_case_circadian_config () -> CircadianConfig :
134+ """Build circadian profile tuned for the hardest continual-shift setup."""
135+ return CircadianConfig (
136+ use_reward_modulated_learning = False ,
137+ split_threshold = 0.25 ,
138+ prune_threshold = 0.04 ,
139+ max_split_per_sleep = 1 ,
140+ max_prune_per_sleep = 0 ,
141+ replay_steps = 2 ,
142+ replay_memory_size = 10 ,
143+ replay_learning_rate = 0.04 ,
144+ replay_inference_steps = 12 ,
145+ replay_inference_learning_rate = 0.14 ,
146+ )
147+
148+
149+ def _build_baseline_circadian_config () -> CircadianConfig :
150+ return CircadianConfig ()
151+
152+
153+ def _build_profile_defaults (profile : str ) -> dict [str , float | int ]:
154+ if profile == "hardest-case" :
155+ return {
156+ "sample_count_phase_a" : 500 ,
157+ "sample_count_phase_b" : 500 ,
158+ "phase_b_train_fraction" : 0.08 ,
159+ "phase_a_epochs" : 90 ,
160+ "phase_b_epochs" : 120 ,
161+ "hidden_dim" : 8 ,
162+ "phase_a_noise_scale" : 0.8 ,
163+ "phase_b_noise_scale" : 1.2 ,
164+ "phase_b_rotation_degrees" : 44.0 ,
165+ "phase_b_translation_x" : 0.9 ,
166+ "phase_b_translation_y" : - 0.7 ,
167+ "sleep_interval_phase_a" : 40 ,
168+ "sleep_interval_phase_b" : 8 ,
169+ }
170+ return {
171+ "sample_count_phase_a" : 500 ,
172+ "sample_count_phase_b" : 500 ,
173+ "phase_b_train_fraction" : 0.14 ,
174+ "phase_a_epochs" : 110 ,
175+ "phase_b_epochs" : 80 ,
176+ "hidden_dim" : 12 ,
177+ "phase_a_noise_scale" : 0.8 ,
178+ "phase_b_noise_scale" : 1.0 ,
179+ "phase_b_rotation_degrees" : 40.0 ,
180+ "phase_b_translation_x" : 0.9 ,
181+ "phase_b_translation_y" : - 0.7 ,
182+ "sleep_interval_phase_a" : 40 ,
183+ "sleep_interval_phase_b" : 8 ,
184+ }
185+
186+
187+ def _resolve_optional (value : int | float | None , fallback : int | float ) -> int | float :
188+ if value is None :
189+ return fallback
190+ return value
191+
192+
105193def _parse_int_list (raw_values : str ) -> list [int ]:
106194 items = [item .strip () for item in raw_values .split ("," ) if item .strip ()]
107195 if not items :
0 commit comments