Skip to content

Commit c1c4083

Browse files
committed
Move fit weight test setup into fixtures
1 parent cf80040 commit c1c4083

3 files changed

Lines changed: 135 additions & 96 deletions

File tree

tests/unit/fit_weights/conftest.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from collections.abc import Callable
2+
from pathlib import Path
3+
4+
import pytest
5+
import yaml
6+
7+
from policyengine_us_data.fit_weights import (
8+
FitScope,
9+
FittedWeightsOutputBundle,
10+
)
11+
12+
13+
class FakeBatch:
14+
def __init__(self) -> None:
15+
self.files: dict[str, bytes] = {}
16+
17+
def put_file(self, file_obj, destination: str) -> None:
18+
self.files[destination] = file_obj.read()
19+
20+
21+
@pytest.fixture
22+
def artifacts_rel() -> str:
23+
return "artifacts/run-1"
24+
25+
26+
@pytest.fixture
27+
def calibration_package_path() -> Path:
28+
return Path("/pipeline/artifacts/run/calibration_package.pkl")
29+
30+
31+
@pytest.fixture
32+
def fake_batch() -> FakeBatch:
33+
return FakeBatch()
34+
35+
36+
@pytest.fixture
37+
def regional_result_bytes() -> dict[str, bytes]:
38+
return {
39+
"weights": b"weights",
40+
"geography": b"regional-geo",
41+
"config": b"regional-config",
42+
"log": b"regional-log",
43+
"cal_log": b"regional-epoch",
44+
}
45+
46+
47+
@pytest.fixture
48+
def national_result_bytes() -> dict[str, bytes]:
49+
return {
50+
"weights": b"weights",
51+
"geography": b"national-geo",
52+
"config": b"national-config",
53+
"log": b"national-log",
54+
"cal_log": b"national-epoch",
55+
}
56+
57+
58+
@pytest.fixture
59+
def regional_output_bundle(
60+
regional_result_bytes: dict[str, bytes],
61+
) -> FittedWeightsOutputBundle:
62+
return FittedWeightsOutputBundle.from_result_bytes(
63+
scope=FitScope.REGIONAL,
64+
result_bytes=regional_result_bytes,
65+
run_id="run-1",
66+
)
67+
68+
69+
@pytest.fixture
70+
def national_output_bundle(
71+
national_result_bytes: dict[str, bytes],
72+
) -> FittedWeightsOutputBundle:
73+
return FittedWeightsOutputBundle.from_result_bytes(
74+
scope=FitScope.NATIONAL,
75+
result_bytes=national_result_bytes,
76+
run_id="run-1",
77+
)
78+
79+
80+
@pytest.fixture
81+
def stage_3_substage() -> Callable[[str], dict]:
82+
data = yaml.safe_load(Path("docs/pipeline_map.yaml").read_text())
83+
substages = {substage["id"]: substage for substage in data["stages"]}
84+
return substages.__getitem__

tests/unit/fit_weights/test_bundles.py

Lines changed: 48 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -10,93 +10,72 @@
1010
)
1111

1212

13-
class FakeBatch:
14-
def __init__(self) -> None:
15-
self.files: dict[str, bytes] = {}
16-
17-
def put_file(self, file_obj, destination: str) -> None:
18-
self.files[destination] = file_obj.read()
19-
20-
21-
def test_input_bundle_exposes_calibration_package_identity_path() -> None:
13+
def test_input_bundle_exposes_calibration_package_identity_path(
14+
calibration_package_path: Path,
15+
) -> None:
2216
bundle = FittedWeightsInputBundle(
2317
scope="regional",
24-
calibration_package_path=Path(
25-
"/pipeline/artifacts/run/calibration_package.pkl"
26-
),
18+
calibration_package_path=calibration_package_path,
2719
)
2820

2921
assert bundle.scope == FitScope.REGIONAL
3022
assert bundle.artifact_identity_paths() == {
31-
"calibration_package": Path("/pipeline/artifacts/run/calibration_package.pkl")
23+
"calibration_package": calibration_package_path
3224
}
3325

3426

35-
def test_regional_output_bundle_writes_expected_paths() -> None:
36-
bundle = FittedWeightsOutputBundle.from_result_bytes(
37-
scope=FitScope.REGIONAL,
38-
result_bytes={
39-
"weights": b"weights",
40-
"geography": b"geo",
41-
"config": b"config",
42-
"log": b"log",
43-
"cal_log": b"epoch",
44-
},
45-
run_id="run-1",
46-
)
47-
batch = FakeBatch()
48-
49-
written = bundle.write_artifacts(batch, "artifacts/run-1")
27+
def test_regional_output_bundle_writes_expected_paths(
28+
artifacts_rel: str,
29+
fake_batch,
30+
regional_output_bundle: FittedWeightsOutputBundle,
31+
) -> None:
32+
written = regional_output_bundle.write_artifacts(fake_batch, artifacts_rel)
5033

5134
assert written == [
5235
"artifacts/run-1/calibration_weights.npy",
5336
"artifacts/run-1/geography_assignment.npz",
5437
"artifacts/run-1/unified_run_config.json",
5538
]
56-
assert batch.files["artifacts/run-1/calibration_weights.npy"] == b"weights"
57-
assert bundle.artifact_paths("/pipeline/artifacts/run-1") == [
39+
assert fake_batch.files["artifacts/run-1/calibration_weights.npy"] == b"weights"
40+
assert regional_output_bundle.artifact_paths("/pipeline/artifacts/run-1") == [
5841
Path("/pipeline/artifacts/run-1/calibration_weights.npy"),
5942
Path("/pipeline/artifacts/run-1/geography_assignment.npz"),
6043
Path("/pipeline/artifacts/run-1/unified_run_config.json"),
6144
]
6245

6346

64-
def test_national_output_bundle_writes_expected_paths() -> None:
65-
bundle = FittedWeightsOutputBundle.from_result_bytes(
66-
scope="national",
67-
result_bytes={
68-
"weights": b"weights",
69-
"geography": b"geo",
70-
"config": b"config",
71-
},
72-
)
73-
batch = FakeBatch()
74-
75-
written = bundle.write_artifacts(batch, "artifacts/run-1")
47+
def test_national_output_bundle_writes_expected_paths(
48+
artifacts_rel: str,
49+
fake_batch,
50+
national_output_bundle: FittedWeightsOutputBundle,
51+
) -> None:
52+
written = national_output_bundle.write_artifacts(fake_batch, artifacts_rel)
7653

7754
assert written == [
7855
"artifacts/run-1/national_calibration_weights.npy",
7956
"artifacts/run-1/national_geography_assignment.npz",
8057
"artifacts/run-1/national_unified_run_config.json",
8158
]
82-
assert batch.files["artifacts/run-1/national_calibration_weights.npy"] == b"weights"
59+
assert (
60+
fake_batch.files["artifacts/run-1/national_calibration_weights.npy"]
61+
== b"weights"
62+
)
8363

8464

85-
def test_missing_optional_epoch_log_is_allowed() -> None:
65+
def test_missing_optional_epoch_log_is_allowed(
66+
regional_result_bytes: dict[str, bytes],
67+
) -> None:
68+
result_bytes = dict(regional_result_bytes)
69+
result_bytes.pop("cal_log")
8670
bundle = FittedWeightsOutputBundle.from_result_bytes(
8771
scope=FitScope.REGIONAL,
88-
result_bytes={
89-
"weights": b"weights",
90-
"geography": b"geo",
91-
"config": b"config",
92-
"log": b"log",
93-
},
72+
result_bytes=result_bytes,
9473
)
9574

9675
assert bundle.diagnostic_result_bytes() == {
97-
"log": b"log",
76+
"log": b"regional-log",
9877
"cal_log": None,
99-
"config": b"config",
78+
"config": b"regional-config",
10079
}
10180

10281

@@ -118,45 +97,32 @@ def test_missing_weights_is_a_hard_failure() -> None:
11897
def test_missing_required_primary_artifacts_fail_before_writes(
11998
missing_key: str,
12099
expected_role: str,
100+
artifacts_rel: str,
101+
fake_batch,
102+
regional_result_bytes: dict[str, bytes],
121103
) -> None:
122-
result_bytes = {
123-
"weights": b"weights",
124-
"geography": b"geo",
125-
"config": b"config",
126-
}
104+
result_bytes = dict(regional_result_bytes)
127105
result_bytes.pop(missing_key)
128106
bundle = FittedWeightsOutputBundle.from_result_bytes(
129107
scope=FitScope.REGIONAL,
130108
result_bytes=result_bytes,
131109
)
132110

133111
with pytest.raises(MissingFitWeightsOutputError, match=expected_role):
134-
bundle.write_artifacts(FakeBatch(), "artifacts/run-1")
112+
bundle.write_artifacts(fake_batch, artifacts_rel)
135113

136114

137-
def test_diagnostics_are_scoped_to_the_output_bundle() -> None:
138-
regional = FittedWeightsOutputBundle.from_result_bytes(
139-
scope=FitScope.REGIONAL,
140-
result_bytes={
141-
"weights": b"weights",
142-
"geography": b"regional-geo",
143-
"config": b"regional-config",
144-
"log": b"regional-log",
145-
"cal_log": b"regional-epoch",
146-
},
115+
def test_diagnostics_are_scoped_to_the_output_bundle(
116+
regional_output_bundle: FittedWeightsOutputBundle,
117+
national_output_bundle: FittedWeightsOutputBundle,
118+
) -> None:
119+
assert (
120+
regional_output_bundle.artifacts.diagnostics.filename
121+
== "unified_diagnostics.csv"
147122
)
148-
national = FittedWeightsOutputBundle.from_result_bytes(
149-
scope=FitScope.NATIONAL,
150-
result_bytes={
151-
"weights": b"weights",
152-
"geography": b"national-geo",
153-
"config": b"national-config",
154-
"log": b"national-log",
155-
"cal_log": b"national-epoch",
156-
},
123+
assert (
124+
national_output_bundle.artifacts.diagnostics.filename
125+
== "national_unified_diagnostics.csv"
157126
)
158-
159-
assert regional.artifacts.diagnostics.filename == "unified_diagnostics.csv"
160-
assert national.artifacts.diagnostics.filename == "national_unified_diagnostics.csv"
161-
assert regional.diagnostic_result_bytes()["log"] == b"regional-log"
162-
assert national.diagnostic_result_bytes()["log"] == b"national-log"
127+
assert regional_output_bundle.diagnostic_result_bytes()["log"] == b"regional-log"
128+
assert national_output_bundle.diagnostic_result_bytes()["log"] == b"national-log"

tests/unit/fit_weights/test_pipeline_docs.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,7 @@
1-
from pathlib import Path
2-
3-
import yaml
4-
51
from policyengine_us_data.fit_weights import FitScope, fit_artifacts_for_scope
62
from scripts.extract_pipeline_docs import scan_decorated_objects
73

84

9-
def _substage(substage_id: str) -> dict:
10-
data = yaml.safe_load(Path("docs/pipeline_map.yaml").read_text())
11-
return next(
12-
substage for substage in data["stages"] if substage["id"] == substage_id
13-
)
14-
15-
165
def test_fit_weights_identity_nodes_are_in_generated_pipeline_docs() -> None:
176
decorated = scan_decorated_objects()
187

@@ -26,11 +15,11 @@ def test_fit_weights_identity_nodes_are_in_generated_pipeline_docs() -> None:
2615
)
2716

2817

29-
def test_stage_3_pipeline_map_labels_match_scoped_artifacts() -> None:
18+
def test_stage_3_pipeline_map_labels_match_scoped_artifacts(stage_3_substage) -> None:
3019
regional_artifacts = fit_artifacts_for_scope(FitScope.REGIONAL)
3120
national_artifacts = fit_artifacts_for_scope(FitScope.NATIONAL)
32-
regional = _substage("3a_weight_fitting_regional")
33-
national = _substage("3b_weight_fitting_national")
21+
regional = stage_3_substage("3a_weight_fitting_regional")
22+
national = stage_3_substage("3b_weight_fitting_national")
3423

3524
regional_nodes = {node["id"]: node for node in regional["extra_nodes"]}
3625
national_nodes = {node["id"]: node for node in national["extra_nodes"]}

0 commit comments

Comments
 (0)