|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import argparse |
| 6 | +from dataclasses import dataclass |
6 | 7 | from pathlib import Path |
7 | 8 | import sys |
8 | 9 |
|
|
18 | 19 | from src.core.circadian_predictive_coding import CircadianConfig |
19 | 20 |
|
20 | 21 |
|
| 22 | +@dataclass(frozen=True) |
| 23 | +class ProfileDefaults: |
| 24 | + """Typed defaults for benchmark profile presets.""" |
| 25 | + |
| 26 | + sample_count_phase_a: int |
| 27 | + sample_count_phase_b: int |
| 28 | + phase_b_train_fraction: float |
| 29 | + phase_a_epochs: int |
| 30 | + phase_b_epochs: int |
| 31 | + hidden_dim: int |
| 32 | + phase_a_noise_scale: float |
| 33 | + phase_b_noise_scale: float |
| 34 | + phase_b_rotation_degrees: float |
| 35 | + phase_b_translation_x: float |
| 36 | + phase_b_translation_y: float |
| 37 | + sleep_interval_phase_a: int |
| 38 | + sleep_interval_phase_b: int |
| 39 | + |
| 40 | + |
21 | 41 | def build_parser() -> argparse.ArgumentParser: |
22 | 42 | """Build CLI parser for continual shift benchmark.""" |
23 | 43 | parser = argparse.ArgumentParser( |
@@ -68,38 +88,38 @@ def main() -> None: |
68 | 88 | ) |
69 | 89 | ) |
70 | 90 | config = ContinualShiftConfig( |
71 | | - sample_count_phase_a=_resolve_optional( |
72 | | - args.sample_count_phase_a, profile_defaults["sample_count_phase_a"] |
| 91 | + sample_count_phase_a=_resolve_optional_int( |
| 92 | + args.sample_count_phase_a, profile_defaults.sample_count_phase_a |
73 | 93 | ), |
74 | | - sample_count_phase_b=_resolve_optional( |
75 | | - args.sample_count_phase_b, profile_defaults["sample_count_phase_b"] |
| 94 | + sample_count_phase_b=_resolve_optional_int( |
| 95 | + args.sample_count_phase_b, profile_defaults.sample_count_phase_b |
76 | 96 | ), |
77 | | - phase_b_train_fraction=_resolve_optional( |
78 | | - args.phase_b_train_fraction, profile_defaults["phase_b_train_fraction"] |
| 97 | + phase_b_train_fraction=_resolve_optional_float( |
| 98 | + args.phase_b_train_fraction, profile_defaults.phase_b_train_fraction |
79 | 99 | ), |
80 | | - phase_a_epochs=_resolve_optional(args.phase_a_epochs, profile_defaults["phase_a_epochs"]), |
81 | | - phase_b_epochs=_resolve_optional(args.phase_b_epochs, profile_defaults["phase_b_epochs"]), |
82 | | - hidden_dim=_resolve_optional(args.hidden_dim, profile_defaults["hidden_dim"]), |
83 | | - phase_a_noise_scale=_resolve_optional( |
84 | | - args.phase_a_noise_scale, profile_defaults["phase_a_noise_scale"] |
| 100 | + phase_a_epochs=_resolve_optional_int(args.phase_a_epochs, profile_defaults.phase_a_epochs), |
| 101 | + phase_b_epochs=_resolve_optional_int(args.phase_b_epochs, profile_defaults.phase_b_epochs), |
| 102 | + hidden_dim=_resolve_optional_int(args.hidden_dim, profile_defaults.hidden_dim), |
| 103 | + phase_a_noise_scale=_resolve_optional_float( |
| 104 | + args.phase_a_noise_scale, profile_defaults.phase_a_noise_scale |
85 | 105 | ), |
86 | | - phase_b_noise_scale=_resolve_optional( |
87 | | - args.phase_b_noise_scale, profile_defaults["phase_b_noise_scale"] |
| 106 | + phase_b_noise_scale=_resolve_optional_float( |
| 107 | + args.phase_b_noise_scale, profile_defaults.phase_b_noise_scale |
88 | 108 | ), |
89 | | - phase_b_rotation_degrees=_resolve_optional( |
90 | | - args.phase_b_rotation_degrees, profile_defaults["phase_b_rotation_degrees"] |
| 109 | + phase_b_rotation_degrees=_resolve_optional_float( |
| 110 | + args.phase_b_rotation_degrees, profile_defaults.phase_b_rotation_degrees |
91 | 111 | ), |
92 | | - phase_b_translation_x=_resolve_optional( |
93 | | - args.phase_b_translation_x, profile_defaults["phase_b_translation_x"] |
| 112 | + phase_b_translation_x=_resolve_optional_float( |
| 113 | + args.phase_b_translation_x, profile_defaults.phase_b_translation_x |
94 | 114 | ), |
95 | | - phase_b_translation_y=_resolve_optional( |
96 | | - args.phase_b_translation_y, profile_defaults["phase_b_translation_y"] |
| 115 | + phase_b_translation_y=_resolve_optional_float( |
| 116 | + args.phase_b_translation_y, profile_defaults.phase_b_translation_y |
97 | 117 | ), |
98 | | - circadian_sleep_interval_phase_a=_resolve_optional( |
99 | | - args.sleep_interval_phase_a, profile_defaults["sleep_interval_phase_a"] |
| 118 | + circadian_sleep_interval_phase_a=_resolve_optional_int( |
| 119 | + args.sleep_interval_phase_a, profile_defaults.sleep_interval_phase_a |
100 | 120 | ), |
101 | | - circadian_sleep_interval_phase_b=_resolve_optional( |
102 | | - args.sleep_interval_phase_b, profile_defaults["sleep_interval_phase_b"] |
| 121 | + circadian_sleep_interval_phase_b=_resolve_optional_int( |
| 122 | + args.sleep_interval_phase_b, profile_defaults.sleep_interval_phase_b |
103 | 123 | ), |
104 | 124 | circadian_config=circadian_config, |
105 | 125 | ) |
@@ -150,41 +170,47 @@ def _build_baseline_circadian_config() -> CircadianConfig: |
150 | 170 | return CircadianConfig() |
151 | 171 |
|
152 | 172 |
|
153 | | -def _build_profile_defaults(profile: str) -> dict[str, float | int]: |
| 173 | +def _build_profile_defaults(profile: str) -> ProfileDefaults: |
154 | 174 | if profile == "hardest-case": |
155 | | - return { |
156 | | - "sample_count_phase_a": 500, |
157 | | - "sample_count_phase_b": 500, |
158 | | - "phase_b_train_fraction": 0.08, |
159 | | - "phase_a_epochs": 90, |
160 | | - "phase_b_epochs": 120, |
161 | | - "hidden_dim": 8, |
162 | | - "phase_a_noise_scale": 0.8, |
163 | | - "phase_b_noise_scale": 1.2, |
164 | | - "phase_b_rotation_degrees": 44.0, |
165 | | - "phase_b_translation_x": 0.9, |
166 | | - "phase_b_translation_y": -0.7, |
167 | | - "sleep_interval_phase_a": 40, |
168 | | - "sleep_interval_phase_b": 8, |
169 | | - } |
170 | | - return { |
171 | | - "sample_count_phase_a": 500, |
172 | | - "sample_count_phase_b": 500, |
173 | | - "phase_b_train_fraction": 0.14, |
174 | | - "phase_a_epochs": 110, |
175 | | - "phase_b_epochs": 80, |
176 | | - "hidden_dim": 12, |
177 | | - "phase_a_noise_scale": 0.8, |
178 | | - "phase_b_noise_scale": 1.0, |
179 | | - "phase_b_rotation_degrees": 40.0, |
180 | | - "phase_b_translation_x": 0.9, |
181 | | - "phase_b_translation_y": -0.7, |
182 | | - "sleep_interval_phase_a": 40, |
183 | | - "sleep_interval_phase_b": 8, |
184 | | - } |
185 | | - |
186 | | - |
187 | | -def _resolve_optional(value: int | float | None, fallback: int | float) -> int | float: |
| 175 | + return ProfileDefaults( |
| 176 | + sample_count_phase_a=500, |
| 177 | + sample_count_phase_b=500, |
| 178 | + phase_b_train_fraction=0.08, |
| 179 | + phase_a_epochs=90, |
| 180 | + phase_b_epochs=120, |
| 181 | + hidden_dim=8, |
| 182 | + phase_a_noise_scale=0.8, |
| 183 | + phase_b_noise_scale=1.2, |
| 184 | + phase_b_rotation_degrees=44.0, |
| 185 | + phase_b_translation_x=0.9, |
| 186 | + phase_b_translation_y=-0.7, |
| 187 | + sleep_interval_phase_a=40, |
| 188 | + sleep_interval_phase_b=8, |
| 189 | + ) |
| 190 | + return ProfileDefaults( |
| 191 | + sample_count_phase_a=500, |
| 192 | + sample_count_phase_b=500, |
| 193 | + phase_b_train_fraction=0.14, |
| 194 | + phase_a_epochs=110, |
| 195 | + phase_b_epochs=80, |
| 196 | + hidden_dim=12, |
| 197 | + phase_a_noise_scale=0.8, |
| 198 | + phase_b_noise_scale=1.0, |
| 199 | + phase_b_rotation_degrees=40.0, |
| 200 | + phase_b_translation_x=0.9, |
| 201 | + phase_b_translation_y=-0.7, |
| 202 | + sleep_interval_phase_a=40, |
| 203 | + sleep_interval_phase_b=8, |
| 204 | + ) |
| 205 | + |
| 206 | + |
| 207 | +def _resolve_optional_int(value: int | None, fallback: int) -> int: |
| 208 | + if value is None: |
| 209 | + return fallback |
| 210 | + return value |
| 211 | + |
| 212 | + |
| 213 | +def _resolve_optional_float(value: float | None, fallback: float) -> float: |
188 | 214 | if value is None: |
189 | 215 | return fallback |
190 | 216 | return value |
|
0 commit comments