Skip to content

Commit f2a3671

Browse files
committed
add calibration log to checkpoint
1 parent 9428a6d commit f2a3671

1 file changed

Lines changed: 23 additions & 9 deletions

File tree

modal_app/data_build.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
VOLUME_MOUNT = "/checkpoints"
2929

3030
# Script to output file mapping for checkpointing
31+
# Values can be a single file path (str) or a list of file paths
3132
SCRIPT_OUTPUTS = {
3233
"policyengine_us_data/utils/uprating.py": (
3334
"policyengine_us_data/storage/uprating_factors.csv"
@@ -47,9 +48,11 @@
4748
"policyengine_us_data/datasets/cps/extended_cps.py": (
4849
"policyengine_us_data/storage/extended_cps_2024.h5"
4950
),
50-
"policyengine_us_data/datasets/cps/enhanced_cps.py": (
51-
"policyengine_us_data/storage/enhanced_cps_2024.h5"
52-
),
51+
# enhanced_cps.py produces both the dataset and calibration log
52+
"policyengine_us_data/datasets/cps/enhanced_cps.py": [
53+
"policyengine_us_data/storage/enhanced_cps_2024.h5",
54+
"calibration_log.csv",
55+
],
5356
"policyengine_us_data/datasets/cps/"
5457
"local_area_calibration/create_stratified_cps.py": (
5558
"policyengine_us_data/storage/stratified_extended_cps_2024.h5"
@@ -161,7 +164,7 @@ def run_script(
161164

162165
def run_script_with_checkpoint(
163166
script_path: str,
164-
output_file: str,
167+
output_files: str | list[str],
165168
branch: str,
166169
volume: modal.Volume,
167170
args: Optional[list] = None,
@@ -171,7 +174,8 @@ def run_script_with_checkpoint(
171174
172175
Args:
173176
script_path: Path to the Python script to run.
174-
output_file: Path to the output file produced by the script.
177+
output_files: Path(s) to output file(s) produced by the script.
178+
Can be a single string or a list of strings.
175179
branch: Git branch name for checkpoint scoping.
176180
volume: Modal volume for checkpointing.
177181
args: Optional list of command-line arguments.
@@ -180,16 +184,26 @@ def run_script_with_checkpoint(
180184
Returns:
181185
The script_path that was executed.
182186
"""
183-
# Try to restore from checkpoint first
184-
if restore_from_checkpoint(branch, output_file):
187+
# Normalize to list
188+
if isinstance(output_files, str):
189+
output_files = [output_files]
190+
191+
# Check if ALL outputs are checkpointed
192+
all_checkpointed = all(is_checkpointed(branch, f) for f in output_files)
193+
194+
if all_checkpointed:
195+
# Restore all files from checkpoint
196+
for output_file in output_files:
197+
restore_from_checkpoint(branch, output_file)
185198
print(f"Skipping {script_path} (restored from checkpoint)")
186199
return script_path
187200

188201
# Run the script
189202
run_script(script_path, args=args, env=env)
190203

191-
# Checkpoint the output
192-
save_checkpoint(branch, output_file, volume)
204+
# Checkpoint all outputs
205+
for output_file in output_files:
206+
save_checkpoint(branch, output_file, volume)
193207

194208
return script_path
195209

0 commit comments

Comments
 (0)