2828VOLUME_MOUNT = "/checkpoints"
2929
3030# Script to output file mapping for checkpointing
31+ # Values can be a single file path (str) or a list of file paths
3132SCRIPT_OUTPUTS = {
3233 "policyengine_us_data/utils/uprating.py" : (
3334 "policyengine_us_data/storage/uprating_factors.csv"
4748 "policyengine_us_data/datasets/cps/extended_cps.py" : (
4849 "policyengine_us_data/storage/extended_cps_2024.h5"
4950 ),
50- "policyengine_us_data/datasets/cps/enhanced_cps.py" : (
51- "policyengine_us_data/storage/enhanced_cps_2024.h5"
52- ),
51+ # enhanced_cps.py produces both the dataset and calibration log
52+ "policyengine_us_data/datasets/cps/enhanced_cps.py" : [
53+ "policyengine_us_data/storage/enhanced_cps_2024.h5" ,
54+ "calibration_log.csv" ,
55+ ],
5356 "policyengine_us_data/datasets/cps/"
5457 "local_area_calibration/create_stratified_cps.py" : (
5558 "policyengine_us_data/storage/stratified_extended_cps_2024.h5"
@@ -161,7 +164,7 @@ def run_script(
161164
162165def run_script_with_checkpoint (
163166 script_path : str ,
164- output_file : str ,
167+ output_files : str | list [ str ] ,
165168 branch : str ,
166169 volume : modal .Volume ,
167170 args : Optional [list ] = None ,
@@ -171,7 +174,8 @@ def run_script_with_checkpoint(
171174
172175 Args:
173176 script_path: Path to the Python script to run.
174- output_file: Path to the output file produced by the script.
177+ output_files: Path(s) to output file(s) produced by the script.
178+ Can be a single string or a list of strings.
175179 branch: Git branch name for checkpoint scoping.
176180 volume: Modal volume for checkpointing.
177181 args: Optional list of command-line arguments.
@@ -180,16 +184,26 @@ def run_script_with_checkpoint(
180184 Returns:
181185 The script_path that was executed.
182186 """
183- # Try to restore from checkpoint first
184- if restore_from_checkpoint (branch , output_file ):
187+ # Normalize to list
188+ if isinstance (output_files , str ):
189+ output_files = [output_files ]
190+
191+ # Check if ALL outputs are checkpointed
192+ all_checkpointed = all (is_checkpointed (branch , f ) for f in output_files )
193+
194+ if all_checkpointed :
195+ # Restore all files from checkpoint
196+ for output_file in output_files :
197+ restore_from_checkpoint (branch , output_file )
185198 print (f"Skipping { script_path } (restored from checkpoint)" )
186199 return script_path
187200
188201 # Run the script
189202 run_script (script_path , args = args , env = env )
190203
191- # Checkpoint the output
192- save_checkpoint (branch , output_file , volume )
204+ # Checkpoint all outputs
205+ for output_file in output_files :
206+ save_checkpoint (branch , output_file , volume )
193207
194208 return script_path
195209
0 commit comments