1010)
1111
1212
13- class FakeBatch :
14- def __init__ (self ) -> None :
15- self .files : dict [str , bytes ] = {}
16-
17- def put_file (self , file_obj , destination : str ) -> None :
18- self .files [destination ] = file_obj .read ()
19-
20-
21- def test_input_bundle_exposes_calibration_package_identity_path () -> None :
13+ def test_input_bundle_exposes_calibration_package_identity_path (
14+ calibration_package_path : Path ,
15+ ) -> None :
2216 bundle = FittedWeightsInputBundle (
2317 scope = "regional" ,
24- calibration_package_path = Path (
25- "/pipeline/artifacts/run/calibration_package.pkl"
26- ),
18+ calibration_package_path = calibration_package_path ,
2719 )
2820
2921 assert bundle .scope == FitScope .REGIONAL
3022 assert bundle .artifact_identity_paths () == {
31- "calibration_package" : Path ( "/pipeline/artifacts/run/calibration_package.pkl" )
23+ "calibration_package" : calibration_package_path
3224 }
3325
3426
35- def test_regional_output_bundle_writes_expected_paths () -> None :
36- bundle = FittedWeightsOutputBundle .from_result_bytes (
37- scope = FitScope .REGIONAL ,
38- result_bytes = {
39- "weights" : b"weights" ,
40- "geography" : b"geo" ,
41- "config" : b"config" ,
42- "log" : b"log" ,
43- "cal_log" : b"epoch" ,
44- },
45- run_id = "run-1" ,
46- )
47- batch = FakeBatch ()
48-
49- written = bundle .write_artifacts (batch , "artifacts/run-1" )
27+ def test_regional_output_bundle_writes_expected_paths (
28+ artifacts_rel : str ,
29+ fake_batch ,
30+ regional_output_bundle : FittedWeightsOutputBundle ,
31+ ) -> None :
32+ written = regional_output_bundle .write_artifacts (fake_batch , artifacts_rel )
5033
5134 assert written == [
5235 "artifacts/run-1/calibration_weights.npy" ,
5336 "artifacts/run-1/geography_assignment.npz" ,
5437 "artifacts/run-1/unified_run_config.json" ,
5538 ]
56- assert batch .files ["artifacts/run-1/calibration_weights.npy" ] == b"weights"
57- assert bundle .artifact_paths ("/pipeline/artifacts/run-1" ) == [
39+ assert fake_batch .files ["artifacts/run-1/calibration_weights.npy" ] == b"weights"
40+ assert regional_output_bundle .artifact_paths ("/pipeline/artifacts/run-1" ) == [
5841 Path ("/pipeline/artifacts/run-1/calibration_weights.npy" ),
5942 Path ("/pipeline/artifacts/run-1/geography_assignment.npz" ),
6043 Path ("/pipeline/artifacts/run-1/unified_run_config.json" ),
6144 ]
6245
6346
64- def test_national_output_bundle_writes_expected_paths () -> None :
65- bundle = FittedWeightsOutputBundle .from_result_bytes (
66- scope = "national" ,
67- result_bytes = {
68- "weights" : b"weights" ,
69- "geography" : b"geo" ,
70- "config" : b"config" ,
71- },
72- )
73- batch = FakeBatch ()
74-
75- written = bundle .write_artifacts (batch , "artifacts/run-1" )
47+ def test_national_output_bundle_writes_expected_paths (
48+ artifacts_rel : str ,
49+ fake_batch ,
50+ national_output_bundle : FittedWeightsOutputBundle ,
51+ ) -> None :
52+ written = national_output_bundle .write_artifacts (fake_batch , artifacts_rel )
7653
7754 assert written == [
7855 "artifacts/run-1/national_calibration_weights.npy" ,
7956 "artifacts/run-1/national_geography_assignment.npz" ,
8057 "artifacts/run-1/national_unified_run_config.json" ,
8158 ]
82- assert batch .files ["artifacts/run-1/national_calibration_weights.npy" ] == b"weights"
59+ assert (
60+ fake_batch .files ["artifacts/run-1/national_calibration_weights.npy" ]
61+ == b"weights"
62+ )
8363
8464
85- def test_missing_optional_epoch_log_is_allowed () -> None :
65+ def test_missing_optional_epoch_log_is_allowed (
66+ regional_result_bytes : dict [str , bytes ],
67+ ) -> None :
68+ result_bytes = dict (regional_result_bytes )
69+ result_bytes .pop ("cal_log" )
8670 bundle = FittedWeightsOutputBundle .from_result_bytes (
8771 scope = FitScope .REGIONAL ,
88- result_bytes = {
89- "weights" : b"weights" ,
90- "geography" : b"geo" ,
91- "config" : b"config" ,
92- "log" : b"log" ,
93- },
72+ result_bytes = result_bytes ,
9473 )
9574
9675 assert bundle .diagnostic_result_bytes () == {
97- "log" : b"log" ,
76+ "log" : b"regional- log" ,
9877 "cal_log" : None ,
99- "config" : b"config" ,
78+ "config" : b"regional- config" ,
10079 }
10180
10281
@@ -118,45 +97,32 @@ def test_missing_weights_is_a_hard_failure() -> None:
11897def test_missing_required_primary_artifacts_fail_before_writes (
11998 missing_key : str ,
12099 expected_role : str ,
100+ artifacts_rel : str ,
101+ fake_batch ,
102+ regional_result_bytes : dict [str , bytes ],
121103) -> None :
122- result_bytes = {
123- "weights" : b"weights" ,
124- "geography" : b"geo" ,
125- "config" : b"config" ,
126- }
104+ result_bytes = dict (regional_result_bytes )
127105 result_bytes .pop (missing_key )
128106 bundle = FittedWeightsOutputBundle .from_result_bytes (
129107 scope = FitScope .REGIONAL ,
130108 result_bytes = result_bytes ,
131109 )
132110
133111 with pytest .raises (MissingFitWeightsOutputError , match = expected_role ):
134- bundle .write_artifacts (FakeBatch (), "artifacts/run-1" )
112+ bundle .write_artifacts (fake_batch , artifacts_rel )
135113
136114
137- def test_diagnostics_are_scoped_to_the_output_bundle () -> None :
138- regional = FittedWeightsOutputBundle .from_result_bytes (
139- scope = FitScope .REGIONAL ,
140- result_bytes = {
141- "weights" : b"weights" ,
142- "geography" : b"regional-geo" ,
143- "config" : b"regional-config" ,
144- "log" : b"regional-log" ,
145- "cal_log" : b"regional-epoch" ,
146- },
115+ def test_diagnostics_are_scoped_to_the_output_bundle (
116+ regional_output_bundle : FittedWeightsOutputBundle ,
117+ national_output_bundle : FittedWeightsOutputBundle ,
118+ ) -> None :
119+ assert (
120+ regional_output_bundle .artifacts .diagnostics .filename
121+ == "unified_diagnostics.csv"
147122 )
148- national = FittedWeightsOutputBundle .from_result_bytes (
149- scope = FitScope .NATIONAL ,
150- result_bytes = {
151- "weights" : b"weights" ,
152- "geography" : b"national-geo" ,
153- "config" : b"national-config" ,
154- "log" : b"national-log" ,
155- "cal_log" : b"national-epoch" ,
156- },
123+ assert (
124+ national_output_bundle .artifacts .diagnostics .filename
125+ == "national_unified_diagnostics.csv"
157126 )
158-
159- assert regional .artifacts .diagnostics .filename == "unified_diagnostics.csv"
160- assert national .artifacts .diagnostics .filename == "national_unified_diagnostics.csv"
161- assert regional .diagnostic_result_bytes ()["log" ] == b"regional-log"
162- assert national .diagnostic_result_bytes ()["log" ] == b"national-log"
127+ assert regional_output_bundle .diagnostic_result_bytes ()["log" ] == b"regional-log"
128+ assert national_output_bundle .diagnostic_result_bytes ()["log" ] == b"national-log"
0 commit comments