Skip to content

Commit 338d254

Browse files
committed
Fix mypy typing for continual-shift profile defaults
1 parent 54c0247 commit 338d254

1 file changed

Lines changed: 83 additions & 57 deletions

File tree

scripts/run_continual_shift_benchmark.py

Lines changed: 83 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import argparse
6+
from dataclasses import dataclass
67
from pathlib import Path
78
import sys
89

@@ -18,6 +19,25 @@
1819
from src.core.circadian_predictive_coding import CircadianConfig
1920

2021

22+
@dataclass(frozen=True)
23+
class ProfileDefaults:
24+
"""Typed defaults for benchmark profile presets."""
25+
26+
sample_count_phase_a: int
27+
sample_count_phase_b: int
28+
phase_b_train_fraction: float
29+
phase_a_epochs: int
30+
phase_b_epochs: int
31+
hidden_dim: int
32+
phase_a_noise_scale: float
33+
phase_b_noise_scale: float
34+
phase_b_rotation_degrees: float
35+
phase_b_translation_x: float
36+
phase_b_translation_y: float
37+
sleep_interval_phase_a: int
38+
sleep_interval_phase_b: int
39+
40+
2141
def build_parser() -> argparse.ArgumentParser:
2242
"""Build CLI parser for continual shift benchmark."""
2343
parser = argparse.ArgumentParser(
@@ -68,38 +88,38 @@ def main() -> None:
6888
)
6989
)
7090
config = ContinualShiftConfig(
71-
sample_count_phase_a=_resolve_optional(
72-
args.sample_count_phase_a, profile_defaults["sample_count_phase_a"]
91+
sample_count_phase_a=_resolve_optional_int(
92+
args.sample_count_phase_a, profile_defaults.sample_count_phase_a
7393
),
74-
sample_count_phase_b=_resolve_optional(
75-
args.sample_count_phase_b, profile_defaults["sample_count_phase_b"]
94+
sample_count_phase_b=_resolve_optional_int(
95+
args.sample_count_phase_b, profile_defaults.sample_count_phase_b
7696
),
77-
phase_b_train_fraction=_resolve_optional(
78-
args.phase_b_train_fraction, profile_defaults["phase_b_train_fraction"]
97+
phase_b_train_fraction=_resolve_optional_float(
98+
args.phase_b_train_fraction, profile_defaults.phase_b_train_fraction
7999
),
80-
phase_a_epochs=_resolve_optional(args.phase_a_epochs, profile_defaults["phase_a_epochs"]),
81-
phase_b_epochs=_resolve_optional(args.phase_b_epochs, profile_defaults["phase_b_epochs"]),
82-
hidden_dim=_resolve_optional(args.hidden_dim, profile_defaults["hidden_dim"]),
83-
phase_a_noise_scale=_resolve_optional(
84-
args.phase_a_noise_scale, profile_defaults["phase_a_noise_scale"]
100+
phase_a_epochs=_resolve_optional_int(args.phase_a_epochs, profile_defaults.phase_a_epochs),
101+
phase_b_epochs=_resolve_optional_int(args.phase_b_epochs, profile_defaults.phase_b_epochs),
102+
hidden_dim=_resolve_optional_int(args.hidden_dim, profile_defaults.hidden_dim),
103+
phase_a_noise_scale=_resolve_optional_float(
104+
args.phase_a_noise_scale, profile_defaults.phase_a_noise_scale
85105
),
86-
phase_b_noise_scale=_resolve_optional(
87-
args.phase_b_noise_scale, profile_defaults["phase_b_noise_scale"]
106+
phase_b_noise_scale=_resolve_optional_float(
107+
args.phase_b_noise_scale, profile_defaults.phase_b_noise_scale
88108
),
89-
phase_b_rotation_degrees=_resolve_optional(
90-
args.phase_b_rotation_degrees, profile_defaults["phase_b_rotation_degrees"]
109+
phase_b_rotation_degrees=_resolve_optional_float(
110+
args.phase_b_rotation_degrees, profile_defaults.phase_b_rotation_degrees
91111
),
92-
phase_b_translation_x=_resolve_optional(
93-
args.phase_b_translation_x, profile_defaults["phase_b_translation_x"]
112+
phase_b_translation_x=_resolve_optional_float(
113+
args.phase_b_translation_x, profile_defaults.phase_b_translation_x
94114
),
95-
phase_b_translation_y=_resolve_optional(
96-
args.phase_b_translation_y, profile_defaults["phase_b_translation_y"]
115+
phase_b_translation_y=_resolve_optional_float(
116+
args.phase_b_translation_y, profile_defaults.phase_b_translation_y
97117
),
98-
circadian_sleep_interval_phase_a=_resolve_optional(
99-
args.sleep_interval_phase_a, profile_defaults["sleep_interval_phase_a"]
118+
circadian_sleep_interval_phase_a=_resolve_optional_int(
119+
args.sleep_interval_phase_a, profile_defaults.sleep_interval_phase_a
100120
),
101-
circadian_sleep_interval_phase_b=_resolve_optional(
102-
args.sleep_interval_phase_b, profile_defaults["sleep_interval_phase_b"]
121+
circadian_sleep_interval_phase_b=_resolve_optional_int(
122+
args.sleep_interval_phase_b, profile_defaults.sleep_interval_phase_b
103123
),
104124
circadian_config=circadian_config,
105125
)
@@ -150,41 +170,47 @@ def _build_baseline_circadian_config() -> CircadianConfig:
150170
return CircadianConfig()
151171

152172

153-
def _build_profile_defaults(profile: str) -> dict[str, float | int]:
173+
def _build_profile_defaults(profile: str) -> ProfileDefaults:
154174
if profile == "hardest-case":
155-
return {
156-
"sample_count_phase_a": 500,
157-
"sample_count_phase_b": 500,
158-
"phase_b_train_fraction": 0.08,
159-
"phase_a_epochs": 90,
160-
"phase_b_epochs": 120,
161-
"hidden_dim": 8,
162-
"phase_a_noise_scale": 0.8,
163-
"phase_b_noise_scale": 1.2,
164-
"phase_b_rotation_degrees": 44.0,
165-
"phase_b_translation_x": 0.9,
166-
"phase_b_translation_y": -0.7,
167-
"sleep_interval_phase_a": 40,
168-
"sleep_interval_phase_b": 8,
169-
}
170-
return {
171-
"sample_count_phase_a": 500,
172-
"sample_count_phase_b": 500,
173-
"phase_b_train_fraction": 0.14,
174-
"phase_a_epochs": 110,
175-
"phase_b_epochs": 80,
176-
"hidden_dim": 12,
177-
"phase_a_noise_scale": 0.8,
178-
"phase_b_noise_scale": 1.0,
179-
"phase_b_rotation_degrees": 40.0,
180-
"phase_b_translation_x": 0.9,
181-
"phase_b_translation_y": -0.7,
182-
"sleep_interval_phase_a": 40,
183-
"sleep_interval_phase_b": 8,
184-
}
185-
186-
187-
def _resolve_optional(value: int | float | None, fallback: int | float) -> int | float:
175+
return ProfileDefaults(
176+
sample_count_phase_a=500,
177+
sample_count_phase_b=500,
178+
phase_b_train_fraction=0.08,
179+
phase_a_epochs=90,
180+
phase_b_epochs=120,
181+
hidden_dim=8,
182+
phase_a_noise_scale=0.8,
183+
phase_b_noise_scale=1.2,
184+
phase_b_rotation_degrees=44.0,
185+
phase_b_translation_x=0.9,
186+
phase_b_translation_y=-0.7,
187+
sleep_interval_phase_a=40,
188+
sleep_interval_phase_b=8,
189+
)
190+
return ProfileDefaults(
191+
sample_count_phase_a=500,
192+
sample_count_phase_b=500,
193+
phase_b_train_fraction=0.14,
194+
phase_a_epochs=110,
195+
phase_b_epochs=80,
196+
hidden_dim=12,
197+
phase_a_noise_scale=0.8,
198+
phase_b_noise_scale=1.0,
199+
phase_b_rotation_degrees=40.0,
200+
phase_b_translation_x=0.9,
201+
phase_b_translation_y=-0.7,
202+
sleep_interval_phase_a=40,
203+
sleep_interval_phase_b=8,
204+
)
205+
206+
207+
def _resolve_optional_int(value: int | None, fallback: int) -> int:
208+
if value is None:
209+
return fallback
210+
return value
211+
212+
213+
def _resolve_optional_float(value: float | None, fallback: float) -> float:
188214
if value is None:
189215
return fallback
190216
return value

0 commit comments

Comments
 (0)