|
16 | 16 |
|
17 | 17 | SCRIPT_DIR = Path(__file__).resolve().parent |
18 | 18 | RUNNER_PATH = SCRIPT_DIR / "run_household_projection.py" |
| 19 | +DEFAULT_SUPPORT_AUGMENTATION_START_YEAR = 2075 |
| 20 | +SUPPORT_AUGMENTATION_VALUE_FLAGS = { |
| 21 | + "--support-augmentation-profile", |
| 22 | + "--support-augmentation-target-year", |
| 23 | + "--support-augmentation-start-year", |
| 24 | + "--support-augmentation-top-n-targets", |
| 25 | + "--support-augmentation-donors-per-target", |
| 26 | + "--support-augmentation-max-distance", |
| 27 | + "--support-augmentation-clone-weight-scale", |
| 28 | + "--support-augmentation-blueprint-base-weight-scale", |
| 29 | +} |
| 30 | +SUPPORT_AUGMENTATION_BOOLEAN_FLAGS = { |
| 31 | + "--support-augmentation-align-to-run-year", |
| 32 | + "--support-augmentation-sanitize-worker-non-target-income", |
| 33 | + "--support-augmentation-sanitize-clone-non-target-income", |
| 34 | +} |
19 | 35 |
|
20 | 36 |
|
21 | 37 | def parse_years(spec: str) -> list[int]: |
@@ -81,6 +97,53 @@ def validate_forwarded_args(forwarded_args: list[str]) -> None: |
81 | 97 | ) |
82 | 98 |
|
83 | 99 |
|
| 100 | +def _option_value(args: list[str], flag: str) -> str | None: |
| 101 | + if flag not in args: |
| 102 | + return None |
| 103 | + index = args.index(flag) |
| 104 | + if index + 1 >= len(args): |
| 105 | + raise ValueError(f"{flag} requires a value") |
| 106 | + return args[index + 1] |
| 107 | + |
| 108 | + |
| 109 | +def _has_support_augmentation_profile(args: list[str]) -> bool: |
| 110 | + return "--support-augmentation-profile" in args |
| 111 | + |
| 112 | + |
| 113 | +def _support_augmentation_start_year(args: list[str]) -> int: |
| 114 | + raw_value = _option_value(args, "--support-augmentation-start-year") |
| 115 | + if raw_value is None: |
| 116 | + return DEFAULT_SUPPORT_AUGMENTATION_START_YEAR |
| 117 | + return int(raw_value) |
| 118 | + |
| 119 | + |
| 120 | +def _strip_support_augmentation_args(args: list[str]) -> list[str]: |
| 121 | + stripped: list[str] = [] |
| 122 | + index = 0 |
| 123 | + while index < len(args): |
| 124 | + arg = args[index] |
| 125 | + if arg in SUPPORT_AUGMENTATION_VALUE_FLAGS: |
| 126 | + if index + 1 >= len(args): |
| 127 | + raise ValueError(f"{arg} requires a value") |
| 128 | + index += 2 |
| 129 | + continue |
| 130 | + if arg in SUPPORT_AUGMENTATION_BOOLEAN_FLAGS: |
| 131 | + index += 1 |
| 132 | + continue |
| 133 | + stripped.append(arg) |
| 134 | + index += 1 |
| 135 | + return stripped |
| 136 | + |
| 137 | + |
| 138 | +def forwarded_args_for_year(year: int, forwarded_args: list[str]) -> list[str]: |
| 139 | + """Return runner args with late-year support disabled before activation.""" |
| 140 | + if not _has_support_augmentation_profile(forwarded_args): |
| 141 | + return list(forwarded_args) |
| 142 | + if year >= _support_augmentation_start_year(forwarded_args): |
| 143 | + return list(forwarded_args) |
| 144 | + return _strip_support_augmentation_args(forwarded_args) |
| 145 | + |
| 146 | + |
84 | 147 | def year_output_dir(root: Path, year: int) -> Path: |
85 | 148 | return root / ".parallel_tmp" / str(year) |
86 | 149 |
|
@@ -123,7 +186,7 @@ def run_year( |
123 | 186 | "--output-dir", |
124 | 187 | str(output_dir), |
125 | 188 | "--save-h5", |
126 | | - *forwarded_args, |
| 189 | + *forwarded_args_for_year(year, forwarded_args), |
127 | 190 | ] |
128 | 191 |
|
129 | 192 | with log_path.open("w", encoding="utf-8") as log_file: |
@@ -168,6 +231,44 @@ def _json_clone(value): |
168 | 231 | return json.loads(json.dumps(value)) |
169 | 232 |
|
170 | 233 |
|
| 234 | +def _normalize_support_augmentation_contract(value): |
| 235 | + if value is None: |
| 236 | + return None |
| 237 | + normalized = _json_clone(value) |
| 238 | + if normalized.get("target_year_strategy") == "run_year": |
| 239 | + normalized.pop("target_year", None) |
| 240 | + normalized.pop("report_file", None) |
| 241 | + normalized.pop("report_summary", None) |
| 242 | + return normalized |
| 243 | + |
| 244 | + |
| 245 | +def _support_augmentation_activation_start(value) -> int | None: |
| 246 | + if not isinstance(value, dict): |
| 247 | + return None |
| 248 | + raw_value = value.get("activation_start_year") |
| 249 | + if raw_value is None: |
| 250 | + return None |
| 251 | + return int(raw_value) |
| 252 | + |
| 253 | + |
| 254 | +def support_augmentation_contracts_compatible(left, right, *, year: int) -> bool: |
| 255 | + if _normalize_support_augmentation_contract( |
| 256 | + left |
| 257 | + ) == _normalize_support_augmentation_contract(right): |
| 258 | + return True |
| 259 | + if left is None and right is not None: |
| 260 | + activation_year = _support_augmentation_activation_start(right) |
| 261 | + return activation_year is not None and year >= activation_year |
| 262 | + if left is not None and right is None: |
| 263 | + activation_year = _support_augmentation_activation_start(left) |
| 264 | + return activation_year is not None and year < activation_year |
| 265 | + return False |
| 266 | + |
| 267 | + |
| 268 | +def merge_support_augmentation_contract(left, right): |
| 269 | + return _json_clone(left if left is not None else right) |
| 270 | + |
| 271 | + |
171 | 272 | def manifest_contract(manifest: dict) -> dict: |
172 | 273 | tax_assumption = _json_clone(manifest.get("tax_assumption")) |
173 | 274 | if isinstance(tax_assumption, dict): |
@@ -209,6 +310,22 @@ def merge_outputs( |
209 | 310 | manifest_seed = temp_contract |
210 | 311 | else: |
211 | 312 | for key, value in manifest_seed.items(): |
| 313 | + if key == "support_augmentation": |
| 314 | + support_augmentation = temp_contract.get(key) |
| 315 | + if not support_augmentation_contracts_compatible( |
| 316 | + value, |
| 317 | + support_augmentation, |
| 318 | + year=year, |
| 319 | + ): |
| 320 | + raise ValueError( |
| 321 | + f"Temp manifest mismatch for {key} in year {year}: " |
| 322 | + f"{support_augmentation} != {value}" |
| 323 | + ) |
| 324 | + manifest_seed[key] = merge_support_augmentation_contract( |
| 325 | + value, |
| 326 | + support_augmentation, |
| 327 | + ) |
| 328 | + continue |
212 | 329 | if temp_contract.get(key) != value: |
213 | 330 | raise ValueError( |
214 | 331 | f"Temp manifest mismatch for {key} in year {year}: " |
|
0 commit comments