Skip to content

Commit 6095df6

Browse files
committed
Merge remote-tracking branch 'upstream/main' into codex/tmp-pr669-merge
# Conflicts: # policyengine_us_data/datasets/cps/enhanced_cps.py
2 parents f618250 + b5f3e0e commit 6095df6

21 files changed

Lines changed: 1166 additions & 78 deletions

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
## [1.74.1] - 2026-04-03
2+
3+
### Fixed
4+
5+
- Added fail-closed dataset contract validation for built CPS artifacts, including
6+
`policyengine-us` lockfile version checks, per-entity HDF5 length validation,
7+
and file-based `Microsimulation` smoke tests in both the build and upload paths.
8+
9+
110
## [1.74.0] - 2026-04-02
211

312
### Added

modal_app/data_build.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
if _p not in sys.path:
1818
sys.path.insert(0, _p)
1919

20-
from modal_app.images import cpu_image as image
20+
from modal_app.images import cpu_image as image # noqa: E402
2121

2222
app = modal.App("policyengine-us-data")
2323

@@ -233,6 +233,34 @@ def run_script(
233233
return script_path
234234

235235

236+
def validate_and_maybe_upload_datasets(
237+
*,
238+
upload: bool,
239+
skip_enhanced_cps: bool,
240+
env: dict,
241+
) -> None:
242+
validation_args = ["--validate-only"]
243+
if skip_enhanced_cps:
244+
validation_args.append("--no-require-enhanced-cps")
245+
246+
print("=== Validating built datasets ===")
247+
run_script(
248+
"policyengine_us_data/storage/upload_completed_datasets.py",
249+
args=validation_args,
250+
env=env,
251+
)
252+
253+
if upload:
254+
upload_args = []
255+
if skip_enhanced_cps:
256+
upload_args.append("--no-require-enhanced-cps")
257+
run_script(
258+
"policyengine_us_data/storage/upload_completed_datasets.py",
259+
args=upload_args,
260+
env=env,
261+
)
262+
263+
236264
def run_script_with_checkpoint(
237265
script_path: str,
238266
output_files: str | list[str],
@@ -634,16 +662,11 @@ def build_datasets(
634662
print("=== Running tests with checkpointing ===")
635663
run_tests_with_checkpoints(branch, checkpoint_volume, env)
636664

637-
# Upload if requested (HF publication only)
638-
if upload:
639-
upload_args = []
640-
if skip_enhanced_cps:
641-
upload_args.append("--no-require-enhanced-cps")
642-
run_script(
643-
"policyengine_us_data/storage/upload_completed_datasets.py",
644-
args=upload_args,
645-
env=env,
646-
)
665+
validate_and_maybe_upload_datasets(
666+
upload=upload,
667+
skip_enhanced_cps=skip_enhanced_cps,
668+
env=env,
669+
)
647670

648671
# Clean up checkpoints after successful completion
649672
cleanup_checkpoints(branch, checkpoint_volume)

policyengine_us_data/datasets/cps/enhanced_cps.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def loss(weights):
131131
optimizer.zero_grad()
132132
masked = torch.exp(weights) * gates()
133133
l_main = loss(masked)
134-
loss_value = l_main + l0_lambda * gates.get_penalty()
134+
total_loss = l_main + l0_lambda * gates.get_penalty()
135135
if (log_path is not None) and (i % 10 == 0):
136136
gates.eval()
137137
estimates = (torch.exp(weights) * gates()) @ loss_matrix
@@ -155,11 +155,11 @@ def loss(weights):
155155
if (log_path is not None) and (i % 1000 == 0):
156156
performance.to_csv(log_path, index=False)
157157
if start_loss is None:
158-
start_loss = loss_value.item()
159-
loss_rel_change = (loss_value.item() - start_loss) / start_loss
160-
loss_value.backward()
158+
start_loss = total_loss.item()
159+
loss_rel_change = (total_loss.item() - start_loss) / start_loss
160+
total_loss.backward()
161161
iterator.set_postfix(
162-
{"loss": loss_value.item(), "loss_rel_change": loss_rel_change}
162+
{"loss": total_loss.item(), "loss_rel_change": loss_rel_change}
163163
)
164164
optimizer.step()
165165
if log_path is not None:
@@ -344,6 +344,7 @@ class EnhancedCPS_2024(EnhancedCPS):
344344
input_dataset = ExtendedCPS_2024_Half
345345
start_year = 2024
346346
end_year = 2024
347+
time_period = 2024
347348
name = "enhanced_cps_2024"
348349
label = "Enhanced CPS 2024"
349350
file_path = STORAGE_FOLDER / "enhanced_cps_2024.h5"

policyengine_us_data/db/etl_state_income_tax.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
CENSUS_STC_FLAT_FILE_URLS = {
3333
2023: "https://www2.census.gov/programs-surveys/stc/datasets/2023/FY2023-Flat-File.txt",
3434
}
35+
LATEST_STC_YEAR = max(CENSUS_STC_FLAT_FILE_URLS)
3536
CENSUS_STC_INDIVIDUAL_INCOME_TAX_ITEM = "T40"
3637
CENSUS_STC_NOT_AVAILABLE = "X"
3738

@@ -179,7 +180,9 @@ def transform_state_income_tax_data(df: pd.DataFrame) -> pd.DataFrame:
179180
return result
180181

181182

182-
def load_state_income_tax_data(df: pd.DataFrame, year: int) -> dict:
183+
def load_state_income_tax_data(
184+
df: pd.DataFrame, year: int, source_year: int | None = None
185+
) -> dict:
183186
"""
184187
Load state income tax targets into the calibration database.
185188
@@ -241,7 +244,7 @@ def load_state_income_tax_data(df: pd.DataFrame, year: int) -> dict:
241244
value=row["income_tax_collections"],
242245
active=True,
243246
source="Census STC",
244-
notes=f"Census STC FY{year}",
247+
notes=f"Census STC FY{source_year or year}",
245248
)
246249
)
247250

@@ -263,14 +266,22 @@ def main():
263266
)
264267
_, year = etl_argparser("ETL for state income tax calibration targets")
265268

266-
logger.info(f"Extracting Census STC data for FY{year}...")
267-
raw_df = extract_state_income_tax_data(year)
269+
data_year = min(year, LATEST_STC_YEAR)
270+
if data_year != year:
271+
logger.warning(
272+
f"Census STC data not available for {year}; "
273+
f"using latest available year ({LATEST_STC_YEAR})"
274+
)
275+
logger.info(f"Extracting Census STC data for FY{data_year}...")
276+
raw_df = extract_state_income_tax_data(data_year)
268277

269278
logger.info("Transforming data...")
270279
transformed_df = transform_state_income_tax_data(raw_df)
271280

272281
logger.info(f"Loading {len(transformed_df)} state income tax targets...")
273-
stratum_lookup = load_state_income_tax_data(transformed_df, year)
282+
stratum_lookup = load_state_income_tax_data(
283+
transformed_df, year, source_year=data_year
284+
)
274285

275286
# Print summary
276287
total_collections = transformed_df["income_tax_collections"].sum()

policyengine_us_data/storage/upload_completed_datasets.py

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1-
import h5py
21
from pathlib import Path
32

4-
from policyengine_us_data.datasets import (
5-
EnhancedCPS_2024,
6-
)
3+
import h5py
4+
from policyengine_core.data import Dataset
5+
6+
from policyengine_us_data.datasets import EnhancedCPS_2024
77
from policyengine_us_data.datasets.cps.cps import CPS_2024
88
from policyengine_us_data.storage import STORAGE_FOLDER
99
from policyengine_us_data.utils.data_upload import upload_data_files
10+
from policyengine_us_data.utils.dataset_validation import (
11+
DatasetContractError,
12+
load_dataset_for_validation,
13+
validate_dataset_contract,
14+
)
1015

1116
# Datasets that require full validation before upload.
1217
# These are the main datasets used in production simulations.
@@ -15,14 +20,9 @@
1520
"cps_2024.h5",
1621
}
1722

18-
FILENAME_TO_DATASET = {
19-
"enhanced_cps_2024.h5": EnhancedCPS_2024,
20-
"cps_2024.h5": CPS_2024,
21-
}
22-
2323
# Minimum file sizes in bytes for validated datasets.
2424
MIN_FILE_SIZES = {
25-
"enhanced_cps_2024.h5": 100 * 1024 * 1024, # 100 MB
25+
"enhanced_cps_2024.h5": 95 * 1024 * 1024, # 95 MB
2626
"cps_2024.h5": 50 * 1024 * 1024, # 50 MB
2727
}
2828

@@ -118,15 +118,23 @@ def _check_group_has_data(f, name):
118118
+ "\n".join(f" - {e}" for e in errors)
119119
)
120120

121+
try:
122+
contract_summary = validate_dataset_contract(file_path)
123+
except DatasetContractError as e:
124+
errors.append(f"Dataset contract validation failed: {e}")
125+
raise DatasetValidationError(
126+
f"Validation failed for {filename}:\n"
127+
+ "\n".join(f" - {e}" for e in errors)
128+
) from e
129+
121130
# 3. Aggregate statistics check via Microsimulation
122131
# Import here to avoid heavy import at module level.
123132
from policyengine_us import Microsimulation
124133

125134
try:
126-
dataset_cls = FILENAME_TO_DATASET.get(filename)
127-
if dataset_cls is None:
128-
raise DatasetValidationError(f"No dataset class registered for {filename}")
129-
sim = Microsimulation(dataset=dataset_cls)
135+
sim = Microsimulation(
136+
dataset=load_dataset_for_validation(file_path, Dataset.from_file)
137+
)
130138
year = 2024
131139

132140
emp_income = sim.calculate("employment_income", year).sum()
@@ -159,6 +167,15 @@ def _check_group_has_data(f, name):
159167

160168
print(f" ✓ Validation passed for {filename}")
161169
print(f" File size: {file_size / 1024 / 1024:.1f} MB")
170+
print(
171+
" policyengine-us: "
172+
f"{contract_summary.policyengine_us.version}"
173+
+ (
174+
f" (locked {contract_summary.policyengine_us.locked_version})"
175+
if contract_summary.policyengine_us.locked_version
176+
else ""
177+
)
178+
)
162179
print(f" employment_income sum: ${emp_income:,.0f}")
163180
print(f" Household weight sum: {hh_weight:,.0f}")
164181

@@ -210,14 +227,18 @@ def upload_datasets(require_enhanced_cps: bool = True):
210227

211228
def validate_all_datasets():
212229
"""Validate all main datasets in storage. Called by `make validate-data`."""
213-
for filename in VALIDATED_FILENAMES:
214-
file_path = STORAGE_FOLDER / filename
215-
if file_path.exists():
216-
validate_dataset(file_path)
217-
else:
218-
raise FileNotFoundError(
219-
f"Expected dataset {filename} not found at {file_path}"
220-
)
230+
validate_built_datasets(require_enhanced_cps=True)
231+
232+
233+
def validate_built_datasets(require_enhanced_cps: bool = True):
234+
required_files = [CPS_2024.file_path]
235+
if require_enhanced_cps:
236+
required_files.append(EnhancedCPS_2024.file_path)
237+
238+
for file_path in required_files:
239+
if not file_path.exists():
240+
raise FileNotFoundError(f"Expected dataset not found at {file_path}")
241+
validate_dataset(file_path)
221242
print("\nAll dataset validations passed.")
222243

223244

@@ -230,5 +251,13 @@ def validate_all_datasets():
230251
action="store_true",
231252
help="Treat enhanced_cps and small_enhanced_cps as optional.",
232253
)
254+
parser.add_argument(
255+
"--validate-only",
256+
action="store_true",
257+
help="Validate built datasets without uploading them.",
258+
)
233259
args = parser.parse_args()
234-
upload_datasets(require_enhanced_cps=not args.no_require_enhanced_cps)
260+
if args.validate_only:
261+
validate_built_datasets(require_enhanced_cps=not args.no_require_enhanced_cps)
262+
else:
263+
upload_datasets(require_enhanced_cps=not args.no_require_enhanced_cps)

0 commit comments

Comments
 (0)