1313
1414from modal_app .images import gpu_image as image # noqa: E402
1515from policyengine_us_data .calibration_package .specs import ( # noqa: E402
16+ CALIBRATION_PACKAGE_CONTRACT_FILENAME ,
1617 calibration_package_artifact_paths ,
1718 stage2_build_context_for_run ,
1819)
1920from policyengine_us_data .fit_weights import ( # noqa: E402
2021 FitResultBytes ,
2122 FitScope ,
23+ FittedWeightsInputBundle ,
2224 NATIONAL_FIT_LAMBDA_L0 ,
2325 fit_artifacts_for_scope ,
2426)
@@ -288,6 +290,9 @@ def _fit_from_package_impl(
288290 branch : str ,
289291 epochs : int ,
290292 volume_package_path : str = None ,
293+ volume_package_contract_path : str = None ,
294+ allow_legacy_no_contract : bool = False ,
295+ fit_scope : str = FitScope .REGIONAL .value ,
291296 target_config : str = None ,
292297 beta : float = None ,
293298 lambda_l0 : float = None ,
@@ -300,6 +305,21 @@ def _fit_from_package_impl(
300305 raise ValueError ("volume_package_path is required" )
301306
302307 _setup_repo ()
308+ input_bundle = FittedWeightsInputBundle (
309+ scope = fit_scope ,
310+ calibration_package_path = Path (volume_package_path ),
311+ calibration_package_contract_path = (
312+ Path (volume_package_contract_path ) if volume_package_contract_path else None
313+ ),
314+ allow_legacy_no_contract = allow_legacy_no_contract ,
315+ )
316+ stage2_identity = input_bundle .stage2_identity ()
317+ if stage2_identity .stage2_contract_mode == "stage2_contract" :
318+ print (
319+ "Validated Stage 2 calibration package contract "
320+ f"{ stage2_identity .calibration_package_contract_fingerprint } " ,
321+ flush = True ,
322+ )
303323
304324 pkg_path = "/root/calibration_package.pkl"
305325 import shutil
@@ -816,11 +836,17 @@ def fit_from_package_t4(
816836 learning_rate : float = None ,
817837 log_freq : int = None ,
818838 volume_package_path : str = None ,
839+ volume_package_contract_path : str = None ,
840+ allow_legacy_no_contract : bool = False ,
841+ fit_scope : str = FitScope .REGIONAL .value ,
819842) -> dict :
820843 return _fit_from_package_impl (
821844 branch ,
822845 epochs ,
823846 volume_package_path = volume_package_path ,
847+ volume_package_contract_path = volume_package_contract_path ,
848+ allow_legacy_no_contract = allow_legacy_no_contract ,
849+ fit_scope = fit_scope ,
824850 target_config = target_config ,
825851 beta = beta ,
826852 lambda_l0 = lambda_l0 ,
@@ -848,11 +874,17 @@ def fit_from_package_a10(
848874 learning_rate : float = None ,
849875 log_freq : int = None ,
850876 volume_package_path : str = None ,
877+ volume_package_contract_path : str = None ,
878+ allow_legacy_no_contract : bool = False ,
879+ fit_scope : str = FitScope .REGIONAL .value ,
851880) -> dict :
852881 return _fit_from_package_impl (
853882 branch ,
854883 epochs ,
855884 volume_package_path = volume_package_path ,
885+ volume_package_contract_path = volume_package_contract_path ,
886+ allow_legacy_no_contract = allow_legacy_no_contract ,
887+ fit_scope = fit_scope ,
856888 target_config = target_config ,
857889 beta = beta ,
858890 lambda_l0 = lambda_l0 ,
@@ -880,11 +912,17 @@ def fit_from_package_a100_40(
880912 learning_rate : float = None ,
881913 log_freq : int = None ,
882914 volume_package_path : str = None ,
915+ volume_package_contract_path : str = None ,
916+ allow_legacy_no_contract : bool = False ,
917+ fit_scope : str = FitScope .REGIONAL .value ,
883918) -> dict :
884919 return _fit_from_package_impl (
885920 branch ,
886921 epochs ,
887922 volume_package_path = volume_package_path ,
923+ volume_package_contract_path = volume_package_contract_path ,
924+ allow_legacy_no_contract = allow_legacy_no_contract ,
925+ fit_scope = fit_scope ,
888926 target_config = target_config ,
889927 beta = beta ,
890928 lambda_l0 = lambda_l0 ,
@@ -912,11 +950,17 @@ def fit_from_package_a100_80(
912950 learning_rate : float = None ,
913951 log_freq : int = None ,
914952 volume_package_path : str = None ,
953+ volume_package_contract_path : str = None ,
954+ allow_legacy_no_contract : bool = False ,
955+ fit_scope : str = FitScope .REGIONAL .value ,
915956) -> dict :
916957 return _fit_from_package_impl (
917958 branch ,
918959 epochs ,
919960 volume_package_path = volume_package_path ,
961+ volume_package_contract_path = volume_package_contract_path ,
962+ allow_legacy_no_contract = allow_legacy_no_contract ,
963+ fit_scope = fit_scope ,
920964 target_config = target_config ,
921965 beta = beta ,
922966 lambda_l0 = lambda_l0 ,
@@ -944,11 +988,17 @@ def fit_from_package_h100(
944988 learning_rate : float = None ,
945989 log_freq : int = None ,
946990 volume_package_path : str = None ,
991+ volume_package_contract_path : str = None ,
992+ allow_legacy_no_contract : bool = False ,
993+ fit_scope : str = FitScope .REGIONAL .value ,
947994) -> dict :
948995 return _fit_from_package_impl (
949996 branch ,
950997 epochs ,
951998 volume_package_path = volume_package_path ,
999+ volume_package_contract_path = volume_package_contract_path ,
1000+ allow_legacy_no_contract = allow_legacy_no_contract ,
1001+ fit_scope = fit_scope ,
9521002 target_config = target_config ,
9531003 beta = beta ,
9541004 lambda_l0 = lambda_l0 ,
@@ -1008,12 +1058,23 @@ def main(
10081058
10091059 if package_path :
10101060 vol_path = f"{ PIPELINE_MOUNT } /artifacts/calibration_package.pkl"
1061+ local_contract_path = Path (package_path ).with_name (
1062+ CALIBRATION_PACKAGE_CONTRACT_FILENAME
1063+ )
1064+ vol_contract_path = (
1065+ f"{ PIPELINE_MOUNT } /artifacts/{ CALIBRATION_PACKAGE_CONTRACT_FILENAME } "
1066+ if local_contract_path .exists ()
1067+ else None
1068+ )
10111069 print (f"Reading package from { package_path } ..." , flush = True )
10121070 import json as _json
10131071 import pickle as _pkl
10141072
10151073 with open (package_path , "rb" ) as f :
10161074 package_bytes = f .read ()
1075+ contract_bytes = (
1076+ local_contract_path .read_bytes () if local_contract_path .exists () else None
1077+ )
10171078 size = len (package_bytes )
10181079 pkg_meta = _pkl .loads (package_bytes ).get ("metadata" , {})
10191080 sidecar_bytes = _json .dumps (pkg_meta , indent = 2 ).encode ()
@@ -1032,6 +1093,11 @@ def main(
10321093 BytesIO (sidecar_bytes ),
10331094 "artifacts/calibration_package_meta.json" ,
10341095 )
1096+ if contract_bytes is not None :
1097+ batch .put_file (
1098+ BytesIO (contract_bytes ),
1099+ f"artifacts/{ CALIBRATION_PACKAGE_CONTRACT_FILENAME } " ,
1100+ )
10351101 pipeline_vol .commit ()
10361102 del package_bytes
10371103 print ("Upload complete." , flush = True )
@@ -1047,6 +1113,9 @@ def main(
10471113 learning_rate = learning_rate ,
10481114 log_freq = log_freq ,
10491115 volume_package_path = vol_path ,
1116+ volume_package_contract_path = vol_contract_path ,
1117+ allow_legacy_no_contract = True ,
1118+ fit_scope = scope .value ,
10501119 )
10511120 elif full_pipeline :
10521121 print (
@@ -1080,6 +1149,9 @@ def main(
10801149 )
10811150 else :
10821151 vol_path = f"{ PIPELINE_MOUNT } /artifacts/calibration_package.pkl"
1152+ vol_contract_path = (
1153+ f"{ PIPELINE_MOUNT } /artifacts/{ CALIBRATION_PACKAGE_CONTRACT_FILENAME } "
1154+ )
10831155 vol_info = check_volume_package .remote ()
10841156 if not vol_info ["exists" ]:
10851157 raise SystemExit (
@@ -1134,6 +1206,9 @@ def main(
11341206 learning_rate = learning_rate ,
11351207 log_freq = log_freq ,
11361208 volume_package_path = vol_path ,
1209+ volume_package_contract_path = vol_contract_path ,
1210+ allow_legacy_no_contract = True ,
1211+ fit_scope = scope .value ,
11371212 )
11381213
11391214 with open (output , "wb" ) as f :
0 commit comments