diff --git a/.github/workflows/site-snapshot.yml b/.github/workflows/site-snapshot.yml index ffa7db0..f50f260 100644 --- a/.github/workflows/site-snapshot.yml +++ b/.github/workflows/site-snapshot.yml @@ -29,6 +29,13 @@ jobs: ref: main path: microplex + - name: Check out microunit + uses: actions/checkout@v4 + with: + repository: CosilicoAI/microunit + ref: main + path: microunit + - name: Set up Python uses: actions/setup-python@v5 with: diff --git a/docs/aca-ptc-multiplier-source-choice.md b/docs/aca-ptc-multiplier-source-choice.md new file mode 100644 index 0000000..e25d671 --- /dev/null +++ b/docs/aca-ptc-multiplier-source-choice.md @@ -0,0 +1,113 @@ +# ACA PTC Multiplier Source Choice + +This records the first Microplex-US reconstruction of +`policyengine-us-data`'s `aca_ptc_multipliers_2022_2024.csv` from Arch +publisher-source consumer facts. + +## Recipe + +Inputs: + +- KFF full-year average marketplace effectuated enrollment, 2022 and 2024 +- CMS 2022 OEP state-level average monthly APTC +- CMS 2024 OEP state-level average monthly APTC +- CMS full-year 2022 effectuated-enrollment workbook average monthly APTC + +Source selection: + +- `enroll_2022` and `enroll_2024`: KFF full-year effectuated enrollment +- `aptc_2024`: CMS 2024 OEP average monthly APTC +- `aptc_2022`: CMS 2022 OEP average monthly APTC where published, with CMS + full-year 2022 average monthly APTC as fallback + +Derived columns: + +- `vol_mult = enroll_2024 / enroll_2022` +- `val_mult = aptc_2024 / aptc_2022` +- PE's state `tax_unit_count` factor uses `vol_mult` +- PE's state `aca_ptc` amount factor uses `vol_mult * val_mult` + +## Reproduction + +Build the five Arch source-package suites, then run: + +```bash +uv run microplex-us-build-aca-ptc-multipliers \ + /tmp/mp-aca-ptc-arch-sources/kff-2022/consumer_facts.jsonl \ + /tmp/mp-aca-ptc-arch-sources/kff-2024/consumer_facts.jsonl \ + /tmp/mp-aca-ptc-arch-sources/cms-oep-2022/consumer_facts.jsonl \ + /tmp/mp-aca-ptc-arch-sources/cms-oep-2024/consumer_facts.jsonl \ + /tmp/mp-aca-ptc-arch-sources/cms-effectuated-2022/consumer_facts.jsonl \ + --out /tmp/mp-aca-ptc-arch-sources/aca_ptc_multipliers_2022_2024.csv +``` + +The 2026-05-12 run wrote 51 rows. Compared with PE's incumbent +`policyengine_us_data/storage/aca_ptc_multipliers_2022_2024.csv`: + +- state set matches +- `enroll_2022` matches for all 51 states +- `enroll_2024` matches for all 51 states +- `vol_mult` matches for all 51 states +- `aptc_2024` matches for all 51 states +- `aptc_2022` differs for 22 states +- `val_mult` differs for the same 22 states + +## PE Incumbent Provenance Trace + +The local `policyengine-us-data` history does not contain a generator for the +incumbent CSV. `git log --follow` shows the file first appearing at its current +path in `8d2c49fa15a515e2379d1b4b5e2c1856a1d4ebe9` on 2026-02-11: +`Add hierarchical uprating notebook, fix verification, move ACA PTC +multipliers`. The commit adds +`policyengine_us_data/storage/aca_ptc_multipliers_2022_2024.csv` directly, plus +notebooks which document that ACA PTC factors are loaded from the CSV and +described as CMS/KFF enrollment data. Those notebooks do not show row-level +source derivation. + +Spot checks against the raw CMS 2022 OEP state-level source support the +Microplex-US source choice for the mismatching states where OEP publishes a +number. For example, current Arch-selected OEP values are New Jersey `489`, New +Mexico `460`, and Virginia `506`, matching the CMS OEP +`APTC_Cnsmr_Avg_APTC` column. The PE incumbent has `504`, `534`, and `407` for +those states, respectively. Nevada remains the explicit fallback case because +the CMS 2022 OEP state-level file reports no Nevada average monthly APTC fact; +Microplex-US uses the CMS full-year effectuated-enrollment value `429.75`. + +## Reconciliation Queue + +States not listed matched PE's incumbent CSV exactly. For listed states, the +Microplex-US value is the Arch publisher-source value selected by the recipe +above. Nevada is the known CMS full-year fallback case because the CMS 2022 OEP +state-level source package has no Nevada average monthly APTC fact. + +| State | PE aptc_2022 | Microplex-US aptc_2022 | PE val_mult | Microplex-US val_mult | +| --- | ---: | ---: | ---: | ---: | +| Nevada | 435 | 429.75 | 1.006896551724138 | 1.019197207678883 | +| New Jersey | 504 | 489 | 1.0337301587301588 | 1.065439672801636 | +| New Mexico | 534 | 460 | 1.0318352059925093 | 1.1978260869565218 | +| New York | 364 | 363 | 1.25 | 1.2534435261707988 | +| North Carolina | 583 | 579 | 0.9571183533447685 | 0.9637305699481865 | +| North Dakota | 436 | 452 | 0.9931192660550459 | 0.9579646017699115 | +| Ohio | 479 | 437 | 1.0396659707724425 | 1.139588100686499 | +| Oklahoma | 577 | 558 | 0.9965337954939342 | 1.0304659498207884 | +| Oregon | 503 | 489 | 1.0417495029821073 | 1.0715746421267893 | +| Pennsylvania | 523 | 501 | 1.0133843212237095 | 1.0578842315369261 | +| Rhode Island | 427 | 403 | 1.063231850117096 | 1.1265508684863523 | +| South Carolina | 566 | 512 | 0.9770318021201413 | 1.080078125 | +| South Dakota | 649 | 640 | 0.9414483821263482 | 0.9546875 | +| Tennessee | 572 | 543 | 1.013986013986014 | 1.0681399631675874 | +| Texas | 539 | 502 | 0.9944341372912802 | 1.0677290836653386 | +| Utah | 385 | 370 | 1.0935064935064935 | 1.1378378378378378 | +| Vermont | 620 | 566 | 1.132258064516129 | 1.2402826855123674 | +| Virginia | 407 | 506 | 0.995085995085995 | 0.8003952569169961 | +| Washington | 438 | 437 | 1.0342465753424657 | 1.036613272311213 | +| West Virginia | 1057 | 1002 | 0.97918637653737 | 1.032934131736527 | +| Wisconsin | 562 | 530 | 1.0177935943060499 | 1.079245283018868 | +| Wyoming | 873 | 812 | 0.9885452462772051 | 1.062807881773399 | + +Open reconciliation decision: + +- Treat the Microplex-US output as the publisher-source reconstruction. +- Treat PE byte parity as a separate legacy-compatibility target. Do not add + overrides unless a row-level legacy source or intentional source-choice table + is supplied. diff --git a/docs/arch-target-gap-queue.md b/docs/arch-target-gap-queue.md new file mode 100644 index 0000000..11f4b66 --- /dev/null +++ b/docs/arch-target-gap-queue.md @@ -0,0 +1,135 @@ +# Arch Target Gap Queue + +The Arch target gap queue is a Microplex-side review tool. It compares a +Microplex target profile to a queryable Arch target DB and emits rows that help +humans or agents decide what Arch source work is missing. + +The queue does not make Arch own Microplex target selection. Profile membership, +source aging, reconciliation, activation, and model-variable aliases remain in +`microplex-us`. + +## Boundary Rules + +- Arch stores publisher/source facts with provenance, constraints, periods, + geography, and source lineage. +- Arch should not duplicate a source fact only because Microplex names a model + variable differently. +- Microplex adapters may map one Arch source fact into simulator-specific target + semantics. For example, Arch + `irs_soi.returns_with_income_tax_after_credits` can satisfy the + PolicyEngine `income_tax_positive` count target because SOI Table 1.1 reports + the count of returns with positive income tax after credits. +- A gap row is an authoring hint, not proof that a source exists. +- Rows marked as source-mapping review or deprioritized must be reviewed before + assigning loader work to agents. + +## Categories + +`gap_category` is the high-level agent-readiness taxonomy: + +| Category | Meaning | Default action | +| --- | --- | --- | +| `covered` | An Arch target record already satisfies the target cell. | No task. | +| `ready_primary_loader` | The expected publisher source and Arch variable shape are known, but the record is missing. | Assign source-loader/spec work. | +| `ready_rollup_or_geography` | The Arch variable exists but not at the requested geography. | Add rollup/geography records or review source geography. | +| `adapter_or_constraint_review` | The Arch variable exists at the geography, but filters or adapter matching do not cover the cell. | Review constraints and adapter mapping. | +| `source_mapping_review` | The queue cannot identify a defensible source fact or Arch variable shape. | Human source-mapping review first. | +| `survey_or_model_input_deprioritized` | The cell is currently treated as a survey/model-input proxy rather than a primary administrative source task. | Defer unless a primary source is identified. | + +`loader_status` is the lower-level diagnostic used to derive the category. Use +`gap_category` for agent routing and `loader_status` for debugging why a cell +landed there. + +## Current PolicyEngine Profile Boundary + +`pe_native_broad` keeps the raw PolicyEngine parity surface intact. It includes +all currently tracked broad target cells, including survey/model-input rows and +cells whose publisher-source semantics still need review. + +`pe_native_broad_source_backed` is the Arch-backed calibration/profile boundary. +It excludes only cells with explicit reasons in +`src/microplex_us/policyengine/target_profiles.py`, such as: + +- SOI multi-domain cells that would require joint AGI, filing status, and + positive income-tax-before-credits facts not currently published by the loaded + SOI packages +- survey-heavy or model-input cells such as rent, child support, + non-Part-B medical premium/expense components, SPM capped expenses, and + `ssn_card_type` +- source-near but non-equivalent rows such as `childcare_expenses`, where IRS + credit expenses and W-2 dependent-care benefits are narrower tax concepts +- pregnancy stock by state, where live births are a flow rather than a direct + source fact for the PolicyEngine target + +## Current Local Snapshot + +Snapshot date: 2026-05-22. + +Inputs: + +- `/Users/maxghenis/CosilicoAI/arch/arch/fixtures/consumer_facts.jsonl` +- `/Users/maxghenis/CosilicoAI/arch/macro/targets.db` +- `/tmp/arch-suite-hhs-acf-tanf-caseload-2024/consumer_facts.jsonl` +- `/tmp/arch-suite-soi-historic-table-2-2022/consumer_facts.jsonl` +- `/tmp/arch-suite-hhs-acf-liheap-fy2024-national-profile/consumer_facts.jsonl` +- `/tmp/arch-suite-soi-historic-table-2-state-agi-2022/consumer_facts.jsonl` +- `/tmp/arch-suite-soi-w2-statistics-2020/consumer_facts.jsonl` +- `/tmp/arch-suite-soi-table-1-4-2023/consumer_facts.jsonl` +- `/tmp/arch-suite-federal-reserve-z1-household-net-worth/consumer_facts.jsonl` +- `/tmp/arch-suite-cms-medicare-trustees-report-2025-part-b-premium-income/consumer_facts.jsonl` + +Command: + +```bash +uv run --extra policyengine microplex-us-arch-target-refresh \ + --arch-targets-db /Users/maxghenis/CosilicoAI/arch/arch/fixtures/consumer_facts.jsonl \ + --arch-targets-db /Users/maxghenis/CosilicoAI/arch/macro/targets.db \ + --arch-targets-db /tmp/arch-suite-hhs-acf-tanf-caseload-2024/consumer_facts.jsonl \ + --arch-targets-db /tmp/arch-suite-soi-historic-table-2-2022/consumer_facts.jsonl \ + --arch-targets-db /tmp/arch-suite-hhs-acf-liheap-fy2024-national-profile/consumer_facts.jsonl \ + --arch-targets-db /tmp/arch-suite-soi-historic-table-2-state-agi-2022/consumer_facts.jsonl \ + --arch-targets-db /tmp/arch-suite-soi-w2-statistics-2020/consumer_facts.jsonl \ + --arch-targets-db /tmp/arch-suite-soi-table-1-4-2023/consumer_facts.jsonl \ + --arch-targets-db /tmp/arch-suite-federal-reserve-z1-household-net-worth/consumer_facts.jsonl \ + --arch-targets-db /tmp/arch-suite-cms-medicare-trustees-report-2025-part-b-premium-income/consumer_facts.jsonl \ + --period 2024 \ + --profile pe_native_broad_source_backed \ + --output-dir artifacts/arch-target-coverage-source-backed +``` + +Coverage: + +- 174 target cells in `pe_native_broad_source_backed` +- 174 covered +- 0 uncovered +- 100.0% coverage + +The raw `pe_native_broad` profile is at 174 of 189 covered with 15 explicitly +reviewed rows outside the source-backed boundary. Federal Reserve Z.1 household +net worth and CMS Medicare Trustees Report Part B premium income are now +source-backed. + +| Category | Rows | +| --- | ---: | +| `adapter_or_constraint_review` | 3 | +| `source_mapping_review` | 2 | +| `survey_or_model_input_deprioritized` | 10 | + +Generated outputs: + +- `artifacts/arch-target-coverage-source-backed/pe_native_broad_source_backed_2024_coverage.json` +- `artifacts/arch-target-coverage-source-backed/pe_native_broad_source_backed_2024_gaps.json` +- `artifacts/arch-target-coverage-source-backed/pe_native_broad_source_backed_2024_gaps.csv` +- `artifacts/arch-target-coverage-source-backed/pe_native_broad_source_backed_2024_summary.md` +- `artifacts/arch-target-coverage-broad-plus/pe_native_broad_2024_coverage.json` +- `artifacts/arch-target-coverage-broad-plus/pe_native_broad_2024_gaps.json` +- `artifacts/arch-target-coverage-broad-plus/pe_native_broad_2024_gaps.csv` +- `artifacts/arch-target-coverage-broad-plus/pe_native_broad_2024_summary.md` + +Remaining work is concentrated in: + +- the raw `pe_native_broad` cells excluded from the source-backed profile, if a + future primary publisher source can support them without changing semantics +- keeping the UK source-backed/raw boundary aligned with the same rule: leave + raw PE target rows visible, and exclude only rows where source equivalence is + not defensible diff --git a/pyproject.toml b/pyproject.toml index e17a836..d3b9b0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ authors = [ ] requires-python = ">=3.13" dependencies = [ - "microplex[calibrate]", + "microplex[calibrate] @ git+https://github.com/PolicyEngine/microplex.git@1e0627182f9df40aacd7043c96956c2895bf9d30", "duckdb>=1.2", "requests>=2.31", ] @@ -23,25 +23,43 @@ dev = [ "pytest>=7.0", "ruff>=0.1", ] +r2 = [ + "boto3>=1.34", +] policyengine = [ "microimpute==1.15.1 ; python_full_version >= '3.12' and python_full_version < '3.15'", "policyengine-us==1.587.0; python_version >= '3.11' and python_version < '3.15'", + "spm-calculator>=0.3.1", ] [project.urls] Repository = "https://github.com/PolicyEngine/microplex-us" [project.scripts] +microplex-us-arch-target-coverage = "microplex_us.targets.arch:main_coverage" +microplex-us-arch-target-gaps = "microplex_us.targets.arch:main_gaps" +microplex-us-arch-target-parity = "microplex_us.targets.arch:main_parity" +microplex-us-arch-target-refresh = "microplex_us.targets.arch:main_refresh" +microplex-us-arch-target-smoke = "microplex_us.targets.arch:main_smoke" +microplex-us-build-aca-ptc-multipliers = "microplex_us.targets.aca_ptc:main" microplex-us-backfill-pe-native-audit = "microplex_us.pipelines.backfill_pe_native_audit:main" microplex-us-backfill-pe-native-scores = "microplex_us.pipelines.backfill_pe_native_scores:main" microplex-us-check-site-snapshot = "microplex_us.pipelines.check_site_snapshot:main" +microplex-us-pe-dataset-readiness = "microplex_us.pipelines.pe_us_dataset_readiness:main" +microplex-us-dashboard = "microplex_us.pipelines.dashboard:main" +microplex-us-pe-native-calibration-benchmark = "microplex_us.pipelines.pe_native_calibration_benchmark:main" microplex-us-pe-native-target-diagnostics = "microplex_us.pipelines.pe_native_scores:main_target_diagnostics" +microplex-us-r2-archive-artifact = "microplex_us.pipelines.r2_artifacts:main" +microplex-us-reweight-cd-age-targets = "microplex_us.pipelines.cd_age_reweighting:main" microplex-us-score-pe-native-loss = "microplex_us.pipelines.pe_native_scores:main" microplex-us-version-bump-benchmark = "microplex_us.pipelines.version_benchmark:main" [tool.hatch.build.targets.wheel] packages = ["src/microplex_us"] +[tool.hatch.metadata] +allow-direct-references = true + [tool.hatch.build.targets.wheel.force-include] "src/microplex_us/pipelines/pe_native_scores.py" = "microplex_us/pipelines/pe_native_scores.py" @@ -65,6 +83,3 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "examples/**/*.py" = ["E402"] "tests/**/*.py" = ["E402", "N802"] - -[tool.uv.sources] -microplex = { path = "../microplex", editable = true } diff --git a/src/microplex_us/geography.py b/src/microplex_us/geography.py index 056ea0b..6e89ee2 100644 --- a/src/microplex_us/geography.py +++ b/src/microplex_us/geography.py @@ -32,15 +32,86 @@ PACKAGE_ROOT.parent / "microplex" / "data", ) DEFAULT_DATA_DIR = next( - (candidate for candidate in DEFAULT_DATA_DIR_CANDIDATES if candidate.exists()), - DEFAULT_DATA_DIR_CANDIDATES[0], + ( + candidate + for candidate in DEFAULT_DATA_DIR_CANDIDATES + if (candidate / "block_probabilities.parquet").exists() + ), + next( + (candidate for candidate in DEFAULT_DATA_DIR_CANDIDATES if candidate.exists()), + DEFAULT_DATA_DIR_CANDIDATES[0], + ), ) DEFAULT_BLOCK_PROBABILITIES_PATH = DEFAULT_DATA_DIR / "block_probabilities.parquet" +DEFAULT_BLOCK_PROBABILITIES_SPM_GEOGRAPHY_PATH = ( + DEFAULT_DATA_DIR / "block_probabilities_spm_geography.parquet" +) +SPM_METRO_AREA_COLUMN = "spm_metro_area" +CENSUS_CBSA_DELINEATION_URL = ( + "https://www2.census.gov/programs-surveys/metro-micro/geographies/" + "reference-files/2020/delineation-files/list1_2020.xls" +) +US_STATE_ABBR_BY_FIPS = { + "01": "AL", + "02": "AK", + "04": "AZ", + "05": "AR", + "06": "CA", + "08": "CO", + "09": "CT", + "10": "DE", + "11": "DC", + "12": "FL", + "13": "GA", + "15": "HI", + "16": "ID", + "17": "IL", + "18": "IN", + "19": "IA", + "20": "KS", + "21": "KY", + "22": "LA", + "23": "ME", + "24": "MD", + "25": "MA", + "26": "MI", + "27": "MN", + "28": "MS", + "29": "MO", + "30": "MT", + "31": "NE", + "32": "NV", + "33": "NH", + "34": "NJ", + "35": "NM", + "36": "NY", + "37": "NC", + "38": "ND", + "39": "OH", + "40": "OK", + "41": "OR", + "42": "PA", + "44": "RI", + "45": "SC", + "46": "SD", + "47": "TN", + "48": "TX", + "49": "UT", + "50": "VT", + "51": "VA", + "53": "WA", + "54": "WV", + "55": "WI", + "56": "WY", + "72": "PR", +} def load_block_probabilities(path: str | Path | None = None) -> pd.DataFrame: """Load US Census block probabilities from parquet.""" - path = DEFAULT_BLOCK_PROBABILITIES_PATH if path is None else Path(path) + path = default_runtime_block_probabilities_path() if path is None else Path(path) + if path is None: + path = DEFAULT_BLOCK_PROBABILITIES_PATH if not path.exists(): raise FileNotFoundError( f"Block probabilities file not found at {path}.\n" @@ -49,16 +120,292 @@ def load_block_probabilities(path: str | Path | None = None) -> pd.DataFrame: return pd.read_parquet(path) +def default_runtime_block_probabilities_path() -> Path | None: + """Return the preferred runtime block probabilities file when available.""" + if DEFAULT_BLOCK_PROBABILITIES_SPM_GEOGRAPHY_PATH.exists(): + return DEFAULT_BLOCK_PROBABILITIES_SPM_GEOGRAPHY_PATH + if DEFAULT_BLOCK_PROBABILITIES_PATH.exists(): + return DEFAULT_BLOCK_PROBABILITIES_PATH + return None + + def normalize_us_state_fips(value: Any) -> str: """Normalize US state FIPS values to two-character strings.""" return str(int(round(float(value)))).zfill(2) +def _normalize_us_state_fips_series(values: pd.Series) -> pd.Series: + numeric = pd.to_numeric(values, errors="coerce") + normalized = numeric.round().astype("Int64").astype("string").str.zfill(2) + return normalized.mask(numeric.isna()) + + +def normalize_us_county_fips(value: Any) -> str | None: + """Normalize US county FIPS values to five-character strings.""" + try: + if pd.isna(value): + return None + text = str(value).strip() + if not text: + return None + numeric = pd.to_numeric(pd.Series([text]), errors="coerce").iloc[0] + if pd.notna(numeric): + return str(int(round(float(numeric)))).zfill(COUNTY_GEOID_LEN) + digits = "".join(character for character in text if character.isdigit()) + return digits.zfill(COUNTY_GEOID_LEN) if digits else None + except (TypeError, ValueError, OverflowError): + return None + + +def normalize_state_legislative_district_id( + value: Any, + *, + chamber: str | None = None, +) -> str | None: + """Normalize state legislative district IDs to STATE-SLDU/SLDL-NNN.""" + if value is None: + return None + try: + if pd.isna(value): + return None + except (TypeError, ValueError): + pass + raw = str(value).strip() + if not raw: + return None + + labeled = _normalize_labeled_state_legislative_district_id(raw) + if labeled is not None: + return labeled + + if raw.startswith("610U900US"): + raw = raw[-5:] + chamber = "upper" + elif raw.startswith("620L900US"): + raw = raw[-5:] + chamber = "lower" + + if raw.isdigit() and len(raw) >= 5 and chamber in {"upper", "lower"}: + state_fips = raw[:2] + district_code = raw[2:] + state_abbr = US_STATE_ABBR_BY_FIPS.get(state_fips, state_fips) + chamber_label = "SLDU" if chamber == "upper" else "SLDL" + return f"{state_abbr}-{chamber_label}-{_normalize_sld_district(district_code)}" + + return raw + + +def _normalize_labeled_state_legislative_district_id(raw: str) -> str | None: + parts = raw.split("-") + if len(parts) != 3: + return None + state, chamber_label, district_code = (part.strip() for part in parts) + if not state or not chamber_label or not district_code: + return None + canonical_chamber_label = { + "SLDU": "SLDU", + "SD": "SLDU", + "SLDL": "SLDL", + "HD": "SLDL", + "AD": "SLDL", + }.get(chamber_label.upper()) + if canonical_chamber_label is None: + return None + return ( + f"{state.upper()}-{canonical_chamber_label}-" + f"{_normalize_sld_district(district_code)}" + ) + + +def _normalize_sld_district(value: Any) -> str: + text = str(value).strip() + if text.endswith(".0") and text[:-2].isdigit(): + text = text[:-2] + return f"{int(text):03d}" if text.isdigit() else text + + +@lru_cache(maxsize=8) +def _spm_metro_area_codes(year: int) -> frozenset[str]: + try: + from spm_calculator.geoadj import list_metro_areas + except ImportError: + return frozenset() + return frozenset(str(area["code"]) for area in list_metro_areas(year)) + + +def state_nonmetro_spm_area_code( + state_fips: Any, + *, + year: int = 2024, +) -> str | None: + """Return the Census SPM state-nonmetro area code when one exists.""" + return _state_spm_area_code(state_fips, suffix=2, year=year) + + +def _state_spm_area_code( + state_fips: Any, + *, + suffix: int, + year: int, +) -> str | None: + try: + normalized_state = normalize_us_state_fips(state_fips) + except (TypeError, ValueError, OverflowError): + return None + if normalized_state == "00": + return None + code = str(int(normalized_state) * 1_000 + suffix) + return code if code in _spm_metro_area_codes(year) else None + + +def _state_spm_area_code_series( + values: pd.Series, + *, + suffix: int, + year: int, +) -> pd.Series: + states = _normalize_us_state_fips_series(values) + numeric_states = pd.to_numeric(states, errors="coerce") + codes = (numeric_states * 1_000 + suffix).round().astype("Int64").astype("string") + valid_codes = _spm_metro_area_codes(year) + return codes.where( + states.notna() & states.ne("00") & codes.isin(valid_codes) + ).astype("string") + + +@lru_cache(maxsize=1) +def _census_cbsa_crosswalk() -> dict[str, str] | None: + """Load county-to-CBSA crosswalk from Census' official delineation file.""" + try: + delineation = pd.read_excel( + CENSUS_CBSA_DELINEATION_URL, + header=2, + dtype=str, + ) + except Exception: + return None + + required_columns = {"CBSA Code", "FIPS State Code", "FIPS County Code"} + if not required_columns.issubset(delineation.columns): + return None + + rows = delineation.dropna( + subset=["CBSA Code", "FIPS State Code", "FIPS County Code"] + ).copy() + rows["county_fips"] = rows["FIPS State Code"].astype(str).str.zfill(2) + rows[ + "FIPS County Code" + ].astype(str).str.zfill(3) + rows["cbsa_code"] = rows["CBSA Code"].astype(str).str.strip() + return dict(zip(rows["county_fips"], rows["cbsa_code"], strict=False)) + + +def _normalize_census_area_code(value: Any) -> str | None: + if value is None or pd.isna(value): + return None + text = str(value).strip() + if not text or text.lower() in {"nan", "none", ""} or text in {"0", "00000"}: + return None + numeric = pd.to_numeric(pd.Series([text]), errors="coerce").iloc[0] + return str(int(round(float(numeric)))) if pd.notna(numeric) else text + + +def _normalize_census_area_code_series(values: pd.Series) -> pd.Series: + result = values.astype("string").str.strip() + invalid = ( + result.isna() + | result.str.lower().isin({"", "nan", "none", ""}) + | result.isin({"0", "00000"}) + ) + numeric = pd.to_numeric(result, errors="coerce") + numeric_codes = numeric.round().astype("Int64").astype("string") + result = result.mask(numeric.notna(), numeric_codes) + return result.mask(invalid).astype("string") + + +def _normalize_spm_area_code(value: Any, *, year: int) -> str | None: + code = _normalize_census_area_code(value) + if code is None: + return None + return code if code in _spm_metro_area_codes(year) else None + + +def add_spm_metro_area_geography( + frame: pd.DataFrame, + *, + year: int = 2024, + county_column: str = "county_fips", + state_column: str = "state_fips", + cbsa_column: str = "cbsa_code", + spm_metro_area_column: str = SPM_METRO_AREA_COLUMN, + derive_cbsa_from_primary_source: bool = True, +) -> pd.DataFrame: + """Attach Census SPM metro/nonmetro area IDs from block-derived geography. + + The final SPM threshold geography is a Census metropolitan area code when + SPM publishes one; otherwise it is a state metro/nonmetro SPM code. We only + classify blank, micropolitan, or unsupported CBSA values into a state area + when they came from a trusted CBSA source: either an existing CBSA column or + the Census delineation input. + """ + if frame.empty: + result = frame.copy() + if spm_metro_area_column not in result.columns: + result[spm_metro_area_column] = pd.Series(dtype="string") + return result + + result = frame.copy() + trusted_cbsa_source = cbsa_column in result.columns + if cbsa_column in result.columns: + cbsa_values = _normalize_census_area_code_series(result[cbsa_column]) + else: + cbsa_values = pd.Series(pd.NA, index=result.index, dtype="string") + + if ( + derive_cbsa_from_primary_source + and county_column in result.columns + and cbsa_values.isna().any() + ): + cbsa_crosswalk = _census_cbsa_crosswalk() + if cbsa_crosswalk is not None: + trusted_cbsa_source = True + county_values = result[county_column].map(normalize_us_county_fips) + derived_cbsa = county_values.map(cbsa_crosswalk) + cbsa_values = cbsa_values.combine_first( + _normalize_census_area_code_series(derived_cbsa) + ) + + result[cbsa_column] = cbsa_values.astype("string") + + spm_area = cbsa_values.where(cbsa_values.isin(_spm_metro_area_codes(year))).astype( + "string" + ) + if trusted_cbsa_source and state_column in result.columns: + nonmetro_codes = _state_spm_area_code_series( + result[state_column], + suffix=2, + year=year, + ) + state_metro_codes = _state_spm_area_code_series( + result[state_column], + suffix=1, + year=year, + ) + state_fallback_codes = nonmetro_codes.combine_first( + state_metro_codes.where(cbsa_values.notna()) + ) + spm_area = spm_area.combine_first(state_fallback_codes) + + result[spm_metro_area_column] = spm_area + return result + + def derive_geographies( block_geoids: list[str] | np.ndarray | pd.Series, include_cd: bool = False, include_sld: bool = False, + include_spm_metro_area: bool = False, block_data: pd.DataFrame | None = None, + year: int = 2024, ) -> pd.DataFrame: """Derive parent geographies from Census block GEOIDs.""" geoids = pd.Series(block_geoids).astype(str) @@ -73,12 +420,20 @@ def derive_geographies( if include_cd or include_sld: block_data = load_block_probabilities() if block_data is None else block_data if include_cd: - result["cd_id"] = geoids.map(dict(zip(block_data["geoid"], block_data["cd_id"]))) + result["cd_id"] = geoids.map( + dict(zip(block_data["geoid"], block_data["cd_id"])) + ) if include_sld: if "sldu_id" in block_data.columns: - result["sldu_id"] = geoids.map(dict(zip(block_data["geoid"], block_data["sldu_id"]))) + result["sldu_id"] = geoids.map( + dict(zip(block_data["geoid"], block_data["sldu_id"])) + ) if "sldl_id" in block_data.columns: - result["sldl_id"] = geoids.map(dict(zip(block_data["geoid"], block_data["sldl_id"]))) + result["sldl_id"] = geoids.map( + dict(zip(block_data["geoid"], block_data["sldl_id"])) + ) + if include_spm_metro_area: + result = add_spm_metro_area_geography(result, year=year) return result @@ -166,15 +521,43 @@ def get_all_geographies(self, block_geoid: str) -> dict[str, str | None]: "cd_id": self.get_cd(block_geoid), "sldu_id": self.get_sldu(block_geoid), "sldl_id": self.get_sldl(block_geoid), + "spm_metro_area": self._get_spm_metro_area(block_geoid), } + def _get_spm_metro_area(self, block_geoid: str) -> str | None: + if SPM_METRO_AREA_COLUMN in self.data.columns: + match = self.data.loc[self.data["geoid"].astype(str).eq(block_geoid)] + if not match.empty: + value = match[SPM_METRO_AREA_COLUMN].iloc[0] + return None if pd.isna(value) else str(value) + + rows = { + "state_fips": [self.get_state(block_geoid)], + "county_fips": [self.get_county(block_geoid)], + } + if "cbsa_code" in self.data.columns: + match = self.data.loc[self.data["geoid"].astype(str).eq(block_geoid)] + if not match.empty: + rows["cbsa_code"] = [match["cbsa_code"].iloc[0]] + value = add_spm_metro_area_geography( + pd.DataFrame(rows), + derive_cbsa_from_primary_source=False, + )["spm_metro_area"].iloc[0] + return None if pd.isna(value) else str(value) + def to_crosswalk(self) -> AtomicGeographyCrosswalk: crosswalk = self.data.copy() - if "county_fips" not in crosswalk.columns and {"state_fips", "county"}.issubset(crosswalk.columns): - crosswalk["county_fips"] = ( - crosswalk["state_fips"].astype(str) + crosswalk["county"].astype(str) - ) - if "tract_geoid" not in crosswalk.columns and {"state_fips", "county", "tract"}.issubset(crosswalk.columns): + if "county_fips" not in crosswalk.columns and {"state_fips", "county"}.issubset( + crosswalk.columns + ): + crosswalk["county_fips"] = crosswalk["state_fips"].astype(str) + crosswalk[ + "county" + ].astype(str) + if "tract_geoid" not in crosswalk.columns and { + "state_fips", + "county", + "tract", + }.issubset(crosswalk.columns): crosswalk["tract_geoid"] = ( crosswalk["state_fips"].astype(str) + crosswalk["county"].astype(str) @@ -182,7 +565,16 @@ def to_crosswalk(self) -> AtomicGeographyCrosswalk: ) geography_columns = tuple( column - for column in ("state_fips", "county_fips", "tract_geoid", "cd_id", "sldu_id", "sldl_id") + for column in ( + "state_fips", + "county_fips", + "tract_geoid", + "cd_id", + "sldu_id", + "sldl_id", + "cbsa_code", + SPM_METRO_AREA_COLUMN, + ) if column in crosswalk.columns ) return AtomicGeographyCrosswalk( @@ -192,7 +584,9 @@ def to_crosswalk(self) -> AtomicGeographyCrosswalk: probability_column="prob" if "prob" in crosswalk.columns else None, ) - def load_crosswalk(self, query: GeographyQuery | None = None) -> AtomicGeographyCrosswalk: + def load_crosswalk( + self, query: GeographyQuery | None = None + ) -> AtomicGeographyCrosswalk: query = query or GeographyQuery() crosswalk = self.to_crosswalk() if not query.geography_columns and query.probability_column is None: @@ -200,7 +594,8 @@ def load_crosswalk(self, query: GeographyQuery | None = None) -> AtomicGeography return AtomicGeographyCrosswalk( data=crosswalk.data.copy(), atomic_id_column=crosswalk.atomic_id_column, - geography_columns=tuple(query.geography_columns) or crosswalk.geography_columns, + geography_columns=tuple(query.geography_columns) + or crosswalk.geography_columns, probability_column=query.probability_column or crosswalk.probability_column, ) @@ -353,8 +748,16 @@ def __repr__(self) -> str: "BLOCK_GEOID_LEN", "DEFAULT_DATA_DIR", "DEFAULT_BLOCK_PROBABILITIES_PATH", + "DEFAULT_BLOCK_PROBABILITIES_SPM_GEOGRAPHY_PATH", + "CENSUS_CBSA_DELINEATION_URL", + "SPM_METRO_AREA_COLUMN", + "default_runtime_block_probabilities_path", "load_block_probabilities", "normalize_us_state_fips", + "normalize_us_county_fips", + "normalize_state_legislative_district_id", + "state_nonmetro_spm_area_code", + "add_spm_metro_area_geography", "derive_geographies", "BlockGeography", ] diff --git a/src/microplex_us/microdata_roles.py b/src/microplex_us/microdata_roles.py new file mode 100644 index 0000000..99774e4 --- /dev/null +++ b/src/microplex_us/microdata_roles.py @@ -0,0 +1,198 @@ +"""Source-specific microdata variable role metadata. + +This is the Microplex-side bridge to the richer Arch source-data contract: +Arch preserves what a source says, while Microplex decides which source columns +are model inputs versus source-reported outputs or diagnostics. +""" + +from __future__ import annotations + +from enum import Enum + + +class MicrodataVariableRole(Enum): + """How Microplex should treat one source-native microdata variable.""" + + SOURCE_INPUT = "source_input" + REPORTED_RETURN_LINE_INPUT = "reported_return_line_input" + CALCULATED_TAX_OUTPUT = "calculated_tax_output" + + +class PolicyEngineUSVariableRole(Enum): + """How Microplex should treat a PolicyEngine US variable at export time.""" + + PRESERVED_INPUT = "preserved_input" + TAKEUP_INPUT = "takeup_input" + REPORTED_OUTPUT = "reported_output" + CALCULATED_OUTPUT = "calculated_output" + + +PUF_CALCULATED_TAX_OUTPUT_VARIABLES: frozenset[str] = frozenset( + { + "american_opportunity_credit", + "amt_foreign_tax_credit", + "early_withdrawal_penalty", + "energy_efficient_home_improvement_credit", + "excess_withheld_payroll_tax", + "foreign_tax_credit", + "general_business_credit", + "other_credits", + "prior_year_minimum_tax_credit", + "recapture_of_investment_credit", + "savers_credit", + "state_and_local_sales_or_income_tax", + "state_income_tax_paid", + "taxable_social_security", + "taxable_unemployment_compensation", + "unreported_payroll_tax", + } +) + +POLICYENGINE_US_TAKEUP_INPUT_VARIABLES: frozenset[str] = frozenset( + { + "takes_up_aca_if_eligible", + "takes_up_chip_if_eligible", + "takes_up_early_head_start_if_eligible", + "takes_up_eitc", + "takes_up_head_start_if_eligible", + "takes_up_medicaid_if_eligible", + "takes_up_snap_if_eligible", + "takes_up_ssi_if_eligible", + "takes_up_tanf_if_eligible", + "would_claim_wic", + "would_file_if_eligible_for_refundable_credit", + "would_file_taxes_voluntarily", + } +) + +POLICYENGINE_US_REPORTED_BENEFIT_AMOUNT_VARIABLES: frozenset[str] = frozenset( + { + "snap_reported", + "ssi_reported", + "tanf_reported", + } +) + +POLICYENGINE_US_REPORTED_TAX_OUTPUT_VARIABLES: frozenset[str] = frozenset( + PUF_CALCULATED_TAX_OUTPUT_VARIABLES + | { + "state_income_tax_reported", + } +) + +POLICYENGINE_US_REPORTED_OUTPUT_VARIABLES: frozenset[str] = frozenset( + POLICYENGINE_US_REPORTED_BENEFIT_AMOUNT_VARIABLES + | POLICYENGINE_US_REPORTED_TAX_OUTPUT_VARIABLES +) + +POLICYENGINE_US_CALCULATED_OUTPUT_VARIABLES: frozenset[str] = frozenset( + { + "aca_ptc", + "additional_ctc", + "assigned_aca_ptc", + "filing_status", + "loss_limited_net_capital_gains", + "net_capital_gains", + "chip_enrolled", + "ctc", + "early_head_start", + "eitc", + "head_start", + "income_tax", + "income_tax_positive", + "is_aca_ptc_eligible", + "medicaid", + "medicaid_cost", + "medicaid_enrolled", + "non_refundable_ctc", + "premium_tax_credit", + "refundable_ctc", + "snap", + "ssi", + "state_income_tax", + "tanf", + "total_income_tax", + "wic", + } +) + +POLICYENGINE_US_DIRECT_EXPORT_BLOCKED_VARIABLES: frozenset[str] = frozenset( + POLICYENGINE_US_CALCULATED_OUTPUT_VARIABLES + | POLICYENGINE_US_REPORTED_OUTPUT_VARIABLES +) + + +def source_name_matches_prefix(source_name: str, prefix: str) -> bool: + """Return whether a source name is an exact or year-suffixed source prefix.""" + return source_name == prefix or source_name.startswith(f"{prefix}_") + + +def microdata_variable_role( + source_name: str, + variable_name: str, +) -> MicrodataVariableRole: + """Resolve the source-specific role for one microdata variable.""" + if ( + source_name_matches_prefix(source_name, "irs_soi_puf") + and variable_name in PUF_CALCULATED_TAX_OUTPUT_VARIABLES + ): + return MicrodataVariableRole.CALCULATED_TAX_OUTPUT + return MicrodataVariableRole.SOURCE_INPUT + + +def is_model_input_microdata_variable( + source_name: str, + variable_name: str, +) -> bool: + """Return whether a source column should enter model-ready microdata.""" + return microdata_variable_role( + source_name, + variable_name, + ) is not MicrodataVariableRole.CALCULATED_TAX_OUTPUT + + +def non_model_input_microdata_variables( + source_name: str, + variable_names: list[str] | tuple[str, ...] | set[str] | frozenset[str], +) -> tuple[str, ...]: + """Return source columns that should stay out of model-ready microdata.""" + return tuple( + variable_name + for variable_name in variable_names + if not is_model_input_microdata_variable(source_name, variable_name) + ) + + +def policyengine_us_variable_role(variable_name: str) -> PolicyEngineUSVariableRole: + """Resolve the Microplex role for a PolicyEngine US variable name.""" + if variable_name in POLICYENGINE_US_CALCULATED_OUTPUT_VARIABLES: + return PolicyEngineUSVariableRole.CALCULATED_OUTPUT + if variable_name in POLICYENGINE_US_REPORTED_OUTPUT_VARIABLES: + return PolicyEngineUSVariableRole.REPORTED_OUTPUT + if variable_name in POLICYENGINE_US_TAKEUP_INPUT_VARIABLES: + return PolicyEngineUSVariableRole.TAKEUP_INPUT + return PolicyEngineUSVariableRole.PRESERVED_INPUT + + +def is_policyengine_us_direct_export_blocked(variable_name: str) -> bool: + """Return whether a source column may not override a PE-US variable.""" + return ( + policyengine_us_variable_role(variable_name) + in { + PolicyEngineUSVariableRole.CALCULATED_OUTPUT, + PolicyEngineUSVariableRole.REPORTED_OUTPUT, + } + ) + + +def blocked_policyengine_us_direct_export_variables( + variable_names: list[str] | tuple[str, ...] | set[str] | frozenset[str], +) -> tuple[str, ...]: + """Return requested direct overrides that violate the variable contract.""" + return tuple( + sorted( + variable_name + for variable_name in variable_names + if is_policyengine_us_direct_export_blocked(variable_name) + ) + ) diff --git a/src/microplex_us/pipelines/cd_age_reweighting.py b/src/microplex_us/pipelines/cd_age_reweighting.py new file mode 100644 index 0000000..45c9681 --- /dev/null +++ b/src/microplex_us/pipelines/cd_age_reweighting.py @@ -0,0 +1,569 @@ +"""Reweight PE-US H5 datasets to congressional-district age targets.""" + +from __future__ import annotations + +import argparse +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import h5py +import numpy as np +import pandas as pd +from scipy.optimize import minimize + +from microplex_us.pipelines.pe_native_optimization import ( + rewrite_policyengine_us_dataset_weights, +) +from microplex_us.policyengine import PolicyEngineUSDBTargetProvider +from microplex_us.policyengine.us import PolicyEngineUSConstraint + + +@dataclass(frozen=True) +class CDAgeTarget: + """One congressional-district person-count-by-age target.""" + + target_id: int + district_geoid: int + value: float + age_constraints: tuple[PolicyEngineUSConstraint, ...] + period: int + + @property + def age_key(self) -> tuple[tuple[str, str], ...]: + return tuple( + sorted((constraint.operation, str(constraint.value)) for constraint in self.age_constraints) + ) + + +def normalize_at_large_cd_geoids(values: np.ndarray) -> np.ndarray: + """Normalize statewide at-large districts from ``xx00`` to PE target ``xx01``.""" + result = np.asarray(values).copy() + finite = np.isfinite(result.astype(float, copy=False)) + as_int = result.astype(np.int64, copy=False) + at_large = finite & (as_int > 0) & (as_int % 100 == 0) + result[at_large] = as_int[at_large] + 1 + return result.astype(np.int64, copy=False) + + +def load_cd_age_targets( + target_db: str | Path, + *, + period: int = 2024, +) -> list[CDAgeTarget]: + """Load active district person-count-by-age targets from PE's target DB.""" + provider = PolicyEngineUSDBTargetProvider(target_db) + raw_targets = provider.load_targets( + period=period, + variables=["person_count"], + domain_variables=["age"], + geo_levels=["district"], + active_only=True, + ) + targets: list[CDAgeTarget] = [] + for target in raw_targets: + district_constraints = [ + constraint + for constraint in target.constraints + if constraint.variable == "congressional_district_geoid" + ] + age_constraints = tuple( + constraint for constraint in target.constraints if constraint.variable == "age" + ) + if len(district_constraints) != 1 or not age_constraints: + continue + targets.append( + CDAgeTarget( + target_id=int(target.target_id), + district_geoid=int(district_constraints[0].value), + value=float(target.value), + age_constraints=age_constraints, + period=int(target.period), + ) + ) + targets.sort(key=lambda target: (target.district_geoid, target.age_key, target.target_id)) + return targets + + +def reweight_h5_to_cd_age_targets( + *, + input_dataset: str | Path, + target_db: str | Path, + output_dataset: str | Path, + period: int = 2024, + max_iter: int = 300, + tol: float = 1e-9, + preserve_district_weight_sum: bool = True, + details_output: str | Path | None = None, +) -> dict[str, Any]: + """Apply independent per-CD entropy reweighting for age-distribution targets.""" + period_key = str(period) + targets = load_cd_age_targets(target_db, period=period) + if not targets: + raise ValueError("No district person_count-by-age targets were loaded") + + with h5py.File(input_dataset, "r") as handle: + household_ids = np.asarray(handle["household_id"][period_key]) + input_weights = np.asarray(handle["household_weight"][period_key], dtype=np.float64) + household_cd = normalize_at_large_cd_geoids( + np.asarray(handle["congressional_district_geoid"][period_key]) + ) + person_household_id = np.asarray(handle["person_household_id"][period_key]) + age = np.asarray(handle["age"][period_key], dtype=np.float64) + + person_household_index = _map_person_households_to_indices( + household_ids, + person_household_id, + ) + unique_age_keys = sorted({target.age_key for target in targets}) + household_age_counts = _build_household_age_count_matrix( + n_households=len(household_ids), + person_household_index=person_household_index, + age=age, + age_keys=unique_age_keys, + ) + age_key_to_col = {age_key: index for index, age_key in enumerate(unique_age_keys)} + + output_weights = input_weights.copy() + detail_rows: list[dict[str, Any]] = [] + district_failures: list[dict[str, Any]] = [] + targets_by_district: dict[int, list[CDAgeTarget]] = {} + for target in targets: + targets_by_district.setdefault(target.district_geoid, []).append(target) + + for district_geoid, district_targets in sorted(targets_by_district.items()): + household_mask = household_cd == district_geoid + household_indices = np.flatnonzero(household_mask) + if len(household_indices) == 0: + district_failures.append( + { + "district_geoid": district_geoid, + "reason": "no_households", + "target_count": len(district_targets), + } + ) + _append_detail_rows( + detail_rows, + targets=district_targets, + age_key_to_col=age_key_to_col, + household_indices=household_indices, + household_age_counts=household_age_counts, + input_weights=input_weights, + output_weights=output_weights, + status="no_households", + ) + continue + + row_cols = [age_key_to_col[target.age_key] for target in district_targets] + design = household_age_counts[np.ix_(household_indices, row_cols)].T.astype( + np.float64, + copy=False, + ) + target_values = np.asarray([target.value for target in district_targets], dtype=np.float64) + base_weights = input_weights[household_indices] + fit_design = design + fit_targets = target_values + if preserve_district_weight_sum: + fit_design = np.vstack( + [ + design, + np.ones((1, design.shape[1]), dtype=np.float64), + ] + ) + fit_targets = np.concatenate( + [target_values, np.asarray([base_weights.sum()], dtype=np.float64)] + ) + solution = _solve_entropy_weights( + design=fit_design, + base_weights=base_weights, + targets=fit_targets, + max_iter=max_iter, + tol=tol, + ) + output_weights[household_indices] = solution["weights"] + if not solution["success"]: + district_failures.append( + { + "district_geoid": district_geoid, + "reason": solution["message"], + "target_count": len(district_targets), + "max_abs_relative_error": solution["max_abs_relative_error"], + } + ) + _append_detail_rows( + detail_rows, + targets=district_targets, + age_key_to_col=age_key_to_col, + household_indices=household_indices, + household_age_counts=household_age_counts, + input_weights=input_weights, + output_weights=output_weights, + status="ok" if solution["success"] else "not_converged", + ) + + output_path = rewrite_policyengine_us_dataset_weights( + input_dataset_path=input_dataset, + output_dataset_path=output_dataset, + household_weights=output_weights, + period=period, + ) + _normalize_cd_geoids_in_h5(output_path, period=period) + + detail_frame = pd.DataFrame(detail_rows) + if details_output is not None: + detail_path = Path(details_output).expanduser().resolve() + detail_path.parent.mkdir(parents=True, exist_ok=True) + detail_frame.to_csv(detail_path, index=False) + + summary = _summarize_detail_frame( + detail_frame, + input_weight_sum=float(input_weights.sum()), + output_weight_sum=float(output_weights.sum()), + n_households=len(input_weights), + n_persons=len(age), + n_age_bins=len(unique_age_keys), + district_failures=district_failures, + ) + summary["preserve_district_weight_sum"] = bool(preserve_district_weight_sum) + summary["input_dataset"] = str(Path(input_dataset).expanduser().resolve()) + summary["output_dataset"] = str(Path(output_path).expanduser().resolve()) + summary["target_db"] = str(Path(target_db).expanduser().resolve()) + summary["period"] = int(period) + return summary + + +def build_cd_age_constraint_matrix( + *, + input_dataset: str | Path, + target_db: str | Path, + period: int = 2024, + target_weight: float = 1.0, +) -> dict[str, Any]: + """Build scaled sparse rows for CD person-count-by-age targets. + + The returned matrix has shape ``(targets, households)`` and uses the same + ``((estimate - target + 1) / (target + 1)) ** 2`` row scaling convention as + the PE-native broad matrix. + """ + if target_weight <= 0: + raise ValueError("target_weight must be positive") + period_key = str(period) + targets = load_cd_age_targets(target_db, period=period) + if not targets: + raise ValueError("No district person_count-by-age targets were loaded") + + with h5py.File(input_dataset, "r") as handle: + household_ids = np.asarray(handle["household_id"][period_key]) + household_cd = normalize_at_large_cd_geoids( + np.asarray(handle["congressional_district_geoid"][period_key]) + ) + person_household_id = np.asarray(handle["person_household_id"][period_key]) + age = np.asarray(handle["age"][period_key], dtype=np.float64) + + person_household_index = _map_person_households_to_indices( + household_ids, + person_household_id, + ) + unique_age_keys = sorted({target.age_key for target in targets}) + household_age_counts = _build_household_age_count_matrix( + n_households=len(household_ids), + person_household_index=person_household_index, + age=age, + age_keys=unique_age_keys, + ) + age_key_to_col = {age_key: index for index, age_key in enumerate(unique_age_keys)} + + rows: list[np.ndarray] = [] + cols: list[np.ndarray] = [] + vals: list[np.ndarray] = [] + target_values = np.asarray([target.value for target in targets], dtype=np.float64) + scaling = np.sqrt(float(target_weight) / float(len(targets))) / ( + target_values + 1.0 + ) + target_names: list[str] = [] + for row_index, target in enumerate(targets): + count_col = age_key_to_col[target.age_key] + household_indices = np.flatnonzero(household_cd == target.district_geoid) + counts = household_age_counts[household_indices, count_col] + nonzero = counts != 0 + if nonzero.any(): + rows.append(np.full(int(nonzero.sum()), row_index, dtype=np.int32)) + cols.append(household_indices[nonzero].astype(np.int32)) + vals.append((counts[nonzero] * scaling[row_index]).astype(np.float32)) + target_names.append( + "district/census/person_count_by_age/" + f"{target.district_geoid}/{json.dumps(target.age_key, separators=(',', ':'))}" + ) + + if rows: + import scipy.sparse as sp + + matrix = sp.csr_matrix( + ( + np.concatenate(vals), + (np.concatenate(rows), np.concatenate(cols)), + ), + shape=(len(targets), len(household_ids)), + dtype=np.float32, + ) + else: + import scipy.sparse as sp + + matrix = sp.csr_matrix((len(targets), len(household_ids)), dtype=np.float32) + + scaled_target = ((target_values - 1.0) * scaling).astype(np.float32) + return { + "matrix": matrix, + "target": scaled_target, + "metadata": { + "target_names": target_names, + "n_targets_total": int(len(targets)), + "n_targets_kept": int(len(targets)), + "n_districts": int(len({target.district_geoid for target in targets})), + "n_age_bins": int(len(unique_age_keys)), + "target_weight": float(target_weight), + "target_db": str(Path(target_db).expanduser().resolve()), + "family": "district_age_distribution", + }, + } + + +def _map_person_households_to_indices( + household_ids: np.ndarray, + person_household_ids: np.ndarray, +) -> np.ndarray: + household_index = {int(household_id): index for index, household_id in enumerate(household_ids)} + try: + return np.asarray( + [household_index[int(household_id)] for household_id in person_household_ids], + dtype=np.int64, + ) + except KeyError as exc: + raise ValueError(f"person_household_id references missing household_id {exc}") from exc + + +def _build_household_age_count_matrix( + *, + n_households: int, + person_household_index: np.ndarray, + age: np.ndarray, + age_keys: list[tuple[tuple[str, str], ...]], +) -> np.ndarray: + counts = np.zeros((n_households, len(age_keys)), dtype=np.float32) + for col, age_key in enumerate(age_keys): + mask = _evaluate_age_key(age, age_key) + np.add.at(counts[:, col], person_household_index[mask], 1.0) + return counts + + +def _evaluate_age_key( + age: np.ndarray, + age_key: tuple[tuple[str, str], ...], +) -> np.ndarray: + mask = np.ones(len(age), dtype=bool) + for operation, raw_value in age_key: + value = float(raw_value) + if operation == "==": + mask &= age == value + elif operation == "!=": + mask &= age != value + elif operation == ">": + mask &= age > value + elif operation == ">=": + mask &= age >= value + elif operation == "<": + mask &= age < value + elif operation == "<=": + mask &= age <= value + else: + raise ValueError(f"Unsupported age target operation: {operation!r}") + return mask + + +def _solve_entropy_weights( + *, + design: np.ndarray, + base_weights: np.ndarray, + targets: np.ndarray, + max_iter: int, + tol: float, +) -> dict[str, Any]: + support = design.sum(axis=1) > 0 + unsupported = (~support) & (np.abs(targets) > tol) + if unsupported.any(): + estimates = design @ base_weights + return { + "weights": base_weights.copy(), + "success": False, + "message": "unsupported_positive_targets", + "max_abs_relative_error": float( + _abs_relative_error(estimates, targets).max(initial=0.0) + ), + } + + def objective(lam: np.ndarray) -> tuple[float, np.ndarray]: + linear_predictor = np.clip(lam @ design, -50.0, 50.0) + weights = base_weights * np.exp(linear_predictor) + value = float(weights.sum() - np.dot(targets, lam)) + gradient = design @ weights - targets + return value, gradient + + result = minimize( + fun=lambda lam: objective(lam)[0], + x0=np.zeros(design.shape[0], dtype=np.float64), + jac=lambda lam: objective(lam)[1], + method="L-BFGS-B", + options={"maxiter": int(max_iter), "ftol": tol, "gtol": tol}, + ) + linear_predictor = np.clip(result.x @ design, -50.0, 50.0) + weights = base_weights * np.exp(linear_predictor) + estimates = design @ weights + max_error = float(_abs_relative_error(estimates, targets).max(initial=0.0)) + success = bool(result.success) or max_error <= max(1e-4, tol * 100) + return { + "weights": weights, + "success": success, + "message": str(result.message), + "max_abs_relative_error": max_error, + } + + +def _append_detail_rows( + rows: list[dict[str, Any]], + *, + targets: list[CDAgeTarget], + age_key_to_col: dict[tuple[tuple[str, str], ...], int], + household_indices: np.ndarray, + household_age_counts: np.ndarray, + input_weights: np.ndarray, + output_weights: np.ndarray, + status: str, +) -> None: + for target in targets: + col = age_key_to_col[target.age_key] + counts = household_age_counts[household_indices, col] + before = float(np.dot(counts, input_weights[household_indices])) + after = float(np.dot(counts, output_weights[household_indices])) + rows.append( + { + "target_id": target.target_id, + "district_geoid": target.district_geoid, + "age_key": json.dumps(target.age_key), + "target": target.value, + "estimate_before": before, + "estimate_after": after, + "relative_error_before": _relative_error(before, target.value), + "relative_error_after": _relative_error(after, target.value), + "abs_relative_error_before": abs(_relative_error(before, target.value)), + "abs_relative_error_after": abs(_relative_error(after, target.value)), + "period": target.period, + "status": status, + } + ) + + +def _relative_error(estimate: float, target: float) -> float: + if abs(target) <= 1e-12: + return 0.0 if abs(estimate) <= 1e-12 else float("inf") + return float((estimate - target) / abs(target)) + + +def _abs_relative_error(estimate: np.ndarray, target: np.ndarray) -> np.ndarray: + denominator = np.where(np.abs(target) <= 1e-12, 1.0, np.abs(target)) + return np.abs((estimate - target) / denominator) + + +def _summarize_detail_frame( + detail_frame: pd.DataFrame, + *, + input_weight_sum: float, + output_weight_sum: float, + n_households: int, + n_persons: int, + n_age_bins: int, + district_failures: list[dict[str, Any]], +) -> dict[str, Any]: + before = detail_frame["abs_relative_error_before"].to_numpy(dtype=np.float64) + after = detail_frame["abs_relative_error_after"].to_numpy(dtype=np.float64) + return { + "n_targets": int(len(detail_frame)), + "n_districts": int(detail_frame["district_geoid"].nunique()), + "n_households": int(n_households), + "n_persons": int(n_persons), + "n_age_bins": int(n_age_bins), + "input_weight_sum": float(input_weight_sum), + "output_weight_sum": float(output_weight_sum), + "weight_sum_relative_change": float( + (output_weight_sum - input_weight_sum) / input_weight_sum + ), + "mean_abs_relative_error_before": float(before.mean()), + "mean_abs_relative_error_after": float(after.mean()), + "median_abs_relative_error_before": float(np.median(before)), + "median_abs_relative_error_after": float(np.median(after)), + "p90_abs_relative_error_before": float(np.quantile(before, 0.9)), + "p90_abs_relative_error_after": float(np.quantile(after, 0.9)), + "p99_abs_relative_error_before": float(np.quantile(before, 0.99)), + "p99_abs_relative_error_after": float(np.quantile(after, 0.99)), + "max_abs_relative_error_before": float(before.max(initial=0.0)), + "max_abs_relative_error_after": float(after.max(initial=0.0)), + "failed_district_count": int(len(district_failures)), + "district_failures": district_failures, + } + + +def _normalize_cd_geoids_in_h5(path: str | Path, *, period: int) -> None: + period_key = str(period) + with h5py.File(path, "r+") as handle: + if "congressional_district_geoid" not in handle: + return + group = handle["congressional_district_geoid"] + if period_key not in group: + return + values = np.asarray(group[period_key]) + group[period_key][...] = normalize_at_large_cd_geoids(values).astype(values.dtype) + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--input-dataset", required=True) + parser.add_argument("--target-db", required=True) + parser.add_argument("--output-dataset", required=True) + parser.add_argument("--period", type=int, default=2024) + parser.add_argument("--max-iter", type=int, default=300) + parser.add_argument("--tol", type=float, default=1e-9) + parser.add_argument( + "--no-preserve-district-weight-sum", + dest="preserve_district_weight_sum", + action="store_false", + help=( + "Do not append a per-district household-weight preservation row. " + "The default preserves district household totals while fitting CD-age targets." + ), + ) + parser.set_defaults(preserve_district_weight_sum=True) + parser.add_argument("--summary-output") + parser.add_argument("--details-output") + args = parser.parse_args(argv) + + summary = reweight_h5_to_cd_age_targets( + input_dataset=args.input_dataset, + target_db=args.target_db, + output_dataset=args.output_dataset, + period=args.period, + max_iter=args.max_iter, + tol=args.tol, + preserve_district_weight_sum=args.preserve_district_weight_sum, + details_output=args.details_output, + ) + payload = json.dumps(summary, indent=2, sort_keys=True, allow_nan=False) + if args.summary_output: + summary_path = Path(args.summary_output).expanduser().resolve() + summary_path.parent.mkdir(parents=True, exist_ok=True) + summary_path.write_text(payload + "\n") + print(payload) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/microplex_us/pipelines/dashboard.py b/src/microplex_us/pipelines/dashboard.py new file mode 100644 index 0000000..593f331 --- /dev/null +++ b/src/microplex_us/pipelines/dashboard.py @@ -0,0 +1,1630 @@ +"""Build the living Microplex diagnostic dashboard payload.""" + +from __future__ import annotations + +import argparse +import csv +import json +import re +import subprocess +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +_ROOT = Path(__file__).resolve().parents[3] +_DEFAULT_ARTIFACT_ROOT = _ROOT / "artifacts" +_DEFAULT_OUTPUT_PATH = _DEFAULT_ARTIFACT_ROOT / "microplex_dashboard_current.json" +_DEFAULT_TARGET_DIAGNOSTICS_PATH = ( + _DEFAULT_ARTIFACT_ROOT / "pe_native_target_diagnostics_current.json" +) +_DEFAULT_POLICYENGINE_US_DATA_REPO = Path( + "/Users/maxghenis/PolicyEngine/policyengine-us-data" +) + +_PE_MODEL_SLOTS = ( + { + "id": "policyengine_legacy_ecps", + "label": "PE legacy enhanced CPS", + "status": "available_as_baseline", + "notes": ( + "The incumbent enhanced CPS is represented as the baseline side of " + "PE-native score artifacts." + ), + }, + { + "id": "policyengine_small_l0", + "label": "PE small-L0 local model", + "status": "missing_weight_package", + "notes": ( + "Mapped to policyengine-us-data local_net_worth_100 when present. " + "The weight package is not itself a scored H5 dataset." + ), + }, + { + "id": "policyengine_big_l0", + "label": "PE big-L0 local model", + "status": "missing_weight_package", + "notes": ( + "Mapped to policyengine-us-data local_net_worth_100_e300 when " + "present. The weight package is not itself a scored H5 dataset." + ), + }, +) + +_PE_L0_MODEL_SPECS = ( + { + "id": "policyengine_small_l0", + "label": "PE small-L0 local model", + "relative_dir": "policyengine_us_data/storage/calibration/local_net_worth_100", + }, + { + "id": "policyengine_big_l0", + "label": "PE big-L0 local model", + "relative_dir": ( + "policyengine_us_data/storage/calibration/local_net_worth_100_e300" + ), + }, +) + + +@dataclass(frozen=True) +class DashboardPaths: + """Filesystem inputs for the dashboard payload.""" + + artifact_root: Path = _DEFAULT_ARTIFACT_ROOT + target_diagnostics_path: Path = _DEFAULT_TARGET_DIAGNOSTICS_PATH + output_path: Path = _DEFAULT_OUTPUT_PATH + + +def build_dashboard_payload( + *, + artifact_root: str | Path = _DEFAULT_ARTIFACT_ROOT, + target_diagnostics_path: str | Path = _DEFAULT_TARGET_DIAGNOSTICS_PATH, + policyengine_us_data_repo: str | Path | None = _DEFAULT_POLICYENGINE_US_DATA_REPO, + include_tmux: bool = True, +) -> dict[str, Any]: + """Collect scores, local screens, active logs, and target diagnostics.""" + + artifact_root = Path(artifact_root) + target_diagnostics_path = Path(target_diagnostics_path) + score_runs = collect_score_runs(artifact_root) + local_screens = collect_local_target_screens(artifact_root) + pe_l0_models = collect_policyengine_l0_models(policyengine_us_data_repo) + actual_l0_runs = collect_actual_l0_objective_runs(artifact_root) + materialized_l0_scores = collect_materialized_policyengine_l0_scores( + artifact_root + ) + run_contracts = collect_run_contracts(artifact_root) + active_logs = collect_recent_log_summaries(artifact_root) + tmux_sessions = collect_tmux_sessions() if include_tmux else [] + target_diagnostics = _read_json(target_diagnostics_path) + generated_at = datetime.now(timezone.utc).isoformat() + + return { + "dashboard_schema_version": 1, + "generated_at": generated_at, + "artifact_root": str(artifact_root), + "target_diagnostics_path": ( + str(target_diagnostics_path) if target_diagnostics is not None else None + ), + "target_diagnostics": target_diagnostics, + "run_board": { + "generated_at": generated_at, + "score_runs": score_runs, + "local_target_screens": local_screens, + "policyengine_l0_models": pe_l0_models, + "actual_l0_objective_runs": actual_l0_runs, + "materialized_policyengine_l0_scores": materialized_l0_scores, + "run_contracts": run_contracts, + "active_logs": active_logs, + "tmux_sessions": tmux_sessions, + "comparison_matrix": build_comparison_matrix( + score_runs, + local_screens, + pe_l0_models, + materialized_l0_scores, + ), + "apples_to_apples": build_apples_to_apples_groups( + score_runs, + local_screens, + pe_l0_models, + materialized_l0_scores, + ), + "assertions": build_dashboard_assertions( + score_runs, + local_screens, + pe_l0_models, + materialized_l0_scores, + ), + }, + } + + +def collect_score_runs(artifact_root: str | Path) -> list[dict[str, Any]]: + """Read completed PE-native score artifacts under ``artifact_root``.""" + + artifact_root = Path(artifact_root) + runs: list[dict[str, Any]] = [] + for path in sorted(_iter_score_paths(artifact_root)): + payload = _read_json(path) + if payload is None: + continue + runs.extend(_score_entries_from_payload(path, payload)) + return sorted( + runs, + key=lambda row: ( + row.get("candidate_loss") is None, + row.get("candidate_loss") or float("inf"), + row.get("artifact_path") or "", + ), + ) + + +def collect_run_contracts(artifact_root: str | Path) -> list[dict[str, Any]]: + """Read machine-readable run contract summaries under ``artifact_root``.""" + + artifact_root = Path(artifact_root) + contracts: list[dict[str, Any]] = [] + for path in sorted(artifact_root.rglob("run_summary.json")): + summary = _read_json(path) + if not isinstance(summary, dict): + continue + manifest = _read_json(path.parent / "run_manifest.json") or {} + contracts.append( + { + "artifact_dir": str(path.parent), + "summary_path": str(path), + "manifest_path": str(path.parent / "run_manifest.json"), + "events_path": str(path.parent / "run_events.jsonl"), + "status_source": "contract", + "run_id": summary.get("run_id") or manifest.get("run_id"), + "attempt_id": summary.get("attempt_id") + or manifest.get("attempt_id"), + "status": summary.get("status"), + "active": summary.get("active"), + "started_at": summary.get("started_at"), + "updated_at": summary.get("updated_at"), + "failed_at": summary.get("failed_at"), + "completed_at": summary.get("completed_at"), + "failed_event_id": summary.get("failed_event_id"), + "failure": summary.get("failure"), + "restart": summary.get("restart"), + "completed_stages": summary.get("completed_stages") or [], + } + ) + return sorted( + contracts, + key=lambda row: str(row.get("updated_at") or ""), + reverse=True, + ) + + +def collect_local_target_screens(artifact_root: str | Path) -> list[dict[str, Any]]: + """Read cheap matrix-side local target screen summaries.""" + + artifact_root = Path(artifact_root) + screens = [] + for path in sorted(artifact_root.rglob("split_loss_summary.json")): + payload = _read_json(path) + if not isinstance(payload, dict): + continue + score_summary = _local_screen_score_summary(path.parent / "scores.json") + screens.append( + { + "label": payload.get("candidate") or path.parent.name, + "artifact_path": str(path), + "artifact_dir": str(path.parent), + "metric": "latest_pe_matrix_plus_cd_age_screen", + "status": ( + "screen_scored_latest_pe" + if score_summary is not None + else "screen_only" + ), + "broad_loss": _number_or_none( + payload.get("broad_objective_on_latest_pe_matrix_rows") + ), + "pe_native_score_path": ( + str(path.parent / "scores.json") + if score_summary is not None + else None + ), + "pe_native_broad_loss": ( + score_summary.get("candidate_loss") + if score_summary is not None + else None + ), + "pe_native_baseline_loss": ( + score_summary.get("baseline_loss") + if score_summary is not None + else None + ), + "pe_native_loss_delta": ( + score_summary.get("loss_delta") + if score_summary is not None + else None + ), + "pe_native_candidate_beats_baseline": ( + score_summary.get("candidate_beats_baseline") + if score_summary is not None + else None + ), + "latest_pe_baseline_broad_loss": _number_or_none( + payload.get("latest_pe_baseline_broad_loss") + ), + "latest_winner_broad_objective": _number_or_none( + payload.get("latest_winner_broad_objective") + ), + "cd_age_target_weight": _number_or_none( + payload.get("cd_age_target_weight") + ), + "cd_age_mean_abs_relative_error": _number_or_none( + payload.get("cd_age_mean_abs_relative_error") + ), + "cd_age_p90_abs_relative_error": _number_or_none( + payload.get("cd_age_p90_abs_relative_error") + ), + "cd_age_p99_abs_relative_error": _number_or_none( + payload.get("cd_age_p99_abs_relative_error") + ), + "cd_age_max_abs_relative_error": _number_or_none( + payload.get("cd_age_max_abs_relative_error") + ), + "weight_sum": _number_or_none(payload.get("weight_sum")), + "weights_path": payload.get("weights_path"), + } + ) + return sorted( + screens, + key=lambda row: ( + row.get("cd_age_mean_abs_relative_error") is None, + row.get("cd_age_mean_abs_relative_error") or float("inf"), + row.get("broad_loss") or float("inf"), + ), + ) + + +def _local_screen_score_summary(path: Path) -> dict[str, Any] | None: + """Return the latest-PE score summary colocated with a local target screen.""" + + payload = _read_json(path) + if isinstance(payload, list): + payload = payload[0] if payload else None + if not isinstance(payload, dict): + return None + summary = payload.get("summary") + if not isinstance(summary, dict): + summary = payload + candidate_loss = _number_or_none( + summary.get("candidate_enhanced_cps_native_loss") + ) + baseline_loss = _number_or_none( + summary.get("baseline_enhanced_cps_native_loss") + ) + if candidate_loss is None: + return None + return { + "candidate_loss": candidate_loss, + "baseline_loss": baseline_loss, + "loss_delta": _number_or_none( + summary.get("enhanced_cps_native_loss_delta") + ), + "candidate_beats_baseline": summary.get("candidate_beats_baseline"), + } + + +def collect_policyengine_l0_models( + policyengine_us_data_repo: str | Path | None, +) -> list[dict[str, Any]]: + """Collect PE local-L0 weight-package diagnostics.""" + + if policyengine_us_data_repo is None: + return [] + repo = Path(policyengine_us_data_repo) + models = [] + for spec in _PE_L0_MODEL_SPECS: + model_dir = repo / spec["relative_dir"] + config = _read_json(model_dir / "unified_run_config.json") + diagnostics = _summarize_unified_diagnostics( + model_dir / "unified_diagnostics.csv" + ) + weights_path = model_dir / "calibration_weights.npy" + present = isinstance(config, dict) and diagnostics is not None + models.append( + { + "id": spec["id"], + "label": spec["label"], + "status": ( + "available_weight_package" + if present + else "missing_weight_package" + ), + "artifact_dir": str(model_dir), + "weights_path": str(weights_path) if weights_path.exists() else None, + "config_path": ( + str(model_dir / "unified_run_config.json") + if isinstance(config, dict) + else None + ), + "diagnostics_path": ( + str(model_dir / "unified_diagnostics.csv") + if diagnostics is not None + else None + ), + "dataset": config.get("dataset") if isinstance(config, dict) else None, + "db_path": config.get("db_path") if isinstance(config, dict) else None, + "n_clones": ( + _number_or_none(config.get("n_clones")) + if isinstance(config, dict) + else None + ), + "epochs": ( + _number_or_none(config.get("epochs")) + if isinstance(config, dict) + else None + ), + "n_targets": ( + _number_or_none(config.get("n_targets")) + if isinstance(config, dict) + else None + ), + "n_records": ( + _number_or_none(config.get("n_records")) + if isinstance(config, dict) + else None + ), + "weight_sum": ( + _number_or_none(config.get("weight_sum")) + if isinstance(config, dict) + else None + ), + "weight_nonzero": ( + _number_or_none(config.get("weight_nonzero")) + if isinstance(config, dict) + else None + ), + "mean_error_pct": ( + _number_or_none(config.get("mean_error_pct")) + if isinstance(config, dict) + else None + ), + "elapsed_seconds": ( + _number_or_none(config.get("elapsed_seconds")) + if isinstance(config, dict) + else None + ), + "diagnostics": diagnostics, + "same_harness_materialization": _inspect_l0_materialization( + model_dir=model_dir, + config=config, + weights_path=weights_path, + ), + "notes": ( + "PE local-L0 fit metrics come from unified_diagnostics.csv. " + "Same-harness broad/latest score remains missing until this " + "weight package is materialized as a scored H5." + ), + } + ) + return models + + +def collect_actual_l0_objective_runs( + artifact_root: str | Path, +) -> list[dict[str, Any]]: + """Collect local unified-calibration runs scored on the actual L0 objective.""" + + artifact_root = Path(artifact_root) + runs: list[dict[str, Any]] = [] + for diagnostics_path in sorted(artifact_root.rglob("unified_diagnostics.csv")): + diagnostics = _summarize_unified_diagnostics(diagnostics_path) + if diagnostics is None: + continue + run_dir = diagnostics_path.parent + weights_path = run_dir / "calibration_weights.npy" + config = _read_json(run_dir / "unified_run_config.json") + weight_summary = _weight_file_summary(weights_path) + runs.append( + { + "label": run_dir.name, + "artifact_dir": str(run_dir), + "diagnostics_path": str(diagnostics_path), + "config_path": ( + str(run_dir / "unified_run_config.json") + if isinstance(config, dict) + else None + ), + "weights_path": str(weights_path) if weights_path.exists() else None, + "status": "complete", + "model_id": _infer_actual_l0_model_id(run_dir), + "actual_l0_data_loss": diagnostics.get("actual_l0_data_loss"), + "actual_l0_mean_abs_relative_error_pct": diagnostics.get( + "actual_l0_mean_abs_relative_error_pct" + ), + "n_targets": diagnostics.get("n_targets"), + "n_achievable": diagnostics.get("n_achievable"), + "n_clones": ( + _number_or_none(config.get("n_clones")) + if isinstance(config, dict) + else None + ), + "epochs": ( + _number_or_none(config.get("epochs")) + if isinstance(config, dict) + else None + ), + "weights": weight_summary, + "diagnostics": diagnostics, + } + ) + return sorted( + runs, + key=lambda row: ( + row.get("actual_l0_data_loss") is None, + row.get("actual_l0_data_loss") or float("inf"), + row.get("artifact_dir") or "", + ), + ) + + +def collect_materialized_policyengine_l0_scores( + artifact_root: str | Path, +) -> list[dict[str, Any]]: + """Read PE local-area L0 materializations scored through broad diagnostics.""" + + artifact_root = Path(artifact_root) + scores: list[dict[str, Any]] = [] + for path in sorted( + artifact_root.rglob("pe_local_area_l0_state_stack_vs_legacy_ecps.json") + ): + payload = _read_json(path) + if not isinstance(payload, dict): + continue + summary = payload.get("summary") + if not isinstance(summary, dict): + continue + candidate_loss = _number_or_none(summary.get("to_loss")) + baseline_loss = _number_or_none(summary.get("from_loss")) + if candidate_loss is None or baseline_loss is None: + continue + scores.append( + { + "id": "policyengine_local_area_l0_state_stack", + "label": "PE local-area L0 state stack", + "status": "same_harness_scored_experimental", + "artifact_path": str(path), + "artifact_dir": str(path.parent), + "metric": payload.get("metric") + or "enhanced_cps_native_loss_target_delta", + "metric_runtime": "legacy_or_patched_runtime", + "candidate_loss": candidate_loss, + "baseline_loss": baseline_loss, + "candidate_beats_baseline": candidate_loss < baseline_loss, + "loss_delta": _number_or_none(summary.get("loss_delta")), + "n_targets": _number_or_none(summary.get("n_targets")), + "state_score_count": _number_or_none( + payload.get("state_score_count") + ), + "state_weight_sum": _number_or_none( + payload.get("state_weight_sum") + ), + "notes": ( + "This is an experimental materialized state-stack score. " + "It is a broad same-harness artifact, but it is not the " + "small-L0 or big-L0 weight package unless the source path " + "says so." + ), + } + ) + return sorted( + scores, + key=lambda row: ( + row.get("candidate_loss") is None, + row.get("candidate_loss") or float("inf"), + row.get("artifact_path") or "", + ), + ) + + +def collect_recent_log_summaries( + artifact_root: str | Path, *, limit: int = 12 +) -> list[dict[str, Any]]: + """Summarize recent logs with row-batch progress lines.""" + + artifact_root = Path(artifact_root) + paths = sorted( + (path for path in artifact_root.rglob("*.log") if path.is_file()), + key=lambda path: path.stat().st_mtime, + reverse=True, + )[:limit] + summaries = [] + for path in paths: + tail = _tail_text(path) + progress = _parse_row_batch_progress(tail) + summaries.append( + { + "path": str(path), + "modified_at": datetime.fromtimestamp( + path.stat().st_mtime, timezone.utc + ).isoformat(), + "progress": progress, + "last_lines": tail.splitlines()[-5:], + } + ) + return summaries + + +def collect_tmux_sessions() -> list[dict[str, Any]]: + """Return current tmux sessions when tmux is available.""" + + try: + completed = subprocess.run( + ["tmux", "ls"], + check=False, + capture_output=True, + text=True, + timeout=5, + ) + except (FileNotFoundError, subprocess.TimeoutExpired): + return [] + if completed.returncode != 0: + return [] + sessions = [] + for line in completed.stdout.splitlines(): + if not line.strip(): + continue + name = line.split(":", 1)[0] + if not _is_relevant_tmux_session(name): + continue + sessions.append({"name": name, "raw": line}) + return sorted(sessions, key=lambda row: (not row["name"].startswith("mp_"), row["name"])) + + +def build_comparison_matrix( + score_runs: list[dict[str, Any]], + local_screens: list[dict[str, Any]], + pe_l0_models: list[dict[str, Any]], + materialized_l0_scores: list[dict[str, Any]] | None = None, +) -> list[dict[str, Any]]: + """Build a compact answer matrix for the current PE comparison question.""" + + materialized_l0_scores = materialized_l0_scores or [] + best_latest = _best_score( + score_runs, + predicate=lambda row: row.get("metric_runtime") == "latest_policyengine_us" + and row.get("model_id") == "microplex_current_best", + ) + best_legacy = _best_score( + score_runs, + predicate=lambda row: row.get("metric_runtime") == "legacy_or_patched_runtime" + and row.get("model_id") == "microplex_current_best", + ) + best_local = local_screens[0] if local_screens else None + pe_l0_by_id = {row.get("id"): row for row in pe_l0_models} + + rows: list[dict[str, Any]] = [] + for slot in _PE_MODEL_SLOTS: + row = dict(slot) + if slot["id"] == "policyengine_legacy_ecps" and best_latest is not None: + row.update( + { + "latest_pe_broad_loss": best_latest.get("baseline_loss"), + "latest_pe_status": "available", + "legacy_metric_loss": ( + best_legacy.get("baseline_loss") + if best_legacy is not None + else None + ), + "legacy_metric_status": ( + "available" if best_legacy is not None else "missing" + ), + } + ) + elif slot["id"] in pe_l0_by_id: + model = pe_l0_by_id[slot["id"]] + diagnostics = model.get("diagnostics") or {} + latest_score = _best_model_metric_score( + score_runs, + model_id=str(slot["id"]), + metric_runtime="latest_policyengine_us", + ) + legacy_score = _best_model_metric_score( + score_runs, + model_id=str(slot["id"]), + metric_runtime="legacy_or_patched_runtime", + ) + row.update( + { + "status": ( + "same_harness_scored" + if latest_score is not None or legacy_score is not None + else model.get("status") + ), + "artifact_dir": model.get("artifact_dir"), + "latest_pe_broad_loss": ( + latest_score.get("candidate_loss") + if latest_score is not None + else None + ), + "latest_pe_status": ( + "scored" + if latest_score is not None + else "missing_h5_score" + ), + "legacy_metric_loss": ( + legacy_score.get("candidate_loss") + if legacy_score is not None + else None + ), + "legacy_metric_status": ( + "scored" + if legacy_score is not None + else "missing_h5_score" + ), + "pe_local_l0_mean_abs_error_pct": diagnostics.get( + "mean_abs_relative_error_pct" + ) + or model.get("mean_error_pct"), + "pe_local_l0_median_abs_error_pct": diagnostics.get( + "median_abs_relative_error_pct" + ), + "pe_local_l0_p90_abs_error_pct": diagnostics.get( + "p90_abs_relative_error_pct" + ), + "pe_local_l0_targets": diagnostics.get("n_targets") + or model.get("n_targets"), + "pe_local_l0_epochs": model.get("epochs"), + "pe_local_l0_weight_nonzero": model.get("weight_nonzero"), + "notes": ( + "Same-harness H5 score is available." + if latest_score is not None or legacy_score is not None + else model.get("notes") or row.get("notes") + ), + } + ) + else: + row.update( + { + "latest_pe_broad_loss": None, + "latest_pe_status": "missing", + "legacy_metric_loss": None, + "legacy_metric_status": "missing", + } + ) + rows.append(row) + + best_materialized_l0 = _best_materialized_l0_score(materialized_l0_scores) + if best_materialized_l0 is not None: + rows.append( + { + "id": best_materialized_l0.get("id"), + "label": best_materialized_l0.get("label"), + "status": best_materialized_l0.get("status"), + "latest_pe_broad_loss": None, + "latest_pe_status": None, + "legacy_metric_loss": best_materialized_l0.get("candidate_loss"), + "legacy_metric_baseline_loss": best_materialized_l0.get( + "baseline_loss" + ), + "legacy_metric_status": ( + "beats_legacy_pe_baseline" + if best_materialized_l0.get("candidate_beats_baseline") + else "worse_than_legacy_pe_baseline" + ), + "artifact_path": best_materialized_l0.get("artifact_path"), + "notes": best_materialized_l0.get("notes"), + } + ) + + rows.append( + { + "id": "microplex_current_best", + "label": "Microplex current best", + "status": "available" if best_latest is not None else "missing", + "latest_pe_broad_loss": ( + best_latest.get("candidate_loss") if best_latest is not None else None + ), + "latest_pe_baseline_loss": ( + best_latest.get("baseline_loss") if best_latest is not None else None + ), + "latest_pe_status": ( + "beats_legacy_pe_baseline" + if best_latest is not None + and best_latest.get("candidate_beats_baseline") + else "missing" + ), + "legacy_metric_loss": ( + best_legacy.get("candidate_loss") if best_legacy is not None else None + ), + "legacy_metric_baseline_loss": ( + best_legacy.get("baseline_loss") if best_legacy is not None else None + ), + "legacy_metric_status": ( + "beats_legacy_pe_baseline" + if best_legacy is not None + and best_legacy.get("candidate_beats_baseline") + else "missing" + ), + "local_cd_age_screen_loss": ( + best_local.get("broad_loss") if best_local is not None else None + ), + "local_cd_age_mare": ( + best_local.get("cd_age_mean_abs_relative_error") + if best_local is not None + else None + ), + "artifact_path": ( + best_latest.get("artifact_path") if best_latest is not None else None + ), + "notes": ( + "This is the best completed Microplex score found locally. " + "The CD-age row is a matrix screen until the latest-PE row-batch " + "score finishes." + ), + } + ) + return rows + + +def build_apples_to_apples_groups( + score_runs: list[dict[str, Any]], + local_screens: list[dict[str, Any]], + pe_l0_models: list[dict[str, Any]], + materialized_l0_scores: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Group comparisons that share an actual metric and target universe.""" + + best_latest = _best_score( + score_runs, + predicate=lambda row: row.get("metric_runtime") == "latest_policyengine_us" + and row.get("model_id") == "microplex_current_best", + ) + best_legacy = _best_score( + score_runs, + predicate=lambda row: row.get("metric_runtime") == "legacy_or_patched_runtime" + and row.get("model_id") == "microplex_current_best", + ) + best_local = local_screens[0] if local_screens else None + best_materialized_l0 = _best_materialized_l0_score(materialized_l0_scores) + pe_l0_by_id = {row.get("id"): row for row in pe_l0_models} + + latest_small = _best_model_metric_score( + score_runs, + model_id="policyengine_small_l0", + metric_runtime="latest_policyengine_us", + ) + latest_big = _best_model_metric_score( + score_runs, + model_id="policyengine_big_l0", + metric_runtime="latest_policyengine_us", + ) + legacy_small = _best_model_metric_score( + score_runs, + model_id="policyengine_small_l0", + metric_runtime="legacy_or_patched_runtime", + ) + legacy_big = _best_model_metric_score( + score_runs, + model_id="policyengine_big_l0", + metric_runtime="legacy_or_patched_runtime", + ) + + groups = [ + { + "id": "latest_pe_broad", + "label": "Latest PolicyEngine broad target loss", + "metric_scope": "same_harness_latest_pe_broad", + "status": ( + "complete" + if best_latest and latest_small and latest_big + else "partial" + if best_latest + else "missing" + ), + "rows": [ + _comparison_row( + model_id="policyengine_legacy_ecps", + label="PE legacy enhanced CPS", + score=( + best_latest.get("baseline_loss") + if best_latest is not None + else None + ), + status="scored_baseline" if best_latest else "missing", + ), + _comparison_row( + model_id="microplex_current_best", + label="Microplex current best", + score=( + best_latest.get("candidate_loss") + if best_latest is not None + else None + ), + status=( + "scored_candidate_beats_baseline" + if best_latest + and best_latest.get("candidate_beats_baseline") + else "missing" + ), + artifact_path=( + best_latest.get("artifact_path") + if best_latest is not None + else None + ), + ), + _scored_or_missing_l0_row( + pe_l0_by_id, + "policyengine_small_l0", + latest_small, + ), + _scored_or_missing_l0_row( + pe_l0_by_id, + "policyengine_big_l0", + latest_big, + ), + ], + }, + { + "id": "legacy_broad", + "label": "Legacy broad target loss", + "metric_scope": "same_harness_legacy_broad", + "status": ( + "complete" + if best_legacy and legacy_small and legacy_big + else "partial" + if best_legacy + else "missing" + ), + "rows": [ + _comparison_row( + model_id="policyengine_legacy_ecps", + label="PE legacy enhanced CPS", + score=( + best_legacy.get("baseline_loss") + if best_legacy is not None + else None + ), + status="scored_baseline" if best_legacy else "missing", + ), + _comparison_row( + model_id="microplex_current_best", + label="Microplex current best", + score=( + best_legacy.get("candidate_loss") + if best_legacy is not None + else None + ), + status=( + "scored_candidate_beats_baseline" + if best_legacy + and best_legacy.get("candidate_beats_baseline") + else "missing" + ), + artifact_path=( + best_legacy.get("artifact_path") + if best_legacy is not None + else None + ), + ), + _comparison_row( + model_id="policyengine_local_area_l0_state_stack", + label="PE local-area L0 state stack", + score=( + best_materialized_l0.get("candidate_loss") + if best_materialized_l0 is not None + else None + ), + status=( + best_materialized_l0.get("status") + if best_materialized_l0 is not None + else "missing" + ), + artifact_path=( + best_materialized_l0.get("artifact_path") + if best_materialized_l0 is not None + else None + ), + detail=( + "Experimental materialization" + if best_materialized_l0 is not None + else None + ), + ), + _scored_or_missing_l0_row( + pe_l0_by_id, + "policyengine_small_l0", + legacy_small, + ), + _scored_or_missing_l0_row( + pe_l0_by_id, + "policyengine_big_l0", + legacy_big, + ), + ], + }, + { + "id": "pe_local_l0_native", + "label": "PE local-L0 native target diagnostics", + "metric_scope": "pe_native_local_l0_diagnostics", + "status": "native_only", + "rows": [ + _native_pe_l0_row(pe_l0_by_id, "policyengine_small_l0"), + _native_pe_l0_row(pe_l0_by_id, "policyengine_big_l0"), + _comparison_row( + model_id="microplex_cd_age_screen", + label="Microplex CD-age screen", + score=( + 100 * best_local.get("cd_age_mean_abs_relative_error") + if best_local is not None + and best_local.get("cd_age_mean_abs_relative_error") + is not None + else None + ), + status=( + "different_target_set_screen_only" + if best_local is not None + else "missing" + ), + artifact_path=( + best_local.get("artifact_path") + if best_local is not None + else None + ), + detail=( + "Displayed for tracking only; not used as a PE local-L0 " + "native comparison." + ), + ), + ], + }, + ] + return groups + + +def build_dashboard_assertions( + score_runs: list[dict[str, Any]], + local_screens: list[dict[str, Any]], + pe_l0_models: list[dict[str, Any]], + materialized_l0_scores: list[dict[str, Any]] | None = None, +) -> dict[str, Any]: + """State which comparison claims are supported by completed artifacts.""" + + materialized_l0_scores = materialized_l0_scores or [] + best_latest = _best_score( + score_runs, + predicate=lambda row: row.get("metric_runtime") == "latest_policyengine_us" + and row.get("model_id") == "microplex_current_best", + ) + best_legacy = _best_score( + score_runs, + predicate=lambda row: row.get("metric_runtime") == "legacy_or_patched_runtime" + and row.get("model_id") == "microplex_current_best", + ) + pe_l0_by_id = {row.get("id"): row for row in pe_l0_models} + small_l0_present = ( + pe_l0_by_id.get("policyengine_small_l0", {}).get("status") + == "available_weight_package" + ) + big_l0_present = ( + pe_l0_by_id.get("policyengine_big_l0", {}).get("status") + == "available_weight_package" + ) + best_materialized_l0 = _best_materialized_l0_score(materialized_l0_scores) + small_latest = _best_model_metric_score( + score_runs, + model_id="policyengine_small_l0", + metric_runtime="latest_policyengine_us", + ) + small_legacy = _best_model_metric_score( + score_runs, + model_id="policyengine_small_l0", + metric_runtime="legacy_or_patched_runtime", + ) + big_latest = _best_model_metric_score( + score_runs, + model_id="policyengine_big_l0", + metric_runtime="latest_policyengine_us", + ) + big_legacy = _best_model_metric_score( + score_runs, + model_id="policyengine_big_l0", + metric_runtime="legacy_or_patched_runtime", + ) + small_complete = bool(small_latest and small_legacy) + big_complete = bool(big_latest and big_legacy) + all_models_complete = bool( + best_latest + and best_legacy + and small_complete + and big_complete + ) + return { + "microplex_beats_legacy_ecps_latest_pe_broad": bool( + best_latest and best_latest.get("candidate_beats_baseline") + ), + "microplex_beats_legacy_ecps_legacy_metric": bool( + best_legacy and best_legacy.get("candidate_beats_baseline") + ), + "microplex_vs_small_l0_complete": small_complete, + "microplex_vs_big_l0_complete": big_complete, + "microplex_vs_all_three_pe_models_on_both_metrics": all_models_complete, + "policyengine_small_l0_weight_package_available": small_l0_present, + "policyengine_big_l0_weight_package_available": big_l0_present, + "policyengine_materialized_l0_same_harness_available": bool( + best_materialized_l0 + ), + "local_cd_age_screen_available": bool(local_screens), + "apples_to_apples_groups_available": True, + "caveat": ( + "Small-L0 and big-L0 PE weight packages are wired into the run " + "board when available. The all-three-PE-model claim is supported " + "only when both materialized PE L0 packages have legacy and latest " + "same-harness scores." + ), + } + + +def _comparison_row( + *, + model_id: str, + label: str, + score: float | None, + status: str, + artifact_path: str | None = None, + detail: str | None = None, +) -> dict[str, Any]: + return { + "model_id": model_id, + "label": label, + "score": _number_or_none(score), + "status": status, + "artifact_path": artifact_path, + "detail": detail, + } + + +def _missing_h5_row( + pe_l0_by_id: dict[str, dict[str, Any]], model_id: str +) -> dict[str, Any]: + model = pe_l0_by_id.get(model_id) or {} + materialization = model.get("same_harness_materialization") + blocker = None + if isinstance(materialization, dict): + blocker = materialization.get("status") + return _comparison_row( + model_id=model_id, + label=str(model.get("label") or model_id), + score=None, + status="missing_same_harness_h5_score", + artifact_path=model.get("artifact_dir"), + detail=blocker, + ) + + +def _scored_or_missing_l0_row( + pe_l0_by_id: dict[str, dict[str, Any]], + model_id: str, + score: dict[str, Any] | None, +) -> dict[str, Any]: + if score is None: + return _missing_h5_row(pe_l0_by_id, model_id) + model = pe_l0_by_id.get(model_id) or {} + return _comparison_row( + model_id=model_id, + label=str(model.get("label") or model_id), + score=score.get("candidate_loss"), + status=( + "scored_candidate_beats_legacy_ecps" + if score.get("candidate_beats_baseline") + else "scored_candidate_worse_than_legacy_ecps" + ), + artifact_path=score.get("artifact_path"), + detail=( + f"{int(score['n_targets_kept']):,} targets" + if _number_or_none(score.get("n_targets_kept")) is not None + else None + ), + ) + + +def _native_pe_l0_row( + pe_l0_by_id: dict[str, dict[str, Any]], model_id: str +) -> dict[str, Any]: + model = pe_l0_by_id.get(model_id) or {} + diagnostics = model.get("diagnostics") or {} + score = diagnostics.get("mean_abs_relative_error_pct") or model.get( + "mean_error_pct" + ) + targets = diagnostics.get("n_targets") or model.get("n_targets") + return _comparison_row( + model_id=model_id, + label=str(model.get("label") or model_id), + score=_number_or_none(score), + status=( + "native_diagnostics_available" + if _number_or_none(score) is not None + else "missing_native_diagnostics" + ), + artifact_path=model.get("diagnostics_path") or model.get("artifact_dir"), + detail=( + f"{format(int(targets), ',')} PE-local targets" + if _number_or_none(targets) is not None + else None + ), + ) + + +def _best_materialized_l0_score( + rows: list[dict[str, Any]], +) -> dict[str, Any] | None: + candidates = [ + row + for row in rows + if _number_or_none(row.get("candidate_loss")) is not None + ] + if not candidates: + return None + return min(candidates, key=lambda row: row["candidate_loss"]) + + +def _weight_file_summary(path: Path) -> dict[str, Any] | None: + if not path.exists(): + return None + try: + import numpy as np + + weights = np.asarray(np.load(path), dtype=float) + except Exception: # pragma: no cover - defensive artifact read + return {"status": "unreadable", "path": str(path)} + return { + "status": "ok", + "path": str(path), + "records": int(weights.size), + "nonzero": int((weights > 0.0).sum()), + "greater_than_1": int((weights > 1.0).sum()), + "greater_than_100": int((weights > 100.0).sum()), + "sum": float(weights.sum()), + } + + +def _infer_actual_l0_model_id(run_dir: Path) -> str: + text = str(run_dir).lower() + if "microplex" in text or "mp_" in text: + return "microplex_actual_l0" + if "local_net_worth_100_e300" in text: + return "policyengine_big_l0" + if "local_net_worth_100" in text: + return "policyengine_small_l0" + return "unknown_actual_l0" + + +def _inspect_l0_materialization( + *, + model_dir: Path, + config: Any, + weights_path: Path, +) -> dict[str, Any]: + """Return a cheap compatibility check for materializing a PE-L0 package.""" + + result: dict[str, Any] = {"status": "unknown"} + if not weights_path.exists(): + result["status"] = "missing_weights" + return result + + try: + import numpy as np + + weights = np.load(weights_path, mmap_mode="r") + weight_count = int(weights.shape[0]) + result["weight_count"] = weight_count + except Exception as error: # pragma: no cover - defensive artifact read + result["status"] = "weights_unreadable" + result["error"] = str(error) + return result + + geography_path = model_dir / "geography.npz" + if geography_path.exists(): + try: + import numpy as np + + with np.load(geography_path, allow_pickle=True) as geography: + if "block_geoid" in geography: + result["geography_row_count"] = int( + geography["block_geoid"].shape[0] + ) + if "n_records" in geography: + result["geography_n_records"] = int( + geography["n_records"][0] + ) + if "n_clones" in geography: + result["geography_n_clones"] = int( + geography["n_clones"][0] + ) + except Exception as error: # pragma: no cover - defensive artifact read + result["geography_error"] = str(error) + + dataset_path = None + if isinstance(config, dict) and config.get("dataset"): + dataset_path = Path(str(config["dataset"])) + result["dataset_path"] = str(dataset_path) + if dataset_path is None or not dataset_path.exists(): + result["status"] = "source_h5_missing" + return result + + household_count = _h5_period_length(dataset_path, "household_id") + result["source_household_count"] = household_count + if household_count is None: + result["status"] = "source_h5_unreadable" + return result + + if household_count > 0 and weight_count % household_count == 0: + result["status"] = "materializable_against_current_source_h5" + result["implied_clone_count"] = weight_count // household_count + else: + result["status"] = "incompatible_current_source_h5" + result["detail"] = ( + "Weight count is not divisible by the current source H5 household " + "count; same-harness scoring needs the matching source dataset or " + "a regenerated L0 package." + ) + return result + + +def _h5_period_length(path: Path, variable: str) -> int | None: + try: + import h5py + + with h5py.File(path, "r") as handle: + if variable not in handle: + return None + obj = handle[variable] + if hasattr(obj, "keys"): + keys = list(obj.keys()) + if not keys: + return None + return int(obj[keys[0]].shape[0]) + return int(obj.shape[0]) + except Exception: # pragma: no cover - defensive artifact read + return None + + +def write_dashboard_payload( + output_path: str | Path = _DEFAULT_OUTPUT_PATH, + *, + artifact_root: str | Path = _DEFAULT_ARTIFACT_ROOT, + target_diagnostics_path: str | Path = _DEFAULT_TARGET_DIAGNOSTICS_PATH, + policyengine_us_data_repo: str | Path | None = _DEFAULT_POLICYENGINE_US_DATA_REPO, + include_tmux: bool = True, +) -> Path: + """Write the living dashboard JSON payload.""" + + payload = build_dashboard_payload( + artifact_root=artifact_root, + target_diagnostics_path=target_diagnostics_path, + policyengine_us_data_repo=policyengine_us_data_repo, + include_tmux=include_tmux, + ) + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(json.dumps(payload, indent=2, sort_keys=True)) + return output_path + + +def _iter_score_paths(artifact_root: Path) -> list[Path]: + paths = list(artifact_root.rglob("scores.json")) + paths.extend(artifact_root.rglob("policyengine_native_scores.json")) + paths.extend(artifact_root.rglob("*_score.json")) + return [path for path in paths if path.is_file()] + + +def _score_entries_from_payload( + path: Path, payload: Any +) -> list[dict[str, Any]]: + if isinstance(payload, list): + raw_entries = payload + elif isinstance(payload, dict) and "broad_loss" in payload: + raw_entries = [payload] + elif isinstance(payload, dict) and "summary" in payload: + raw_entries = [payload] + elif isinstance(payload, dict) and "candidate_enhanced_cps_native_loss" in payload: + raw_entries = [payload] + else: + return [] + + entries = [] + for index, item in enumerate(raw_entries): + if not isinstance(item, dict): + continue + if "candidate_enhanced_cps_native_loss" in item: + summary = item + broad_loss = item + else: + summary = item.get("summary") if isinstance(item.get("summary"), dict) else {} + broad_loss = ( + item.get("broad_loss") + if isinstance(item.get("broad_loss"), dict) + else {} + ) + candidate_loss = _number_or_none( + summary.get("candidate_enhanced_cps_native_loss") + ) + baseline_loss = _number_or_none( + summary.get("baseline_enhanced_cps_native_loss") + ) + if candidate_loss is None or baseline_loss is None: + continue + candidate_dataset = broad_loss.get("candidate_dataset") + baseline_dataset = broad_loss.get("baseline_dataset") + metric_runtime = _infer_metric_runtime(path, summary) + model_id = _infer_score_model_id(path, candidate_dataset) + label = _score_label(path, candidate_dataset, index) + entries.append( + { + "label": label, + "model_id": model_id, + "artifact_path": str(path), + "artifact_dir": str(path.parent), + "entry_index": index, + "metric": item.get("metric") or "pe_native_broad_loss", + "metric_runtime": metric_runtime, + "period": item.get("period") or summary.get("period") or 2024, + "candidate_dataset": candidate_dataset, + "baseline_dataset": baseline_dataset, + "candidate_loss": candidate_loss, + "baseline_loss": baseline_loss, + "loss_delta": _number_or_none( + summary.get("enhanced_cps_native_loss_delta") + ), + "candidate_beats_baseline": bool( + summary.get("candidate_beats_baseline") + ), + "candidate_unweighted_msre": _number_or_none( + summary.get("candidate_unweighted_msre") + ), + "baseline_unweighted_msre": _number_or_none( + summary.get("baseline_unweighted_msre") + ), + "n_targets_kept": _number_or_none(summary.get("n_targets_kept")), + "n_targets_total": _number_or_none(summary.get("n_targets_total")), + "candidate_weight_sum": _number_or_none( + broad_loss.get("candidate_weight_sum") + ), + "baseline_weight_sum": _number_or_none( + broad_loss.get("baseline_weight_sum") + ), + "source_kind": "scores_json", + } + ) + return entries + + +def _summarize_unified_diagnostics(path: Path) -> dict[str, Any] | None: + try: + with path.open(newline="") as file: + rows = list(csv.DictReader(file)) + except OSError: + return None + if not rows: + return None + + abs_errors = [] + actual_l0_abs_errors = [] + actual_l0_squared_errors = [] + achievable_count = 0 + for row in rows: + if str(row.get("achievable", "")).lower() == "true": + achievable_count += 1 + error = _number_or_none(row.get("abs_rel_error")) + if error is not None: + abs_errors.append(error) + estimate = _number_or_none(row.get("estimate")) + true_value = _number_or_none(row.get("true_value")) + if estimate is not None and true_value is not None: + actual_error = (estimate - true_value) / (true_value + 1.0) + actual_l0_abs_errors.append(abs(actual_error)) + actual_l0_squared_errors.append(actual_error * actual_error) + + sorted_errors = sorted(abs_errors) + return { + "n_targets": len(rows), + "n_achievable": achievable_count, + "actual_l0_objective": ( + "sum(((estimate - target) / (target + 1)) ** 2)" + ), + "actual_l0_data_loss": ( + sum(actual_l0_squared_errors) if actual_l0_squared_errors else None + ), + "actual_l0_mean_abs_relative_error_pct": ( + 100 * sum(actual_l0_abs_errors) / len(actual_l0_abs_errors) + if actual_l0_abs_errors + else None + ), + "mean_abs_relative_error_pct": ( + 100 * sum(abs_errors) / len(abs_errors) if abs_errors else None + ), + "median_abs_relative_error_pct": _percentile(sorted_errors, 0.5), + "p90_abs_relative_error_pct": _percentile(sorted_errors, 0.9), + "p99_abs_relative_error_pct": _percentile(sorted_errors, 0.99), + "max_abs_relative_error_pct": ( + 100 * sorted_errors[-1] if sorted_errors else None + ), + "share_under_10pct": _share_under(abs_errors, 0.10), + "share_under_25pct": _share_under(abs_errors, 0.25), + } + + +def _score_label(path: Path, candidate_dataset: Any, index: int) -> str: + artifact = path.parent.name + if isinstance(candidate_dataset, str): + dataset_name = Path(candidate_dataset).name + if dataset_name != "policyengine_us.h5": + return f"{artifact} / {dataset_name}" + if index: + return f"{artifact} / candidate {index + 1}" + return artifact + + +def _infer_metric_runtime(path: Path, summary: dict[str, Any]) -> str: + text = str(path).lower() + n_targets = _number_or_none(summary.get("n_targets_kept")) + baseline_loss = _number_or_none(summary.get("baseline_enhanced_cps_native_loss")) + if "latest_pe" in text or n_targets == 2805 or baseline_loss == 0.09774356788921322: + return "latest_policyengine_us" + return "legacy_or_patched_runtime" + + +def _infer_score_model_id(path: Path, candidate_dataset: Any) -> str: + text_parts = [str(path).lower()] + if isinstance(candidate_dataset, str): + text_parts.append(candidate_dataset.lower()) + text_parts.append(Path(candidate_dataset).name.lower()) + text = " ".join(text_parts) + if "pe_small_l0" in text or "local_net_worth_100/" in text: + return "policyengine_small_l0" + if "pe_big_l0" in text or "local_net_worth_100_e300" in text: + return "policyengine_big_l0" + if "policyengine_local_area_l0" in text or "state_stack" in text: + return "policyengine_local_area_l0_state_stack" + return "microplex_current_best" + + +def _percentile(sorted_values: list[float], quantile: float) -> float | None: + if not sorted_values: + return None + if len(sorted_values) == 1: + return 100 * sorted_values[0] + position = quantile * (len(sorted_values) - 1) + lower = int(position) + upper = min(lower + 1, len(sorted_values) - 1) + weight = position - lower + return 100 * ( + sorted_values[lower] * (1 - weight) + sorted_values[upper] * weight + ) + + +def _share_under(values: list[float], threshold: float) -> float | None: + if not values: + return None + return sum(value < threshold for value in values) / len(values) + + +def _best_score( + rows: list[dict[str, Any]], *, predicate: Any +) -> dict[str, Any] | None: + candidates = [ + row + for row in rows + if predicate(row) and _number_or_none(row.get("candidate_loss")) is not None + ] + if not candidates: + return None + return min(candidates, key=lambda row: row["candidate_loss"]) + + +def _best_model_metric_score( + rows: list[dict[str, Any]], + *, + model_id: str, + metric_runtime: str, +) -> dict[str, Any] | None: + return _best_score( + rows, + predicate=lambda row: row.get("model_id") == model_id + and row.get("metric_runtime") == metric_runtime, + ) + + +def _parse_row_batch_progress(text: str) -> dict[str, Any] | None: + pattern = re.compile( + r"PE-native row batch (?P[^:]+): " + r"(?P\d+)/(?P\d+) households " + r"\((?P[0-9.]+)s\)" + ) + matches = list(pattern.finditer(text)) + if not matches: + return None + match = matches[-1] + done = int(match.group("done")) + total = int(match.group("total")) + return { + "dataset": match.group("dataset"), + "households_done": done, + "households_total": total, + "fraction": done / total if total else None, + "elapsed_seconds": float(match.group("elapsed")), + } + + +def _is_relevant_tmux_session(name: str) -> bool: + lowered = name.lower() + return ( + lowered.startswith("mp_") + or "microplex" in lowered + or lowered.startswith("dashboard") + ) + + +def _tail_text(path: Path, max_bytes: int = 8192) -> str: + try: + with path.open("rb") as file: + file.seek(0, 2) + size = file.tell() + file.seek(max(size - max_bytes, 0)) + return file.read().decode("utf-8", errors="replace") + except OSError: + return "" + + +def _read_json(path: Path) -> Any | None: + try: + return json.loads(path.read_text()) + except (OSError, json.JSONDecodeError): + return None + + +def _number_or_none(value: Any) -> float | None: + try: + number = float(value) + except (TypeError, ValueError): + return None + if number != number: + return None + return number + + +def main(argv: list[str] | None = None) -> int: + """CLI for the living Microplex dashboard payload.""" + + parser = argparse.ArgumentParser( + description="Build the living Microplex diagnostic dashboard JSON." + ) + parser.add_argument("--artifact-root", default=str(_DEFAULT_ARTIFACT_ROOT)) + parser.add_argument( + "--target-diagnostics-path", + default=str(_DEFAULT_TARGET_DIAGNOSTICS_PATH), + help="Existing per-target diagnostics JSON to embed when available.", + ) + parser.add_argument( + "--policyengine-us-data-repo", + default=str(_DEFAULT_POLICYENGINE_US_DATA_REPO), + help=( + "Local policyengine-us-data checkout used to discover PE local-L0 " + "weight packages. Pass an empty string to skip discovery." + ), + ) + parser.add_argument("--output-path", default=str(_DEFAULT_OUTPUT_PATH)) + parser.add_argument( + "--no-tmux", + action="store_true", + help="Skip tmux session discovery for deterministic tests.", + ) + args = parser.parse_args(argv) + output = write_dashboard_payload( + args.output_path, + artifact_root=args.artifact_root, + target_diagnostics_path=args.target_diagnostics_path, + policyengine_us_data_repo=args.policyengine_us_data_repo or None, + include_tmux=not args.no_tmux, + ) + print(output) + return 0 diff --git a/src/microplex_us/pipelines/pe_native_calibration_benchmark.py b/src/microplex_us/pipelines/pe_native_calibration_benchmark.py new file mode 100644 index 0000000..af4a6f7 --- /dev/null +++ b/src/microplex_us/pipelines/pe_native_calibration_benchmark.py @@ -0,0 +1,678 @@ +"""Benchmark PE-native calibration strategies on a common target surface.""" + +from __future__ import annotations + +import argparse +import json +import re +import subprocess +import sys +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from tempfile import TemporaryDirectory +from time import perf_counter +from typing import Any + +import h5py +import numpy as np + +from microplex_us.pipelines.pe_native_optimization import ( + _PE_NATIVE_BROAD_MATRIX_SCRIPT, + optimize_pe_native_loss_weights, + rewrite_policyengine_us_dataset_weights, +) +from microplex_us.pipelines.pe_native_scores import ( + _DEFAULT_PE_NATIVE_BASELINE_CACHE_DIR, + _ENHANCED_CPS_BAD_TARGETS, + build_policyengine_us_data_subprocess_env, + compute_batch_us_pe_native_scores, + resolve_policyengine_us_data_repo_root, + validate_policyengine_us_data_runtime, +) + + +@dataclass(frozen=True) +class CalibrationBenchmarkVariant: + """One dataset variant to score in a PE-native calibration benchmark.""" + + label: str + method: str + dataset_path: str + generated: bool = False + optimization: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return { + "label": self.label, + "method": self.method, + "dataset_path": self.dataset_path, + "generated": self.generated, + "optimization": dict(self.optimization), + } + + +def _household_weights( + dataset_path: str | Path, + *, + period: int, +) -> tuple[np.ndarray, np.ndarray]: + path = Path(dataset_path).expanduser().resolve() + period_key = str(period) + with h5py.File(path, "r") as handle: + if "household_id" not in handle or period_key not in handle["household_id"]: + raise ValueError(f"{path} is missing household_id/{period_key}") + if ( + "household_weight" not in handle + or period_key not in handle["household_weight"] + ): + raise ValueError(f"{path} is missing household_weight/{period_key}") + household_ids = np.asarray(handle["household_id"][period_key], dtype=np.int64) + weights = np.asarray( + handle["household_weight"][period_key], + dtype=np.float64, + ) + if household_ids.shape[0] != weights.shape[0]: + raise ValueError(f"{path} household_id and household_weight lengths differ") + return household_ids, weights + + +def _reference_aligned_weights( + household_ids: np.ndarray, + reference_dataset_path: str | Path, + *, + period: int, +) -> tuple[str, np.ndarray | None]: + reference_ids, reference_weights = _household_weights( + reference_dataset_path, + period=period, + ) + if household_ids.shape == reference_ids.shape and np.array_equal( + household_ids, + reference_ids, + ): + return "same_order", reference_weights + if len(np.unique(reference_ids)) != len(reference_ids): + return "reference_duplicate_household_ids", None + reference_by_id = { + int(household_id): float(weight) + for household_id, weight in zip(reference_ids, reference_weights, strict=True) + } + if all(int(household_id) in reference_by_id for household_id in household_ids): + return ( + "matched_by_household_id", + np.asarray( + [reference_by_id[int(household_id)] for household_id in household_ids], + dtype=np.float64, + ), + ) + return "not_comparable", None + + +def compute_household_weight_diagnostics( + dataset_path: str | Path, + *, + period: int = 2024, + reference_dataset_path: str | Path | None = None, +) -> dict[str, Any]: + """Summarize household weight quality and optional distance from a reference.""" + + resolved = Path(dataset_path).expanduser().resolve() + household_ids, weights = _household_weights(resolved, period=period) + n_households = int(len(weights)) + positive = weights > 0.0 + weight_sum = float(weights.sum()) + square_sum = float(np.dot(weights, weights)) + effective_sample_size = ( + weight_sum * weight_sum / square_sum if square_sum > 0.0 else 0.0 + ) + diagnostics: dict[str, Any] = { + "dataset_path": str(resolved), + "period": int(period), + "household_count": n_households, + "positive_household_count": int(positive.sum()), + "zero_household_count": int((weights == 0.0).sum()), + "negative_household_count": int((weights < 0.0).sum()), + "weight_sum": weight_sum, + "weight_mean": float(weights.mean()) if n_households else 0.0, + "weight_median": float(np.median(weights)) if n_households else 0.0, + "weight_min": float(weights.min()) if n_households else 0.0, + "weight_max": float(weights.max()) if n_households else 0.0, + "weight_p95": float(np.quantile(weights, 0.95)) if n_households else 0.0, + "weight_p99": float(np.quantile(weights, 0.99)) if n_households else 0.0, + "max_to_mean_weight_ratio": ( + float(weights.max() / weights.mean()) + if n_households and weights.mean() > 0.0 + else None + ), + "effective_sample_size": float(effective_sample_size), + "effective_sample_size_share": ( + float(effective_sample_size / n_households) if n_households else None + ), + } + + if reference_dataset_path is None: + return diagnostics + + alignment, reference_weights = _reference_aligned_weights( + household_ids, + reference_dataset_path, + period=period, + ) + diagnostics["reference_dataset_path"] = str( + Path(reference_dataset_path).expanduser().resolve() + ) + diagnostics["reference_alignment"] = alignment + if reference_weights is None: + return diagnostics + + delta = weights - reference_weights + reference_sum = float(reference_weights.sum()) + diagnostics.update( + { + "reference_weight_sum": reference_sum, + "weight_sum_delta": float(weight_sum - reference_sum), + "l1_delta_as_share_of_reference_sum": ( + float(np.abs(delta).sum() / abs(reference_sum)) + if reference_sum != 0.0 + else None + ), + "mean_abs_weight_delta": float(np.abs(delta).mean()), + "rms_weight_delta": float(np.sqrt(np.mean(delta * delta))), + "max_abs_weight_delta": float(np.abs(delta).max()) if len(delta) else 0.0, + "changed_household_count": int((np.abs(delta) > 1e-9).sum()), + "changed_household_share": ( + float((np.abs(delta) > 1e-9).mean()) if len(delta) else None + ), + } + ) + return diagnostics + + +def _slugify_label(label: str) -> str: + slug = re.sub(r"[^A-Za-z0-9_.-]+", "-", label.strip()).strip("-") + return slug or "variant" + + +def _log(message: str) -> None: + timestamp = datetime.now().isoformat(timespec="seconds") + print(f"[{timestamp}] {message}", file=sys.stderr, flush=True) + + +def _penalty_label(value: float) -> str: + if value == 0.0: + return "pe_native_unconstrained" + return f"pe_native_l2_{value:g}".replace("+", "") + + +def _parse_existing_candidates(values: Sequence[str] | None) -> dict[str, Path]: + candidates: dict[str, Path] = {} + for value in values or (): + if "=" not in value: + raise ValueError( + "--existing-candidate must be formatted as label=/path/to/file.h5" + ) + label, path = value.split("=", 1) + label = label.strip() + if not label: + raise ValueError("--existing-candidate label cannot be empty") + candidates[label] = Path(path).expanduser() + return candidates + + +def _parse_float_list(value: str | None) -> tuple[float, ...]: + if value is None: + return () + stripped = value.strip() + if not stripped: + return () + return tuple(float(item.strip()) for item in stripped.split(",") if item.strip()) + + +def _resolve_target_total_weight( + *, + input_dataset_path: str | Path, + baseline_dataset_path: str | Path, + period: int, + target_total_weight: float | None, + target_total_weight_source: str, +) -> tuple[float | None, str]: + if target_total_weight is not None: + return float(target_total_weight), "explicit" + if target_total_weight_source == "preserve-input": + return None, "preserve-input" + if target_total_weight_source == "input": + _, input_weights = _household_weights(input_dataset_path, period=period) + return float(input_weights.sum()), "input" + if target_total_weight_source == "baseline": + _, baseline_weights = _household_weights(baseline_dataset_path, period=period) + return float(baseline_weights.sum()), "baseline" + raise ValueError( + "target_total_weight_source must be one of preserve-input, input, baseline" + ) + + +def _extract_pe_native_loss_inputs( + *, + input_dataset_path: str | Path, + period: int, + policyengine_us_data_repo: str | Path | None, + policyengine_us_data_python: str | Path | None, + skip_tax_expenditure_targets: bool, +) -> dict[str, Any]: + resolved_repo = resolve_policyengine_us_data_repo_root(policyengine_us_data_repo) + env = build_policyengine_us_data_subprocess_env(resolved_repo) + if policyengine_us_data_python is not None: + command = [str(Path(policyengine_us_data_python).expanduser())] + else: + command = ["uv", "run", "--project", str(resolved_repo), "python"] + validate_policyengine_us_data_runtime( + command, + repo_root=resolved_repo, + env=env, + ) + _log("extracting PE-native loss matrix") + with TemporaryDirectory(prefix="microplex-us-pe-native-benchmark-") as temp_dir: + prefix = Path(temp_dir) / "pe_native_matrix" + started_at = perf_counter() + completed = subprocess.run( + [ + *command, + "-c", + _PE_NATIVE_BROAD_MATRIX_SCRIPT, + str(resolved_repo), + json.dumps(_ENHANCED_CPS_BAD_TARGETS), + str(int(period)), + str(Path(input_dataset_path).expanduser().resolve()), + "1" if skip_tax_expenditure_targets else "0", + str(prefix), + ], + cwd=resolved_repo, + env=env, + capture_output=True, + text=True, + check=False, + ) + if completed.returncode != 0: + detail = completed.stderr.strip() or completed.stdout.strip() or str( + completed.returncode + ) + raise RuntimeError(f"PE-native loss-matrix extraction failed: {detail}") + _log(f"extracted PE-native loss matrix in {perf_counter() - started_at:.1f}s") + return { + "scaled_matrix": np.load(prefix.with_suffix(".matrix.npy")), + "scaled_target": np.load(prefix.with_suffix(".target.npy")), + "initial_weights": np.load(prefix.with_suffix(".weights.npy")), + "metadata": json.loads(prefix.with_suffix(".meta.json").read_text()), + } + + +def build_policyengine_us_native_calibration_benchmark( + *, + input_dataset_path: str | Path, + baseline_dataset_path: str | Path, + output_dir: str | Path, + period: int = 2024, + l2_penalties: Sequence[float] = (0.0, 1e-12, 1e-10, 1e-8), + max_iter: int = 200, + tol: float = 1e-8, + budget: int | None = None, + target_total_weight: float | None = None, + target_total_weight_source: str = "preserve-input", + existing_candidates: Mapping[str, str | Path] | None = None, + policyengine_us_data_repo: str | Path | None = None, + policyengine_us_data_python: str | Path | None = None, + batch_households: int | None = None, + baseline_cache_dir: str | Path | None = _DEFAULT_PE_NATIVE_BASELINE_CACHE_DIR, + skip_tax_expenditure_targets: bool = False, + force: bool = False, +) -> dict[str, Any]: + """Run and score PE-native calibration variants against one baseline.""" + + started_at = perf_counter() + input_path = Path(input_dataset_path).expanduser().resolve() + baseline_path = Path(baseline_dataset_path).expanduser().resolve() + destination = Path(output_dir).expanduser().resolve() + destination.mkdir(parents=True, exist_ok=True) + + resolved_target_total_weight, target_total_weight_resolved_from = ( + _resolve_target_total_weight( + input_dataset_path=input_path, + baseline_dataset_path=baseline_path, + period=period, + target_total_weight=target_total_weight, + target_total_weight_source=target_total_weight_source, + ) + ) + + variants: list[CalibrationBenchmarkVariant] = [ + CalibrationBenchmarkVariant( + label="input", + method="existing_input", + dataset_path=str(input_path), + ) + ] + for label, path in (existing_candidates or {}).items(): + variants.append( + CalibrationBenchmarkVariant( + label=label, + method="existing_candidate", + dataset_path=str(Path(path).expanduser().resolve()), + ) + ) + + loss_inputs = ( + _extract_pe_native_loss_inputs( + input_dataset_path=input_path, + period=period, + policyengine_us_data_repo=policyengine_us_data_repo, + policyengine_us_data_python=policyengine_us_data_python, + skip_tax_expenditure_targets=skip_tax_expenditure_targets, + ) + if l2_penalties + else None + ) + + for penalty in l2_penalties: + penalty = float(penalty) + label = _penalty_label(penalty) + if resolved_target_total_weight is not None: + label = f"{label}_{target_total_weight_resolved_from}_total" + output_path = destination / f"{_slugify_label(label)}.h5" + optimization_path = output_path.with_suffix(".optimization.json") + if force or not output_path.exists(): + if loss_inputs is None: + raise RuntimeError("PE-native loss inputs were not extracted") + _log(f"optimizing {label} with l2_penalty={penalty:g}") + optimization_started_at = perf_counter() + optimized_weights, summary = optimize_pe_native_loss_weights( + scaled_matrix=loss_inputs["scaled_matrix"], + scaled_target=loss_inputs["scaled_target"], + initial_weights=loss_inputs["initial_weights"], + budget=budget, + max_iter=max_iter, + l2_penalty=penalty, + tol=tol, + target_total_weight=resolved_target_total_weight, + ) + _log( + f"optimized {label} in " + f"{perf_counter() - optimization_started_at:.1f}s; " + f"loss {summary['initial_loss']:.6g} -> " + f"{summary['optimized_loss']:.6g}" + ) + _log(f"rewriting weights for {label}") + rewritten = rewrite_policyengine_us_dataset_weights( + input_dataset_path=input_path, + output_dataset_path=output_path, + household_weights=optimized_weights, + period=period, + ) + optimization = { + "metric": "enhanced_cps_native_loss_weight_optimization", + "period": int(period), + "input_dataset": str(input_path), + "output_dataset": str(rewritten), + "initial_loss": float(summary["initial_loss"]), + "optimized_loss": float(summary["optimized_loss"]), + "loss_delta": float(summary["loss_delta"]), + "initial_weight_sum": float(summary["initial_weight_sum"]), + "optimized_weight_sum": float(summary["optimized_weight_sum"]), + "household_count": int(summary["household_count"]), + "positive_household_count": int( + summary["positive_household_count"] + ), + "budget": summary["budget"], + "converged": bool(summary["converged"]), + "iterations": int(summary["iterations"]), + "target_names": list(loss_inputs["metadata"]["target_names"]), + "skip_tax_expenditure_targets": bool( + loss_inputs["metadata"].get( + "skip_tax_expenditure_targets", + skip_tax_expenditure_targets, + ) + ), + "l2_penalty": penalty, + "target_total_weight": resolved_target_total_weight, + "target_total_weight_resolved_from": target_total_weight_resolved_from, + "step_size": summary.get("step_size"), + "history_interval": summary.get("history_interval"), + "loss_history": summary.get("loss_history", []), + "reused_existing_output": False, + } + optimization_path.write_text( + json.dumps(optimization, indent=2, sort_keys=True, allow_nan=False) + ) + else: + _log(f"reusing existing optimized dataset for {label}") + optimization = ( + json.loads(optimization_path.read_text()) + if optimization_path.exists() + else {} + ) + optimization.update( + { + "l2_penalty": penalty, + "target_total_weight": resolved_target_total_weight, + "target_total_weight_resolved_from": ( + target_total_weight_resolved_from + ), + "reused_existing_output": True, + } + ) + variants.append( + CalibrationBenchmarkVariant( + label=label, + method="pe_native_weight_optimization", + dataset_path=str(output_path.resolve()), + generated=True, + optimization=optimization, + ) + ) + + _log(f"scoring {len(variants)} calibration variants") + scoring_started_at = perf_counter() + scores = compute_batch_us_pe_native_scores( + candidate_dataset_paths=[variant.dataset_path for variant in variants], + baseline_dataset_path=baseline_path, + period=period, + policyengine_us_data_repo=policyengine_us_data_repo, + policyengine_us_data_python=policyengine_us_data_python, + batch_households=batch_households, + baseline_cache_dir=baseline_cache_dir, + skip_tax_expenditure_targets=skip_tax_expenditure_targets, + ) + _log(f"scored variants in {perf_counter() - scoring_started_at:.1f}s") + scores_by_dataset = { + str(Path(score["broad_loss"]["candidate_dataset"]).resolve()): score + for score in scores + } + + rows: list[dict[str, Any]] = [] + for variant in variants: + dataset_key = str(Path(variant.dataset_path).resolve()) + score = scores_by_dataset[dataset_key] + broad_loss = score["broad_loss"] + rows.append( + { + **variant.to_dict(), + "score_summary": score["summary"], + "broad_loss": broad_loss, + "family_breakdown": score.get("family_breakdown", []), + "weight_diagnostics": compute_household_weight_diagnostics( + variant.dataset_path, + period=period, + reference_dataset_path=input_path, + ), + } + ) + + ranked_rows = sorted( + rows, + key=lambda row: row["score_summary"]["candidate_enhanced_cps_native_loss"], + ) + baseline_loss = ( + float(rows[0]["score_summary"]["baseline_enhanced_cps_native_loss"]) + if rows + else None + ) + payload: dict[str, Any] = { + "schema_version": 1, + "metric": "pe_native_calibration_strategy_benchmark", + "period": int(period), + "input_dataset": str(input_path), + "baseline_dataset": str(baseline_path), + "output_dir": str(destination), + "skip_tax_expenditure_targets": bool(skip_tax_expenditure_targets), + "target_total_weight": resolved_target_total_weight, + "target_total_weight_resolved_from": target_total_weight_resolved_from, + "budget": None if budget is None else int(budget), + "max_iter": int(max_iter), + "tol": float(tol), + "l2_penalties": [float(value) for value in l2_penalties], + "baseline_enhanced_cps_native_loss": baseline_loss, + "best_variant_label": ranked_rows[0]["label"] if ranked_rows else None, + "best_variant_loss": ( + float( + ranked_rows[0]["score_summary"][ + "candidate_enhanced_cps_native_loss" + ] + ) + if ranked_rows + else None + ), + "variant_count": len(rows), + "rows": rows, + "ranking": [ + { + "label": row["label"], + "method": row["method"], + "candidate_enhanced_cps_native_loss": row["score_summary"][ + "candidate_enhanced_cps_native_loss" + ], + "enhanced_cps_native_loss_delta": row["score_summary"][ + "enhanced_cps_native_loss_delta" + ], + "effective_sample_size_share": row["weight_diagnostics"][ + "effective_sample_size_share" + ], + "l1_delta_as_share_of_reference_sum": row["weight_diagnostics"].get( + "l1_delta_as_share_of_reference_sum" + ), + } + for row in ranked_rows + ], + "elapsed_seconds": perf_counter() - started_at, + } + return payload + + +def write_policyengine_us_native_calibration_benchmark( + output_path: str | Path, + **kwargs: Any, +) -> Path: + """Build a PE-native calibration benchmark and write it as JSON.""" + + payload = build_policyengine_us_native_calibration_benchmark(**kwargs) + destination = Path(output_path).expanduser().resolve() + destination.parent.mkdir(parents=True, exist_ok=True) + destination.write_text(json.dumps(payload, indent=2, sort_keys=True)) + return destination + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser( + description=( + "Benchmark input, existing, unconstrained, and penalized PE-native " + "calibration variants on the same PE-native broad target surface." + ) + ) + parser.add_argument("--input-dataset", required=True) + parser.add_argument("--baseline-dataset", required=True) + parser.add_argument("--output-dir", required=True) + parser.add_argument( + "--output-path", + help=( + "Benchmark JSON path. Defaults to " + "/pe_native_calibration_benchmark.json." + ), + ) + parser.add_argument("--period", type=int, default=2024) + parser.add_argument( + "--l2-penalties", + default="0,1e-12,1e-10,1e-8", + help=( + "Comma-separated PE-native optimization penalties. " + "Use an empty string to score only existing datasets." + ), + ) + parser.add_argument("--max-iter", type=int, default=200) + parser.add_argument("--tol", type=float, default=1e-8) + parser.add_argument("--budget", type=int) + parser.add_argument("--target-total-weight", type=float) + parser.add_argument( + "--target-total-weight-source", + choices=("preserve-input", "input", "baseline"), + default="preserve-input", + ) + parser.add_argument( + "--existing-candidate", + action="append", + help="Add a precomputed variant as label=/path/to/candidate.h5.", + ) + parser.add_argument("--policyengine-us-data-python") + parser.add_argument("--policyengine-us-data-repo") + parser.add_argument("--batch-households", type=int) + parser.add_argument( + "--baseline-cache-dir", + default=str(_DEFAULT_PE_NATIVE_BASELINE_CACHE_DIR), + help="Pass an empty string to disable PE-native baseline estimate caching.", + ) + parser.add_argument( + "--skip-tax-expenditure-targets", + action="store_true", + ) + parser.add_argument( + "--force", + action="store_true", + help="Regenerate optimized H5 variants even if outputs already exist.", + ) + args = parser.parse_args(argv) + + output_dir = Path(args.output_dir).expanduser() + output_path = ( + Path(args.output_path).expanduser() + if args.output_path + else output_dir / "pe_native_calibration_benchmark.json" + ) + written = write_policyengine_us_native_calibration_benchmark( + output_path, + input_dataset_path=args.input_dataset, + baseline_dataset_path=args.baseline_dataset, + output_dir=output_dir, + period=args.period, + l2_penalties=_parse_float_list(args.l2_penalties), + max_iter=args.max_iter, + tol=args.tol, + budget=args.budget, + target_total_weight=args.target_total_weight, + target_total_weight_source=args.target_total_weight_source, + existing_candidates=_parse_existing_candidates(args.existing_candidate), + policyengine_us_data_repo=args.policyengine_us_data_repo, + policyengine_us_data_python=args.policyengine_us_data_python, + batch_households=args.batch_households, + baseline_cache_dir=args.baseline_cache_dir or None, + skip_tax_expenditure_targets=args.skip_tax_expenditure_targets, + force=args.force, + ) + print(str(written)) + return 0 + + +__all__ = [ + "CalibrationBenchmarkVariant", + "build_policyengine_us_native_calibration_benchmark", + "compute_household_weight_diagnostics", + "write_policyengine_us_native_calibration_benchmark", +] diff --git a/src/microplex_us/pipelines/pe_us_dataset_readiness.py b/src/microplex_us/pipelines/pe_us_dataset_readiness.py new file mode 100644 index 0000000..6f87445 --- /dev/null +++ b/src/microplex_us/pipelines/pe_us_dataset_readiness.py @@ -0,0 +1,502 @@ +"""Lightweight readiness audit for exported PolicyEngine-US datasets.""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Any + +import h5py +import numpy as np + +DEFAULT_PERIOD = 2024 +DEFAULT_REQUIRED_VARIABLES: dict[str, str] = { + "household_id": "household", + "household_weight": "household", + "person_id": "person", + "person_household_id": "person", + "tax_unit_id": "tax_unit", + "person_tax_unit_id": "person", + "spm_unit_id": "spm_unit", + "person_spm_unit_id": "person", + "state_fips": "household", + "county_fips": "household", + "congressional_district_geoid": "household", + "spm_unit_spm_threshold": "spm_unit", + "spm_unit_tenure_type": "spm_unit", +} +DEFAULT_EXPECTED_MATERIALIZED_VARIABLES = ( + "income_tax", + "income_tax_positive", + "eitc", + "ctc", + "refundable_ctc", + "non_refundable_ctc", + "snap", + "ssi", + "tanf", + "medicaid", + "aca_ptc", +) +DEFAULT_EXPECTED_SPINES = ("cps_asec", "acs_pums") + + +def build_policyengine_us_dataset_readiness_audit( + path: str | Path, + *, + period: int | str = DEFAULT_PERIOD, + expected_materialized_variables: tuple[str, ...] = DEFAULT_EXPECTED_MATERIALIZED_VARIABLES, + required_variables: dict[str, str] | None = None, + expected_spines: tuple[str, ...] = DEFAULT_EXPECTED_SPINES, + minimum_nonmissing_share: float = 0.999, +) -> dict[str, Any]: + """Inspect a saved artifact bundle or ``policyengine_us.h5`` export. + + This audit intentionally avoids running PolicyEngine. It only checks that the + exported H5 has the structural, geography, SPM-threshold, and materialized + policy-output arrays we expect before expensive native scoring starts. + """ + + input_path = Path(path).expanduser() + artifact_dir = input_path if input_path.is_dir() else None + dataset_path = _resolve_dataset_path(input_path) + manifest = _load_optional_json(artifact_dir / "manifest.json") if artifact_dir else None + source_spine_composition = ( + _load_optional_json(artifact_dir / "source_spine_composition.json") + if artifact_dir + else None + ) + period_key = str(period) + required = dict(required_variables or DEFAULT_REQUIRED_VARIABLES) + expected_variables = tuple(dict.fromkeys(expected_materialized_variables)) + issues: list[dict[str, Any]] = [] + + with h5py.File(dataset_path, "r") as handle: + entity_counts = _entity_counts(handle, period_key) + variable_summaries = { + variable: _variable_summary( + handle, + variable, + period_key=period_key, + entity_counts=entity_counts, + preferred_entity=required.get(variable), + ) + for variable in sorted(set(required) | set(expected_variables)) + } + + for variable, expected_entity in required.items(): + summary = variable_summaries[variable] + _append_variable_presence_issues( + issues, + variable=variable, + summary=summary, + expected_entity=expected_entity, + minimum_nonmissing_share=minimum_nonmissing_share, + required=True, + ) + for variable in expected_variables: + summary = variable_summaries[variable] + _append_variable_presence_issues( + issues, + variable=variable, + summary=summary, + expected_entity=None, + minimum_nonmissing_share=0.0, + required=True, + ) + + _append_source_spine_issues( + issues, + source_spine_composition=source_spine_composition, + expected_spines=expected_spines, + ) + valid = not any(issue["severity"] == "error" for issue in issues) + return { + "schemaVersion": 1, + "valid": valid, + "inputPath": str(input_path), + "artifactDir": str(artifact_dir) if artifact_dir is not None else None, + "datasetPath": str(dataset_path), + "period": int(period) if str(period).isdigit() else str(period), + "entityCounts": entity_counts, + "requiredVariables": required, + "expectedMaterializedVariables": list(expected_variables), + "variableSummaries": variable_summaries, + "sourceSpineComposition": _source_spine_summary(source_spine_composition), + "manifestSummary": _manifest_summary(manifest), + "issues": issues, + } + + +def write_policyengine_us_dataset_readiness_audit( + path: str | Path, + output_path: str | Path | None = None, + **kwargs: Any, +) -> Path: + """Write a readiness audit JSON sidecar.""" + + input_path = Path(path).expanduser() + destination = ( + Path(output_path) + if output_path is not None + else ( + input_path / "policyengine_dataset_readiness.json" + if input_path.is_dir() + else input_path.with_name(f"{input_path.stem}_readiness.json") + ) + ) + payload = build_policyengine_us_dataset_readiness_audit(input_path, **kwargs) + destination.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n") + return destination + + +def _resolve_dataset_path(path: Path) -> Path: + if path.is_file(): + return path.resolve() + if not path.is_dir(): + raise FileNotFoundError(f"Dataset or artifact directory not found: {path}") + manifest_path = path / "manifest.json" + if manifest_path.exists(): + manifest = json.loads(manifest_path.read_text()) + dataset_name = dict(manifest.get("artifacts", {})).get("policyengine_dataset") + if isinstance(dataset_name, str) and dataset_name: + dataset_path = path / dataset_name + if dataset_path.exists(): + return dataset_path.resolve() + dataset_path = path / "policyengine_us.h5" + if dataset_path.exists(): + return dataset_path.resolve() + raise FileNotFoundError(f"No policyengine_us.h5 export found under {path}") + + +def _load_optional_json(path: Path) -> dict[str, Any] | None: + if not path.exists(): + return None + return json.loads(path.read_text()) + + +def _entity_counts(handle: h5py.File, period_key: str) -> dict[str, int | None]: + variables = { + "household": "household_id", + "person": "person_id", + "tax_unit": "tax_unit_id", + "spm_unit": "spm_unit_id", + } + return { + entity: _dataset_length(handle, variable, period_key) + for entity, variable in variables.items() + } + + +def _dataset_length( + handle: h5py.File, + variable: str, + period_key: str, +) -> int | None: + if variable not in handle or period_key not in handle[variable]: + return None + return int(len(handle[variable][period_key])) + + +def _variable_summary( + handle: h5py.File, + variable: str, + *, + period_key: str, + entity_counts: dict[str, int | None], + preferred_entity: str | None = None, +) -> dict[str, Any]: + if variable not in handle: + return {"exists": False, "hasPeriod": False} + group = handle[variable] + if period_key not in group: + return { + "exists": True, + "hasPeriod": False, + "availablePeriods": sorted(str(key) for key in group.keys()), + } + values = np.asarray(group[period_key]) + length = int(values.shape[0]) if values.shape else 1 + profile = _array_profile(values) + return { + "exists": True, + "hasPeriod": True, + "length": length, + "dtype": str(values.dtype), + "entity": _infer_entity( + length, + entity_counts, + preferred_entity=preferred_entity, + ), + **profile, + } + + +def _array_profile(values: np.ndarray) -> dict[str, Any]: + flat = np.ravel(values) + if flat.dtype.kind in {"b", "i", "u", "f"}: + numeric = flat.astype(float, copy=False) + finite = np.isfinite(numeric) + positive = finite & (numeric > 0.0) + nonzero = finite & (numeric != 0.0) + return { + "finiteCount": int(finite.sum()), + "finiteShare": _share(finite.sum(), len(flat)), + "nonmissingCount": int(finite.sum()), + "nonmissingShare": _share(finite.sum(), len(flat)), + "positiveCount": int(positive.sum()), + "positiveShare": _share(positive.sum(), len(flat)), + "nonzeroCount": int(nonzero.sum()), + "nonzeroShare": _share(nonzero.sum(), len(flat)), + } + decoded = _decode_string_array(flat) + nonmissing = np.array( + [bool(value.strip()) and value.strip().lower() != "nan" for value in decoded], + dtype=bool, + ) + return { + "nonmissingCount": int(nonmissing.sum()), + "nonmissingShare": _share(nonmissing.sum(), len(flat)), + } + + +def _decode_string_array(values: np.ndarray) -> list[str]: + result: list[str] = [] + for value in values.tolist(): + if isinstance(value, bytes): + result.append(value.decode("utf-8", errors="replace")) + else: + result.append(str(value)) + return result + + +def _share(numerator: int | np.integer, denominator: int) -> float | None: + if denominator == 0: + return None + return float(numerator) / float(denominator) + + +def _infer_entity( + length: int, + entity_counts: dict[str, int | None], + *, + preferred_entity: str | None = None, +) -> str | None: + matches = [ + entity + for entity, count in entity_counts.items() + if count is not None and int(count) == length + ] + if preferred_entity is not None and preferred_entity in matches: + return preferred_entity + if len(matches) == 1: + return matches[0] + if len(matches) > 1: + return "|".join(matches) + return None + + +def _append_variable_presence_issues( + issues: list[dict[str, Any]], + *, + variable: str, + summary: dict[str, Any], + expected_entity: str | None, + minimum_nonmissing_share: float, + required: bool, +) -> None: + severity = "error" if required else "warning" + if not summary.get("exists"): + issues.append( + { + "severity": severity, + "code": "missing_variable", + "variable": variable, + "message": f"Dataset is missing {variable!r}", + } + ) + return + if not summary.get("hasPeriod"): + issues.append( + { + "severity": severity, + "code": "missing_period", + "variable": variable, + "message": f"Dataset variable {variable!r} is missing the requested period", + } + ) + return + entity = summary.get("entity") + if expected_entity is not None and entity != expected_entity: + issues.append( + { + "severity": "error", + "code": "entity_length_mismatch", + "variable": variable, + "expectedEntity": expected_entity, + "observedEntity": entity, + "message": ( + f"Variable {variable!r} has length matching {entity!r}, " + f"expected {expected_entity!r}" + ), + } + ) + nonmissing_share = summary.get("nonmissingShare") + if ( + isinstance(nonmissing_share, int | float) + and nonmissing_share < minimum_nonmissing_share + ): + issues.append( + { + "severity": "error" if required else "warning", + "code": "low_nonmissing_share", + "variable": variable, + "nonmissingShare": float(nonmissing_share), + "minimumNonmissingShare": minimum_nonmissing_share, + "message": ( + f"Variable {variable!r} nonmissing share " + f"{nonmissing_share:.4f} is below {minimum_nonmissing_share:.4f}" + ), + } + ) + if variable == "spm_unit_spm_threshold": + positive_share = summary.get("positiveShare") + if isinstance(positive_share, int | float) and positive_share < 0.999: + issues.append( + { + "severity": "error", + "code": "low_positive_spm_threshold_share", + "variable": variable, + "positiveShare": float(positive_share), + "message": "SPM thresholds should be positive for nearly all SPM units", + } + ) + + +def _append_source_spine_issues( + issues: list[dict[str, Any]], + *, + source_spine_composition: dict[str, Any] | None, + expected_spines: tuple[str, ...], +) -> None: + if not expected_spines: + return + if source_spine_composition is None: + issues.append( + { + "severity": "warning", + "code": "missing_source_spine_composition", + "message": "Artifact has no source_spine_composition.json sidecar", + } + ) + return + observed = { + str(group.get("spine")) + for group in source_spine_composition.get("groups", ()) + if group.get("spine") is not None + } + missing = sorted(set(expected_spines) - observed) + if missing: + issues.append( + { + "severity": "error", + "code": "missing_expected_spines", + "missingSpines": missing, + "observedSpines": sorted(observed), + "message": "Source-spine composition is missing expected spines", + } + ) + + +def _source_spine_summary(payload: dict[str, Any] | None) -> dict[str, Any] | None: + if payload is None: + return None + return { + "householdCount": payload.get("household_count"), + "nonzeroHouseholdCount": payload.get("nonzero_household_count"), + "totalActiveWeight": payload.get("total_active_weight"), + "effectiveSampleSize": payload.get("effective_sample_size"), + "groups": [ + { + "spine": group.get("spine"), + "householdCount": group.get("household_count"), + "nonzeroHouseholdCount": group.get("nonzero_household_count"), + "totalActiveWeight": group.get("total_active_weight"), + "totalSourceWeight": group.get("total_source_weight"), + } + for group in payload.get("groups", ()) + ], + } + + +def _manifest_summary(payload: dict[str, Any] | None) -> dict[str, Any] | None: + if payload is None: + return None + artifacts = dict(payload.get("artifacts", {})) + return { + "rows": payload.get("rows"), + "weights": payload.get("weights"), + "policyengineDataset": artifacts.get("policyengine_dataset"), + "policyengineNativeScores": artifacts.get("policyengine_native_scores"), + "sourceSpineComposition": artifacts.get("source_spine_composition"), + } + + +def main(argv: list[str] | None = None) -> int: + """CLI entry point for dataset readiness audits.""" + + parser = argparse.ArgumentParser( + description="Audit a saved MicroPlex PE-US H5 export before native scoring.", + ) + parser.add_argument("path", help="Artifact directory or policyengine_us.h5 path") + parser.add_argument("--period", default=DEFAULT_PERIOD) + parser.add_argument("--output-path") + parser.add_argument( + "--expected-materialized-variable", + action="append", + default=None, + help="Calculated PolicyEngine variable expected in the H5. Repeatable.", + ) + parser.add_argument( + "--expected-spine", + action="append", + default=None, + help="Source spine expected in source_spine_composition.json. Repeatable.", + ) + args = parser.parse_args(argv) + + output = write_policyengine_us_dataset_readiness_audit( + args.path, + output_path=args.output_path, + period=args.period, + expected_materialized_variables=tuple( + args.expected_materialized_variable + if args.expected_materialized_variable is not None + else DEFAULT_EXPECTED_MATERIALIZED_VARIABLES + ), + expected_spines=tuple( + args.expected_spine + if args.expected_spine is not None + else DEFAULT_EXPECTED_SPINES + ), + ) + payload = json.loads(output.read_text()) + print( + json.dumps( + { + "output": str(output), + "valid": payload["valid"], + "datasetPath": payload["datasetPath"], + "entityCounts": payload["entityCounts"], + "issueCount": len(payload["issues"]), + }, + indent=2, + sort_keys=True, + ) + ) + return 0 if payload["valid"] else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/microplex_us/pipelines/r2_artifacts.py b/src/microplex_us/pipelines/r2_artifacts.py new file mode 100644 index 0000000..ebd83bb --- /dev/null +++ b/src/microplex_us/pipelines/r2_artifacts.py @@ -0,0 +1,420 @@ +"""Archive Microplex artifact directories to Cloudflare R2.""" + +from __future__ import annotations + +import argparse +import hashlib +import json +import os +import sys +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +R2_ARCHIVE_MANIFEST_FILENAME = "r2_archive_manifest.json" +R2_ARCHIVE_INDEX_FILENAME = "r2_archive_index.jsonl" +DEFAULT_R2_PREFIX = "microplex-us/artifacts" +DEFAULT_REGION = "auto" +SUMMARY_FILENAMES = frozenset( + { + "manifest.json", + "summary.md", + "scores.json", + "fit_summary.json", + "target_deltas_top50.json", + "matrix_residual_drilldown_top100.json", + "calibration_summary.json", + "source_spine_composition.json", + "support_audit.json", + "run_manifest.json", + } +) + + +@dataclass(frozen=True) +class R2ArchiveConfig: + """R2 destination and credentials for one archive operation.""" + + bucket: str + endpoint_url: str + prefix: str = DEFAULT_R2_PREFIX + region: str = DEFAULT_REGION + access_key_id: str | None = None + secret_access_key: str | None = None + session_token: str | None = None + + @classmethod + def from_env(cls) -> R2ArchiveConfig: + """Build config from Microplex-specific env vars with R2/AWS fallbacks.""" + bucket = _first_env("MICROPLEX_R2_BUCKET", "R2_BUCKET", "AWS_BUCKET") + if not bucket: + raise ValueError( + "Missing R2 bucket. Set MICROPLEX_R2_BUCKET or pass --bucket." + ) + endpoint_url = _first_env("MICROPLEX_R2_ENDPOINT_URL", "R2_ENDPOINT_URL") + account_id = _first_env("MICROPLEX_R2_ACCOUNT_ID", "CLOUDFLARE_ACCOUNT_ID") + if not endpoint_url: + if not account_id: + raise ValueError( + "Missing R2 endpoint. Set MICROPLEX_R2_ENDPOINT_URL, " + "R2_ENDPOINT_URL, or CLOUDFLARE_ACCOUNT_ID." + ) + endpoint_url = f"https://{account_id}.r2.cloudflarestorage.com" + return cls( + bucket=bucket, + endpoint_url=endpoint_url, + prefix=( + _first_env("MICROPLEX_R2_PREFIX", "R2_PREFIX") + or DEFAULT_R2_PREFIX + ), + region=_first_env("MICROPLEX_R2_REGION", "AWS_DEFAULT_REGION") + or DEFAULT_REGION, + access_key_id=_first_env( + "MICROPLEX_R2_ACCESS_KEY_ID", + "R2_ACCESS_KEY_ID", + "AWS_ACCESS_KEY_ID", + ), + secret_access_key=_first_env( + "MICROPLEX_R2_SECRET_ACCESS_KEY", + "R2_SECRET_ACCESS_KEY", + "AWS_SECRET_ACCESS_KEY", + ), + session_token=_first_env( + "MICROPLEX_R2_SESSION_TOKEN", + "R2_SESSION_TOKEN", + "AWS_SESSION_TOKEN", + ), + ) + + +def _first_env(*names: str) -> str | None: + for name in names: + value = os.environ.get(name) + if value: + return value + return None + + +def normalize_r2_prefix(value: str) -> str: + """Normalize an R2 key prefix without changing internal separators.""" + return value.strip("/") + + +def build_r2_object_key(prefix: str, artifact_id: str, relative_path: str) -> str: + """Return a stable R2 object key for one artifact file.""" + parts = [ + normalize_r2_prefix(prefix), + artifact_id.strip("/"), + relative_path.replace(os.sep, "/").strip("/"), + ] + return "/".join(part for part in parts if part) + + +def iter_artifact_files(artifact_dir: str | Path) -> list[Path]: + """List regular artifact files, excluding the local R2 archive sidecar.""" + root = Path(artifact_dir) + return sorted( + path + for path in root.rglob("*") + if path.is_file() and path.name != R2_ARCHIVE_MANIFEST_FILENAME + ) + + +def file_sha256(path: str | Path, *, chunk_size: int = 1024 * 1024) -> str: + """Compute a SHA-256 digest for a local file.""" + digest = hashlib.sha256() + with Path(path).open("rb") as file: + for chunk in iter(lambda: file.read(chunk_size), b""): + digest.update(chunk) + return digest.hexdigest() + + +def build_archive_manifest( + artifact_dir: str | Path, + config: R2ArchiveConfig, + *, + artifact_id: str | None = None, + hash_files: bool = True, + status: str = "planned", +) -> dict[str, Any]: + """Build the local manifest describing files and destination object keys.""" + root = Path(artifact_dir).resolve() + if not root.is_dir(): + raise NotADirectoryError(f"Artifact directory not found: {root}") + resolved_artifact_id = artifact_id or root.name + files: list[dict[str, Any]] = [] + total_bytes = 0 + for path in iter_artifact_files(root): + relative_path = path.relative_to(root).as_posix() + size_bytes = path.stat().st_size + total_bytes += size_bytes + entry: dict[str, Any] = { + "path": relative_path, + "size_bytes": size_bytes, + "object_key": build_r2_object_key( + config.prefix, + resolved_artifact_id, + relative_path, + ), + "status": status, + "summary": path.name in SUMMARY_FILENAMES, + } + if hash_files: + entry["sha256"] = file_sha256(path) + files.append(entry) + return { + "schema_version": 1, + "created_at": datetime.now(UTC).isoformat(), + "artifact_id": resolved_artifact_id, + "artifact_dir": str(root), + "r2": { + "bucket": config.bucket, + "endpoint_url": config.endpoint_url, + "prefix": normalize_r2_prefix(config.prefix), + "region": config.region, + "manifest_object_key": build_r2_object_key( + config.prefix, + resolved_artifact_id, + R2_ARCHIVE_MANIFEST_FILENAME, + ), + }, + "summary_files": [ + entry["path"] for entry in files if bool(entry.get("summary")) + ], + "file_count": len(files), + "total_bytes": total_bytes, + "files": files, + } + + +def create_r2_s3_client(config: R2ArchiveConfig) -> Any: + """Create a boto3 S3 client configured for Cloudflare R2.""" + try: + import boto3 + except ImportError as error: # pragma: no cover - exercised by CLI environment. + raise RuntimeError( + "boto3 is required for R2 uploads. Install the optional extra with " + "`uv sync --extra r2` or run through `uv run --extra r2 ...`." + ) from error + client_kwargs: dict[str, Any] = { + "service_name": "s3", + "endpoint_url": config.endpoint_url, + "region_name": config.region, + } + if config.access_key_id is not None: + client_kwargs["aws_access_key_id"] = config.access_key_id + if config.secret_access_key is not None: + client_kwargs["aws_secret_access_key"] = config.secret_access_key + if config.session_token is not None: + client_kwargs["aws_session_token"] = config.session_token + return boto3.client(**client_kwargs) + + +def upload_artifact_manifest_to_r2( + artifact_dir: str | Path, + config: R2ArchiveConfig, + *, + artifact_id: str | None = None, + client: Any | None = None, + dry_run: bool = False, + force: bool = False, + hash_files: bool = True, +) -> dict[str, Any]: + """Upload an artifact directory to R2 and write a local upload manifest.""" + root = Path(artifact_dir).resolve() + manifest = build_archive_manifest( + root, + config, + artifact_id=artifact_id, + hash_files=hash_files, + status="dry_run" if dry_run else "pending", + ) + local_manifest_path = root / R2_ARCHIVE_MANIFEST_FILENAME + if dry_run: + _write_json(local_manifest_path, manifest) + return manifest + s3 = client or create_r2_s3_client(config) + for entry in manifest["files"]: + path = root / entry["path"] + object_key = entry["object_key"] + if not force and _object_exists(s3, config.bucket, object_key): + entry["status"] = "already_exists" + continue + s3.upload_file(str(path), config.bucket, object_key) + entry["status"] = "uploaded" + entry["uploaded_at"] = datetime.now(UTC).isoformat() + manifest["completed_at"] = datetime.now(UTC).isoformat() + manifest["status"] = "uploaded" + _write_json(local_manifest_path, manifest) + manifest_key = manifest["r2"]["manifest_object_key"] + s3.upload_file(str(local_manifest_path), config.bucket, manifest_key) + return manifest + + +def append_archive_index_entry( + index_path: str | Path, + manifest: dict[str, Any], + *, + pruned_local: bool = False, +) -> Path: + """Append a compact archive record to a local JSONL index.""" + path = Path(index_path) + path.parent.mkdir(parents=True, exist_ok=True) + entry = { + "recorded_at": datetime.now(UTC).isoformat(), + "artifact_id": manifest["artifact_id"], + "artifact_dir": manifest["artifact_dir"], + "bucket": manifest["r2"]["bucket"], + "prefix": manifest["r2"]["prefix"], + "manifest_object_key": manifest["r2"]["manifest_object_key"], + "file_count": manifest["file_count"], + "total_bytes": manifest["total_bytes"], + "status": manifest.get("status"), + "pruned_local": pruned_local, + } + with path.open("a") as file: + file.write(json.dumps(entry, sort_keys=True) + "\n") + return path + + +def _object_exists(client: Any, bucket: str, key: str) -> bool: + try: + client.head_object(Bucket=bucket, Key=key) + except Exception as error: # noqa: BLE001 - boto3 exposes provider-specific errors. + response = getattr(error, "response", None) + code = None + if isinstance(response, dict): + code = str(response.get("Error", {}).get("Code", "")) + if code in {"404", "NoSuchKey", "NotFound"}: + return False + # Some fakes and S3-compatible clients use a generic missing-object error. + if error.__class__.__name__ in {"NoSuchKey", "NotFound"}: + return False + raise + return True + + +def _write_json(path: str | Path, payload: dict[str, Any]) -> None: + resolved = Path(path) + temp_path = resolved.with_suffix(resolved.suffix + ".tmp") + temp_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n") + temp_path.replace(resolved) + + +def _build_config_from_args(args: argparse.Namespace) -> R2ArchiveConfig: + env_config: R2ArchiveConfig | None = None + if args.bucket is None or args.endpoint_url is None: + try: + env_config = R2ArchiveConfig.from_env() + except ValueError: + if args.bucket is None or args.endpoint_url is None: + raise + bucket = args.bucket or (env_config.bucket if env_config is not None else None) + endpoint_url = args.endpoint_url or ( + env_config.endpoint_url if env_config is not None else None + ) + if bucket is None or endpoint_url is None: + raise ValueError("Both bucket and endpoint URL are required.") + return R2ArchiveConfig( + bucket=bucket, + endpoint_url=endpoint_url, + prefix=args.prefix + or (env_config.prefix if env_config is not None else DEFAULT_R2_PREFIX), + region=args.region + or (env_config.region if env_config is not None else DEFAULT_REGION), + access_key_id=( + args.access_key_id + or (env_config.access_key_id if env_config is not None else None) + ), + secret_access_key=( + args.secret_access_key + or (env_config.secret_access_key if env_config is not None else None) + ), + session_token=( + args.session_token + or (env_config.session_token if env_config is not None else None) + ), + ) + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser( + description="Archive a Microplex artifact directory to Cloudflare R2." + ) + parser.add_argument("artifact_dir", type=Path) + parser.add_argument("--artifact-id", default=None) + parser.add_argument("--bucket", default=None) + parser.add_argument("--endpoint-url", default=None) + parser.add_argument("--prefix", default=None) + parser.add_argument("--region", default=None) + parser.add_argument("--access-key-id", default=None) + parser.add_argument("--secret-access-key", default=None) + parser.add_argument("--session-token", default=None) + parser.add_argument( + "--dry-run", + action="store_true", + help="Write the local archive manifest without uploading to R2.", + ) + parser.add_argument( + "--force", + action="store_true", + help="Upload files even when an object with the same key already exists.", + ) + parser.add_argument( + "--no-hash", + action="store_true", + help="Skip SHA-256 file hashing when building the archive manifest.", + ) + parser.add_argument( + "--index-path", + type=Path, + default=None, + help=( + "Optional local JSONL archive index. Defaults to " + "/r2_archive_index.jsonl when uploading." + ), + ) + parser.add_argument( + "--mark-pruned-local", + action="store_true", + help="Mark the local archive-index row as pruned after external cleanup.", + ) + args = parser.parse_args(argv) + try: + config = _build_config_from_args(args) + manifest = upload_artifact_manifest_to_r2( + args.artifact_dir, + config, + artifact_id=args.artifact_id, + dry_run=args.dry_run, + force=args.force, + hash_files=not args.no_hash, + ) + except Exception as error: # noqa: BLE001 - CLI should report a concise failure. + print(f"R2 archive failed: {error}", file=sys.stderr) + return 1 + if not args.dry_run: + index_path = args.index_path or args.artifact_dir.parent / R2_ARCHIVE_INDEX_FILENAME + append_archive_index_entry( + index_path, + manifest, + pruned_local=args.mark_pruned_local, + ) + uploaded = sum( + 1 + for entry in manifest["files"] + if entry["status"] in {"uploaded", "already_exists"} + ) + mode = "planned" if args.dry_run else "archived" + print( + f"R2 artifact {mode}: {manifest['artifact_id']} " + f"({uploaded}/{manifest['file_count']} files, " + f"{manifest['total_bytes']} bytes)" + ) + print(args.artifact_dir / R2_ARCHIVE_MANIFEST_FILENAME) + return 0 + + +if __name__ == "__main__": # pragma: no cover + raise SystemExit(main()) diff --git a/src/microplex_us/pipelines/us.py b/src/microplex_us/pipelines/us.py index ec541a2..552a844 100644 --- a/src/microplex_us/pipelines/us.py +++ b/src/microplex_us/pipelines/us.py @@ -83,6 +83,7 @@ from microplex_us.policyengine.us import ( subset_policyengine_tables_by_households as _subset_policyengine_tables_by_households, ) +from microplex_us.targets.arch import resolve_arch_sqlite_target_provider from microplex_us.variables import ( PE_STYLE_PUF_IRS_DEMOGRAPHIC_PREDICTORS, DonorMatchStrategy, @@ -1407,6 +1408,8 @@ class USMicroplexBuildConfig: policyengine_prefer_existing_tax_unit_ids: bool = False policyengine_quantity_targets: tuple[PolicyEngineUSQuantityTarget, ...] = () policyengine_targets_db: str | None = None + arch_targets_db: str | tuple[str, ...] | None = None + calibration_target_source: Literal["policyengine", "arch"] = "policyengine" policyengine_target_period: int | None = None policyengine_target_variables: tuple[str, ...] = () policyengine_target_domains: tuple[str, ...] = () @@ -1786,7 +1789,7 @@ def build_from_frames( rows=int(len(synthetic_data)), columns=int(len(synthetic_data.columns)), ) - if self.config.policyengine_targets_db is not None: + if self._has_policyengine_calibration_targets(): _emit_us_pipeline_progress( "US microplex build: policyengine tables start", rows=int(len(synthetic_data)), @@ -2816,10 +2819,7 @@ def calibrate_policyengine_tables( tables: PolicyEngineUSEntityTableBundle, ) -> tuple[PolicyEngineUSEntityTableBundle, pd.DataFrame, dict[str, Any]]: """Calibrate household weights using PolicyEngine US target DB constraints.""" - if self.config.policyengine_targets_db is None: - raise ValueError("policyengine_targets_db is required for DB calibration") - - provider = PolicyEngineUSDBTargetProvider(self.config.policyengine_targets_db) + provider, _source = self._resolve_calibration_target_provider() target_period = ( self.config.policyengine_target_period or self.config.policyengine_dataset_year @@ -3629,9 +3629,33 @@ def _resolve_policyengine_calibration_targets( materialization_failures, ) + def _has_policyengine_calibration_targets(self) -> bool: + if self.config.calibration_target_source == "arch": + return self.config.arch_targets_db is not None + return self.config.policyengine_targets_db is not None + + def _resolve_calibration_target_provider(self): + if self.config.calibration_target_source == "arch": + if self.config.arch_targets_db is None: + raise ValueError( + "arch_targets_db is required when calibration_target_source='arch'" + ) + return ( + resolve_arch_sqlite_target_provider(self.config.arch_targets_db), + "arch", + ) + if self.config.policyengine_targets_db is None: + raise ValueError( + "policyengine_targets_db is required for PolicyEngine DB calibration" + ) + return ( + PolicyEngineUSDBTargetProvider(self.config.policyengine_targets_db), + "policyengine", + ) + def _load_policyengine_target_set( self, - provider: PolicyEngineUSDBTargetProvider, + provider: Any, *, bindings: dict[str, PolicyEngineUSVariableBinding], period: int, diff --git a/src/microplex_us/policyengine/target_profiles.py b/src/microplex_us/policyengine/target_profiles.py index 0b54941..2fd9b8c 100644 --- a/src/microplex_us/policyengine/target_profiles.py +++ b/src/microplex_us/policyengine/target_profiles.py @@ -23,17 +23,57 @@ def to_provider_filter(self) -> dict[str, str | None]: } +PolicyEngineUSTargetCellKey = tuple[str, str | None, str | None, str | None] + + +def _target_cell_key(cell: PolicyEngineUSTargetCell) -> PolicyEngineUSTargetCellKey: + return ( + cell.variable, + cell.geo_level, + cell.domain_variable, + cell.geographic_id, + ) + + PE_NATIVE_BROAD_TARGET_CELLS: tuple[PolicyEngineUSTargetCell, ...] = ( - PolicyEngineUSTargetCell("aca_ptc", geo_level="national", domain_variable="aca_ptc"), + PolicyEngineUSTargetCell( + "aca_ptc", geo_level="national", domain_variable="aca_ptc" + ), PolicyEngineUSTargetCell("adjusted_gross_income", geo_level="national"), + PolicyEngineUSTargetCell( + "adjusted_gross_income", + geo_level="national", + domain_variable="adjusted_gross_income", + ), + PolicyEngineUSTargetCell( + "adjusted_gross_income", + geo_level="national", + domain_variable="adjusted_gross_income,filing_status,income_tax_before_credits", + ), + PolicyEngineUSTargetCell( + "adjusted_gross_income", + geo_level="national", + domain_variable="adjusted_gross_income,income_tax_before_credits", + ), PolicyEngineUSTargetCell("alimony_expense", geo_level="national"), PolicyEngineUSTargetCell("alimony_income", geo_level="national"), PolicyEngineUSTargetCell("charitable_deduction", geo_level="national"), + PolicyEngineUSTargetCell("childcare_expenses", geo_level="national"), PolicyEngineUSTargetCell("child_support_expense", geo_level="national"), PolicyEngineUSTargetCell("child_support_received", geo_level="national"), - PolicyEngineUSTargetCell("dividend_income", geo_level="national", domain_variable="dividend_income"), + PolicyEngineUSTargetCell("deductible_mortgage_interest", geo_level="national"), + PolicyEngineUSTargetCell("dividend_income", geo_level="national"), + PolicyEngineUSTargetCell( + "dividend_income", geo_level="national", domain_variable="dividend_income" + ), + PolicyEngineUSTargetCell("employment_income", geo_level="national"), + PolicyEngineUSTargetCell( + "employment_income", geo_level="national", domain_variable="employment_income" + ), PolicyEngineUSTargetCell("eitc", geo_level="national"), - PolicyEngineUSTargetCell("eitc", geo_level="national", domain_variable="eitc_child_count"), + PolicyEngineUSTargetCell( + "eitc", geo_level="national", domain_variable="eitc_child_count" + ), PolicyEngineUSTargetCell( "eitc", geo_level="national", @@ -43,7 +83,14 @@ def to_provider_filter(self) -> dict[str, str | None]: "health_insurance_premiums_without_medicare_part_b", geo_level="national", ), - PolicyEngineUSTargetCell("income_tax", geo_level="national", domain_variable="income_tax"), + PolicyEngineUSTargetCell( + "household_count", + geo_level="national", + domain_variable="spm_unit_energy_subsidy_reported", + ), + PolicyEngineUSTargetCell( + "income_tax", geo_level="national", domain_variable="income_tax" + ), PolicyEngineUSTargetCell( "income_tax_before_credits", geo_level="national", @@ -58,16 +105,43 @@ def to_provider_filter(self) -> dict[str, str | None]: geo_level="national", domain_variable="medical_expense_deduction", ), + PolicyEngineUSTargetCell( + "medical_expense_deduction", + geo_level="national", + domain_variable="medical_expense_deduction,tax_unit_itemizes", + ), PolicyEngineUSTargetCell("medicare_part_b_premiums", geo_level="national"), - PolicyEngineUSTargetCell("net_capital_gains", geo_level="national", domain_variable="net_capital_gains"), + PolicyEngineUSTargetCell( + "net_capital_gains", geo_level="national", domain_variable="net_capital_gains" + ), PolicyEngineUSTargetCell("net_worth", geo_level="national"), + PolicyEngineUSTargetCell( + "non_refundable_ctc", + geo_level="national", + domain_variable="adjusted_gross_income,non_refundable_ctc", + ), + PolicyEngineUSTargetCell( + "non_refundable_ctc", + geo_level="national", + domain_variable="non_refundable_ctc", + ), PolicyEngineUSTargetCell("other_medical_expenses", geo_level="national"), PolicyEngineUSTargetCell("over_the_counter_health_expenses", geo_level="national"), - PolicyEngineUSTargetCell("person_count", geo_level="national", domain_variable="aca_ptc"), - PolicyEngineUSTargetCell("person_count", geo_level="national", domain_variable="age"), - PolicyEngineUSTargetCell("person_count", geo_level="national", domain_variable="medicaid"), - PolicyEngineUSTargetCell("person_count", geo_level="national", domain_variable="ssn_card_type"), - PolicyEngineUSTargetCell("qualified_business_income_deduction", geo_level="national"), + PolicyEngineUSTargetCell( + "person_count", geo_level="national", domain_variable="aca_ptc" + ), + PolicyEngineUSTargetCell( + "person_count", geo_level="national", domain_variable="age" + ), + PolicyEngineUSTargetCell( + "person_count", geo_level="national", domain_variable="medicaid" + ), + PolicyEngineUSTargetCell( + "person_count", geo_level="national", domain_variable="ssn_card_type" + ), + PolicyEngineUSTargetCell( + "qualified_business_income_deduction", geo_level="national" + ), PolicyEngineUSTargetCell( "qualified_business_income_deduction", geo_level="national", @@ -84,12 +158,38 @@ def to_provider_filter(self) -> dict[str, str | None]: geo_level="national", domain_variable="real_estate_taxes", ), - PolicyEngineUSTargetCell("refundable_ctc", geo_level="national", domain_variable="refundable_ctc"), + PolicyEngineUSTargetCell( + "real_estate_taxes", + geo_level="national", + domain_variable="real_estate_taxes,tax_unit_itemizes", + ), + PolicyEngineUSTargetCell( + "refundable_ctc", + geo_level="national", + domain_variable="adjusted_gross_income,refundable_ctc", + ), + PolicyEngineUSTargetCell( + "refundable_ctc", geo_level="national", domain_variable="refundable_ctc" + ), PolicyEngineUSTargetCell("rent", geo_level="national"), - PolicyEngineUSTargetCell("rental_income", geo_level="national", domain_variable="rental_income"), + PolicyEngineUSTargetCell("rental_income", geo_level="national"), + PolicyEngineUSTargetCell( + "rental_income", geo_level="national", domain_variable="rental_income" + ), + PolicyEngineUSTargetCell("roth_401k_contributions", geo_level="national"), PolicyEngineUSTargetCell("roth_ira_contributions", geo_level="national"), PolicyEngineUSTargetCell("salt", geo_level="national", domain_variable="salt"), + PolicyEngineUSTargetCell( + "salt", geo_level="national", domain_variable="salt,tax_unit_itemizes" + ), PolicyEngineUSTargetCell("salt_deduction", geo_level="national"), + PolicyEngineUSTargetCell( + "self_employed_pension_contribution_ald", geo_level="national" + ), + PolicyEngineUSTargetCell( + "self_employment_income", + geo_level="national", + ), PolicyEngineUSTargetCell( "self_employment_income", geo_level="national", @@ -102,20 +202,58 @@ def to_provider_filter(self) -> dict[str, str | None]: PolicyEngineUSTargetCell("social_security_retirement", geo_level="national"), PolicyEngineUSTargetCell("social_security_survivors", geo_level="national"), PolicyEngineUSTargetCell("spm_unit_capped_housing_subsidy", geo_level="national"), - PolicyEngineUSTargetCell("spm_unit_capped_work_childcare_expenses", geo_level="national"), + PolicyEngineUSTargetCell( + "spm_unit_capped_work_childcare_expenses", geo_level="national" + ), + PolicyEngineUSTargetCell( + "spm_unit_count", geo_level="national", domain_variable="tanf" + ), PolicyEngineUSTargetCell("ssi", geo_level="national"), PolicyEngineUSTargetCell("tanf", geo_level="national"), + PolicyEngineUSTargetCell("tanf", geo_level="national", domain_variable="tanf"), PolicyEngineUSTargetCell( "tax_exempt_interest_income", geo_level="national", domain_variable="tax_exempt_interest_income", ), - PolicyEngineUSTargetCell("tax_unit_count", geo_level="national", domain_variable="aca_ptc"), + PolicyEngineUSTargetCell( + "tax_unit_count", geo_level="national", domain_variable="aca_ptc" + ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="adjusted_gross_income", + ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="adjusted_gross_income,filing_status,income_tax_before_credits", + ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="adjusted_gross_income,income_tax_before_credits", + ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="adjusted_gross_income,non_refundable_ctc", + ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="adjusted_gross_income,refundable_ctc", + ), PolicyEngineUSTargetCell( "tax_unit_count", geo_level="national", domain_variable="dividend_income", ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="employment_income", + ), PolicyEngineUSTargetCell( "tax_unit_count", geo_level="national", @@ -126,7 +264,9 @@ def to_provider_filter(self) -> dict[str, str | None]: geo_level="national", domain_variable="adjusted_gross_income,eitc,eitc_child_count", ), - PolicyEngineUSTargetCell("tax_unit_count", geo_level="national", domain_variable="income_tax"), + PolicyEngineUSTargetCell( + "tax_unit_count", geo_level="national", domain_variable="income_tax" + ), PolicyEngineUSTargetCell( "tax_unit_count", geo_level="national", @@ -137,11 +277,21 @@ def to_provider_filter(self) -> dict[str, str | None]: geo_level="national", domain_variable="medical_expense_deduction", ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="medical_expense_deduction,tax_unit_itemizes", + ), PolicyEngineUSTargetCell( "tax_unit_count", geo_level="national", domain_variable="net_capital_gains", ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="non_refundable_ctc", + ), PolicyEngineUSTargetCell( "tax_unit_count", geo_level="national", @@ -157,13 +307,27 @@ def to_provider_filter(self) -> dict[str, str | None]: geo_level="national", domain_variable="real_estate_taxes", ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="real_estate_taxes,tax_unit_itemizes", + ), PolicyEngineUSTargetCell( "tax_unit_count", geo_level="national", domain_variable="refundable_ctc", ), - PolicyEngineUSTargetCell("tax_unit_count", geo_level="national", domain_variable="rental_income"), - PolicyEngineUSTargetCell("tax_unit_count", geo_level="national", domain_variable="salt"), + PolicyEngineUSTargetCell( + "tax_unit_count", geo_level="national", domain_variable="rental_income" + ), + PolicyEngineUSTargetCell( + "tax_unit_count", geo_level="national", domain_variable="salt" + ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="salt,tax_unit_itemizes", + ), PolicyEngineUSTargetCell( "tax_unit_count", geo_level="national", @@ -199,6 +363,11 @@ def to_provider_filter(self) -> dict[str, str | None]: geo_level="national", domain_variable="taxable_social_security", ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="total_self_employment_income", + ), PolicyEngineUSTargetCell( "tax_unit_count", geo_level="national", @@ -230,6 +399,12 @@ def to_provider_filter(self) -> dict[str, str | None]: domain_variable="taxable_social_security", ), PolicyEngineUSTargetCell("tip_income", geo_level="national"), + PolicyEngineUSTargetCell( + "total_self_employment_income", + geo_level="national", + domain_variable="total_self_employment_income", + ), + PolicyEngineUSTargetCell("traditional_401k_contributions", geo_level="national"), PolicyEngineUSTargetCell("traditional_ira_contributions", geo_level="national"), PolicyEngineUSTargetCell("unemployment_compensation", geo_level="national"), PolicyEngineUSTargetCell( @@ -237,12 +412,30 @@ def to_provider_filter(self) -> dict[str, str | None]: geo_level="national", domain_variable="unemployment_compensation", ), + PolicyEngineUSTargetCell("aca_ptc", geo_level="state", domain_variable=None), PolicyEngineUSTargetCell("aca_ptc", geo_level="state", domain_variable="aca_ptc"), PolicyEngineUSTargetCell("adjusted_gross_income", geo_level="state"), - PolicyEngineUSTargetCell("dividend_income", geo_level="state", domain_variable="dividend_income"), - PolicyEngineUSTargetCell("eitc", geo_level="state", domain_variable="eitc_child_count"), - PolicyEngineUSTargetCell("household_count", geo_level="state", domain_variable="snap"), - PolicyEngineUSTargetCell("income_tax", geo_level="state", domain_variable="income_tax"), + PolicyEngineUSTargetCell( + "adjusted_gross_income", + geo_level="state", + domain_variable="adjusted_gross_income", + ), + PolicyEngineUSTargetCell( + "dividend_income", geo_level="state", domain_variable="dividend_income" + ), + PolicyEngineUSTargetCell("employment_income", geo_level="state"), + PolicyEngineUSTargetCell( + "employment_income", geo_level="state", domain_variable="employment_income" + ), + PolicyEngineUSTargetCell( + "eitc", geo_level="state", domain_variable="eitc_child_count" + ), + PolicyEngineUSTargetCell( + "household_count", geo_level="state", domain_variable="snap" + ), + PolicyEngineUSTargetCell( + "income_tax", geo_level="state", domain_variable="income_tax" + ), PolicyEngineUSTargetCell( "income_tax_before_credits", geo_level="state", @@ -253,11 +446,37 @@ def to_provider_filter(self) -> dict[str, str | None]: geo_level="state", domain_variable="medical_expense_deduction", ), - PolicyEngineUSTargetCell("net_capital_gains", geo_level="state", domain_variable="net_capital_gains"), - PolicyEngineUSTargetCell("person_count", geo_level="state", domain_variable="aca_ptc"), - PolicyEngineUSTargetCell("person_count", geo_level="state", domain_variable="adjusted_gross_income"), + PolicyEngineUSTargetCell( + "medical_expense_deduction", + geo_level="state", + domain_variable="medical_expense_deduction,tax_unit_itemizes", + ), + PolicyEngineUSTargetCell( + "net_capital_gains", geo_level="state", domain_variable="net_capital_gains" + ), + PolicyEngineUSTargetCell( + "non_refundable_ctc", + geo_level="state", + domain_variable="non_refundable_ctc", + ), + PolicyEngineUSTargetCell( + "person_count", geo_level="state", domain_variable="aca_ptc" + ), + PolicyEngineUSTargetCell( + "person_count", + geo_level="state", + domain_variable="aca_ptc,is_aca_ptc_eligible", + ), + PolicyEngineUSTargetCell( + "person_count", geo_level="state", domain_variable="adjusted_gross_income" + ), PolicyEngineUSTargetCell("person_count", geo_level="state", domain_variable="age"), - PolicyEngineUSTargetCell("person_count", geo_level="state", domain_variable="medicaid_enrolled"), + PolicyEngineUSTargetCell( + "person_count", geo_level="state", domain_variable="is_pregnant" + ), + PolicyEngineUSTargetCell( + "person_count", geo_level="state", domain_variable="medicaid_enrolled" + ), PolicyEngineUSTargetCell( "qualified_business_income_deduction", geo_level="state", @@ -268,27 +487,62 @@ def to_provider_filter(self) -> dict[str, str | None]: geo_level="state", domain_variable="qualified_dividend_income", ), - PolicyEngineUSTargetCell("real_estate_taxes", geo_level="state", domain_variable="real_estate_taxes"), - PolicyEngineUSTargetCell("refundable_ctc", geo_level="state", domain_variable="refundable_ctc"), - PolicyEngineUSTargetCell("rental_income", geo_level="state", domain_variable="rental_income"), + PolicyEngineUSTargetCell( + "real_estate_taxes", geo_level="state", domain_variable="real_estate_taxes" + ), + PolicyEngineUSTargetCell( + "real_estate_taxes", + geo_level="state", + domain_variable="real_estate_taxes,tax_unit_itemizes", + ), + PolicyEngineUSTargetCell( + "refundable_ctc", geo_level="state", domain_variable="refundable_ctc" + ), + PolicyEngineUSTargetCell( + "rental_income", geo_level="state", domain_variable="rental_income" + ), PolicyEngineUSTargetCell("salt", geo_level="state", domain_variable="salt"), + PolicyEngineUSTargetCell( + "salt", geo_level="state", domain_variable="salt,tax_unit_itemizes" + ), + PolicyEngineUSTargetCell( + "self_employment_income", + geo_level="state", + ), PolicyEngineUSTargetCell( "self_employment_income", geo_level="state", domain_variable="self_employment_income", ), PolicyEngineUSTargetCell("snap", geo_level="state", domain_variable="snap"), + PolicyEngineUSTargetCell( + "spm_unit_count", geo_level="state", domain_variable="tanf" + ), PolicyEngineUSTargetCell("state_income_tax", geo_level="state"), + PolicyEngineUSTargetCell("tanf", geo_level="state", domain_variable="tanf"), PolicyEngineUSTargetCell( "tax_exempt_interest_income", geo_level="state", domain_variable="tax_exempt_interest_income", ), - PolicyEngineUSTargetCell("tax_unit_count", geo_level="state", domain_variable="aca_ptc"), - PolicyEngineUSTargetCell("tax_unit_count", geo_level="state", domain_variable="adjusted_gross_income"), - PolicyEngineUSTargetCell("tax_unit_count", geo_level="state", domain_variable="dividend_income"), - PolicyEngineUSTargetCell("tax_unit_count", geo_level="state", domain_variable="eitc_child_count"), - PolicyEngineUSTargetCell("tax_unit_count", geo_level="state", domain_variable="income_tax"), + PolicyEngineUSTargetCell( + "tax_unit_count", geo_level="state", domain_variable="aca_ptc" + ), + PolicyEngineUSTargetCell( + "tax_unit_count", geo_level="state", domain_variable="adjusted_gross_income" + ), + PolicyEngineUSTargetCell( + "tax_unit_count", geo_level="state", domain_variable="dividend_income" + ), + PolicyEngineUSTargetCell( + "tax_unit_count", geo_level="state", domain_variable="employment_income" + ), + PolicyEngineUSTargetCell( + "tax_unit_count", geo_level="state", domain_variable="eitc_child_count" + ), + PolicyEngineUSTargetCell( + "tax_unit_count", geo_level="state", domain_variable="income_tax" + ), PolicyEngineUSTargetCell( "tax_unit_count", geo_level="state", @@ -299,7 +553,17 @@ def to_provider_filter(self) -> dict[str, str | None]: geo_level="state", domain_variable="medical_expense_deduction", ), - PolicyEngineUSTargetCell("tax_unit_count", geo_level="state", domain_variable="net_capital_gains"), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="state", + domain_variable="medical_expense_deduction,tax_unit_itemizes", + ), + PolicyEngineUSTargetCell( + "tax_unit_count", geo_level="state", domain_variable="net_capital_gains" + ), + PolicyEngineUSTargetCell( + "tax_unit_count", geo_level="state", domain_variable="non_refundable_ctc" + ), PolicyEngineUSTargetCell( "tax_unit_count", geo_level="state", @@ -310,10 +574,33 @@ def to_provider_filter(self) -> dict[str, str | None]: geo_level="state", domain_variable="qualified_dividend_income", ), - PolicyEngineUSTargetCell("tax_unit_count", geo_level="state", domain_variable="real_estate_taxes"), - PolicyEngineUSTargetCell("tax_unit_count", geo_level="state", domain_variable="refundable_ctc"), - PolicyEngineUSTargetCell("tax_unit_count", geo_level="state", domain_variable="rental_income"), - PolicyEngineUSTargetCell("tax_unit_count", geo_level="state", domain_variable="salt"), + PolicyEngineUSTargetCell( + "tax_unit_count", geo_level="state", domain_variable="real_estate_taxes" + ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="state", + domain_variable="real_estate_taxes,tax_unit_itemizes", + ), + PolicyEngineUSTargetCell( + "tax_unit_count", geo_level="state", domain_variable="refundable_ctc" + ), + PolicyEngineUSTargetCell( + "tax_unit_count", geo_level="state", domain_variable="rental_income" + ), + PolicyEngineUSTargetCell( + "tax_unit_count", geo_level="state", domain_variable="salt" + ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="state", + domain_variable="salt,tax_unit_itemizes", + ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="state", + domain_variable="selected_marketplace_plan_benchmark_ratio,used_aca_ptc", + ), PolicyEngineUSTargetCell( "tax_unit_count", geo_level="state", @@ -349,11 +636,19 @@ def to_provider_filter(self) -> dict[str, str | None]: geo_level="state", domain_variable="taxable_social_security", ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="state", + domain_variable="total_self_employment_income", + ), PolicyEngineUSTargetCell( "tax_unit_count", geo_level="state", domain_variable="unemployment_compensation", ), + PolicyEngineUSTargetCell( + "tax_unit_count", geo_level="state", domain_variable="used_aca_ptc" + ), PolicyEngineUSTargetCell( "tax_unit_partnership_s_corp_income", geo_level="state", @@ -379,6 +674,11 @@ def to_provider_filter(self) -> dict[str, str | None]: geo_level="state", domain_variable="taxable_social_security", ), + PolicyEngineUSTargetCell( + "total_self_employment_income", + geo_level="state", + domain_variable="total_self_employment_income", + ), PolicyEngineUSTargetCell( "unemployment_compensation", geo_level="state", @@ -388,26 +688,199 @@ def to_provider_filter(self) -> dict[str, str | None]: _PE_NATIVE_BROAD_NO_STATE_ACA_EXCLUDED_CELLS = frozenset( { + ("aca_ptc", "state", None, None), ("aca_ptc", "state", "aca_ptc", None), + ("person_count", "state", "aca_ptc", None), + ("person_count", "state", "aca_ptc,is_aca_ptc_eligible", None), ("tax_unit_count", "state", "aca_ptc", None), + ( + "tax_unit_count", + "state", + "selected_marketplace_plan_benchmark_ratio,used_aca_ptc", + None, + ), + ("tax_unit_count", "state", "used_aca_ptc", None), } ) PE_NATIVE_BROAD_NO_STATE_ACA_TARGET_CELLS: tuple[PolicyEngineUSTargetCell, ...] = tuple( cell for cell in PE_NATIVE_BROAD_TARGET_CELLS - if ( - cell.variable, - cell.geo_level, - cell.domain_variable, - cell.geographic_id, - ) - not in _PE_NATIVE_BROAD_NO_STATE_ACA_EXCLUDED_CELLS + if _target_cell_key(cell) not in _PE_NATIVE_BROAD_NO_STATE_ACA_EXCLUDED_CELLS +) + +PE_NATIVE_BROAD_SOURCE_BACKED_EXCLUDED_CELL_REASONS: dict[ + PolicyEngineUSTargetCellKey, + str, +] = { + ( + "adjusted_gross_income", + "national", + "adjusted_gross_income,filing_status,income_tax_before_credits", + None, + ): ( + "SOI source packages currently loaded by Arch do not publish adjusted " + "gross income jointly by AGI band, filing status, and returns with " + "positive income tax before credits." + ), + ( + "adjusted_gross_income", + "national", + "adjusted_gross_income,income_tax_before_credits", + None, + ): ( + "SOI source packages currently loaded by Arch publish AGI bands and " + "income-tax-before-credits returns separately, not AGI amounts " + "restricted to returns with positive income tax before credits." + ), + ( + "tax_unit_count", + "national", + "adjusted_gross_income,filing_status,income_tax_before_credits", + None, + ): ( + "SOI Historic Table 2 does not provide the full AGI by filing-status " + "by positive-income-tax-before-credits joint count required by this " + "PolicyEngine cell." + ), + ( + "person_count", + "national", + "ssn_card_type", + None, + ): ( + "PolicyEngine ssn_card_type is a modeled legal-status input; no " + "accepted primary aggregate source mapping is encoded for Arch." + ), + ( + "person_count", + "state", + "is_pregnant", + None, + ): ( + "The PolicyEngine cell is a pregnancy stock by state; live births are " + "a flow and are not a defensible direct source fact for this target." + ), + ( + "alimony_expense", + "national", + None, + None, + ): ( + "No accepted primary source mapping is encoded for this " + "survey/model-input expense variable." + ), + ( + "child_support_expense", + "national", + None, + None, + ): ( + "No accepted primary source mapping is encoded for this " + "survey/model-input expense variable." + ), + ( + "child_support_received", + "national", + None, + None, + ): ( + "No accepted primary source mapping is encoded for this " + "survey/model-input receipt variable." + ), + ( + "childcare_expenses", + "national", + None, + None, + ): ( + "IRS child-care credit expenses and W-2 dependent-care benefits are " + "narrower tax concepts than PolicyEngine childcare_expenses, so they " + "are not treated as source-equivalent." + ), + ( + "health_insurance_premiums_without_medicare_part_b", + "national", + None, + None, + ): ( + "This premium component is a modeled/survey input; no accepted primary " + "aggregate source mapping is encoded for Arch." + ), + ( + "other_medical_expenses", + "national", + None, + None, + ): ( + "This out-of-pocket medical expense component is a survey/model input " + "without an accepted primary aggregate source mapping." + ), + ( + "over_the_counter_health_expenses", + "national", + None, + None, + ): ( + "This out-of-pocket medical expense component is a survey/model input " + "without an accepted primary aggregate source mapping." + ), + ( + "rent", + "national", + None, + None, + ): ( + "PolicyEngine rent is a household survey/model input; ACS rent tables " + "do not provide a direct aggregate source fact for this exact variable." + ), + ( + "spm_unit_capped_housing_subsidy", + "national", + None, + None, + ): ( + "This is a capped SPM model amount rather than a direct publisher " + "source fact." + ), + ( + "spm_unit_capped_work_childcare_expenses", + "national", + None, + None, + ): ( + "This is a capped SPM model amount rather than a direct publisher " + "source fact." + ), +} + +PE_NATIVE_BROAD_SOURCE_BACKED_TARGET_CELLS: tuple[ + PolicyEngineUSTargetCell, ... +] = tuple( + cell + for cell in PE_NATIVE_BROAD_TARGET_CELLS + if _target_cell_key(cell) + not in PE_NATIVE_BROAD_SOURCE_BACKED_EXCLUDED_CELL_REASONS ) _TARGET_PROFILES: dict[str, tuple[PolicyEngineUSTargetCell, ...]] = { "pe_native_broad": PE_NATIVE_BROAD_TARGET_CELLS, "pe_native_broad_no_state_aca": PE_NATIVE_BROAD_NO_STATE_ACA_TARGET_CELLS, + "pe_native_broad_source_backed": PE_NATIVE_BROAD_SOURCE_BACKED_TARGET_CELLS, +} + +_TARGET_PROFILE_EXCLUSION_REASONS: dict[ + str, + dict[PolicyEngineUSTargetCellKey, str], +] = { + "pe_native_broad": {}, + "pe_native_broad_no_state_aca": { + cell_key: "State ACA cells are excluded from this profile variant." + for cell_key in _PE_NATIVE_BROAD_NO_STATE_ACA_EXCLUDED_CELLS + }, + "pe_native_broad_source_backed": ( + PE_NATIVE_BROAD_SOURCE_BACKED_EXCLUDED_CELL_REASONS + ), } @@ -425,3 +898,14 @@ def resolve_policyengine_us_target_profile( raise ValueError( f"Unknown PolicyEngine US target profile '{name}'. Known profiles: {known}" ) from exc + + +def policyengine_us_target_profile_exclusion_reasons( + name: str, +) -> dict[PolicyEngineUSTargetCellKey, str]: + if name not in _TARGET_PROFILES: + known = ", ".join(policyengine_us_target_profile_names()) + raise ValueError( + f"Unknown PolicyEngine US target profile '{name}'. Known profiles: {known}" + ) + return dict(_TARGET_PROFILE_EXCLUSION_REASONS.get(name, {})) diff --git a/src/microplex_us/targets/__init__.py b/src/microplex_us/targets/__init__.py index d65d841..33769e1 100644 --- a/src/microplex_us/targets/__init__.py +++ b/src/microplex_us/targets/__init__.py @@ -1,10 +1,54 @@ """US-specific target mappings.""" +from microplex_us.targets.aca_ptc import ( + ACA_AVERAGE_MONTHLY_APTC_CONCEPT, + ACA_MARKETPLACE_EFFECTUATED_ENROLLMENT_CONCEPT, + ACAPTCBaseAPTCPolicy, + ACAPTCMultiplierInput, + ACAPTCMultiplierRow, + aca_ptc_multiplier_inputs_from_arch_consumer_facts, + build_aca_ptc_multiplier_rows, + load_arch_consumer_fact_jsonl_rows, + write_policyengine_aca_ptc_multiplier_csv, +) from microplex_us.targets.adapters import ( POLICYENGINE_US_COUNT_ENTITIES, policyengine_db_target_to_canonical_spec, policyengine_db_targets_to_canonical_set, ) +from microplex_us.targets.arch import ( + ArchCompositeSQLiteTargetProvider, + ArchConsumerFactJSONLTargetProvider, + ArchFactSQLiteTargetProvider, + ArchSQLiteTargetProvider, + ArchTargetCellCoverage, + ArchTargetGapQueueReport, + ArchTargetGapQueueRow, + ArchTargetParityReport, + ArchTargetParityRow, + ArchTargetProfileCoverageReport, + ArchTargetRecord, + SOIAgingFactors, + arch_target_record_to_canonical_spec, + resolve_arch_sqlite_target_provider, + summarize_arch_target_gap_queue, + summarize_arch_target_parity, + summarize_arch_target_profile_coverage, +) +from microplex_us.targets.census_blocks import ( + CENSUS_BLOCK_GEOGRAPHY_YEAR, + CENSUS_BLOCK_POPULATION_GEO_LEVELS, + CENSUS_BLOCK_POPULATION_ROLLUPS, + CENSUS_BLOCK_POPULATION_SOURCE, + CENSUS_BLOCK_POPULATION_UNITS, + CENSUS_BLOCK_POPULATION_VARIABLE, + CENSUS_BLOCK_SOURCE_YEAR, + CENSUS_BLOCK_TARGET_PERIOD, + DEFAULT_CENSUS_BLOCK_POPULATION_GEO_LEVELS, + CensusBlockPopulationRollup, + CensusBlockPopulationTargetProvider, + build_census_block_population_targets, +) from microplex_us.targets.rac_mapping import ( MICRODATA_TO_RAC, POLICYENGINE_TO_RAC, @@ -16,9 +60,47 @@ ) __all__ = [ + "ArchTargetCellCoverage", + "ArchTargetProfileCoverageReport", + "ArchSQLiteTargetProvider", + "ArchCompositeSQLiteTargetProvider", + "ArchConsumerFactJSONLTargetProvider", + "ArchFactSQLiteTargetProvider", + "ArchTargetRecord", + "ArchTargetGapQueueReport", + "ArchTargetGapQueueRow", + "ArchTargetParityReport", + "ArchTargetParityRow", "POLICYENGINE_US_COUNT_ENTITIES", + "CENSUS_BLOCK_GEOGRAPHY_YEAR", + "CENSUS_BLOCK_POPULATION_GEO_LEVELS", + "CENSUS_BLOCK_POPULATION_ROLLUPS", + "CENSUS_BLOCK_POPULATION_SOURCE", + "CENSUS_BLOCK_POPULATION_UNITS", + "CENSUS_BLOCK_POPULATION_VARIABLE", + "CENSUS_BLOCK_SOURCE_YEAR", + "CENSUS_BLOCK_TARGET_PERIOD", + "DEFAULT_CENSUS_BLOCK_POPULATION_GEO_LEVELS", + "CensusBlockPopulationRollup", + "CensusBlockPopulationTargetProvider", + "SOIAgingFactors", + "arch_target_record_to_canonical_spec", + "build_census_block_population_targets", + "summarize_arch_target_gap_queue", + "summarize_arch_target_parity", + "summarize_arch_target_profile_coverage", "policyengine_db_target_to_canonical_spec", "policyengine_db_targets_to_canonical_set", + "resolve_arch_sqlite_target_provider", + "ACA_AVERAGE_MONTHLY_APTC_CONCEPT", + "ACA_MARKETPLACE_EFFECTUATED_ENROLLMENT_CONCEPT", + "ACAPTCBaseAPTCPolicy", + "ACAPTCMultiplierInput", + "ACAPTCMultiplierRow", + "aca_ptc_multiplier_inputs_from_arch_consumer_facts", + "build_aca_ptc_multiplier_rows", + "load_arch_consumer_fact_jsonl_rows", + "write_policyengine_aca_ptc_multiplier_csv", "RACVariable", "RAC_VARIABLE_MAP", "POLICYENGINE_TO_RAC", diff --git a/src/microplex_us/targets/aca_ptc.py b/src/microplex_us/targets/aca_ptc.py new file mode 100644 index 0000000..9f61623 --- /dev/null +++ b/src/microplex_us/targets/aca_ptc.py @@ -0,0 +1,465 @@ +"""ACA PTC target-construction helpers for US target sources.""" + +from __future__ import annotations + +import argparse +import csv +from collections.abc import Iterable, Mapping +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal + +from microplex.targets import ( + arch_consumer_fact_concept, + arch_consumer_fact_numeric_value, + arch_consumer_fact_period, + arch_consumer_fact_source_record_id, + load_arch_consumer_fact_jsonl_rows, +) + +ACAPTCBaseAPTCPolicy = Literal[ + "oep", + "effectuated", + "oep_with_effectuated_fallback", +] + +ACA_MARKETPLACE_EFFECTUATED_ENROLLMENT_CONCEPT = ( + "cms_aca.marketplace_effectuated_enrollment" +) +ACA_AVERAGE_MONTHLY_APTC_CONCEPT = "cms_aca.average_monthly_aptc" + + +@dataclass(frozen=True) +class ACAPTCMultiplierInput: + """Publisher-source inputs for one state's ACA PTC multiplier row.""" + + state: str + enroll_base: float + enroll_target: float + aptc_base: float + aptc_target: float + base_year: int = 2022 + target_year: int = 2024 + enroll_base_source_record_id: str | None = None + enroll_target_source_record_id: str | None = None + aptc_base_source_record_id: str | None = None + aptc_target_source_record_id: str | None = None + aptc_base_source_kind: str | None = None + aptc_target_source_kind: str | None = None + + +@dataclass(frozen=True) +class ACAPTCMultiplierRow: + """PE-compatible ACA PTC multiplier row for one state.""" + + state: str + enroll_base: float + enroll_target: float + vol_mult: float + aptc_base: float + aptc_target: float + val_mult: float + base_year: int = 2022 + target_year: int = 2024 + enroll_base_source_record_id: str | None = None + enroll_target_source_record_id: str | None = None + aptc_base_source_record_id: str | None = None + aptc_target_source_record_id: str | None = None + aptc_base_source_kind: str | None = None + aptc_target_source_kind: str | None = None + + @property + def amount_mult(self) -> float: + """Multiplier PE applies to the ACA PTC amount target.""" + + return self.vol_mult * self.val_mult + + def target_factors(self) -> dict[str, float]: + """Return the variable factors consumed by PE's state uprating path.""" + + return { + "tax_unit_count": self.vol_mult, + "aca_ptc": self.amount_mult, + } + + def to_policyengine_csv_row(self) -> dict[str, float | int | str]: + """Return a row with PE's incumbent ACA multiplier CSV column names.""" + + return { + "state": self.state, + f"enroll_{self.base_year}": _source_csv_number(self.enroll_base), + f"enroll_{self.target_year}": _source_csv_number(self.enroll_target), + "vol_mult": self.vol_mult, + f"aptc_{self.base_year}": _source_csv_number(self.aptc_base), + f"aptc_{self.target_year}": _source_csv_number(self.aptc_target), + "val_mult": self.val_mult, + } + + +@dataclass(frozen=True) +class _ACAStateFact: + state: str + period: int + value: float + concept: str + source_record_id: str | None + source_kind: str | None + + +def build_aca_ptc_multiplier_rows( + inputs: Iterable[ACAPTCMultiplierInput], +) -> tuple[ACAPTCMultiplierRow, ...]: + """Build state ACA PTC multiplier rows from explicit source inputs.""" + + rows = [] + for item in inputs: + _validate_positive_source_value(item.enroll_base, "enroll_base", item.state) + _validate_positive_source_value(item.enroll_target, "enroll_target", item.state) + _validate_positive_source_value(item.aptc_base, "aptc_base", item.state) + _validate_positive_source_value(item.aptc_target, "aptc_target", item.state) + rows.append( + ACAPTCMultiplierRow( + state=item.state, + base_year=item.base_year, + target_year=item.target_year, + enroll_base=item.enroll_base, + enroll_target=item.enroll_target, + vol_mult=item.enroll_target / item.enroll_base, + aptc_base=item.aptc_base, + aptc_target=item.aptc_target, + val_mult=item.aptc_target / item.aptc_base, + enroll_base_source_record_id=item.enroll_base_source_record_id, + enroll_target_source_record_id=item.enroll_target_source_record_id, + aptc_base_source_record_id=item.aptc_base_source_record_id, + aptc_target_source_record_id=item.aptc_target_source_record_id, + aptc_base_source_kind=item.aptc_base_source_kind, + aptc_target_source_kind=item.aptc_target_source_kind, + ) + ) + return tuple(sorted(rows, key=lambda row: row.state)) + + +def aca_ptc_multiplier_inputs_from_arch_consumer_facts( + rows: Iterable[Mapping[str, Any]], + *, + base_year: int = 2022, + target_year: int = 2024, + base_aptc_policy: ACAPTCBaseAPTCPolicy = "oep_with_effectuated_fallback", +) -> tuple[ACAPTCMultiplierInput, ...]: + """Collect PE-style ACA PTC multiplier inputs from Arch consumer facts. + + The publisher-source recipe uses KFF full-year effectuated enrollment for + the volume ratio, CMS OEP average APTC where available for the base-year + value ratio base, CMS full-year 2022 APTC as the fallback for missing OEP + state values, and CMS OEP average APTC for the target-year value ratio. + """ + + enrollment: dict[tuple[int, str], _ACAStateFact] = {} + oep_aptc: dict[tuple[int, str], _ACAStateFact] = {} + effectuated_aptc: dict[tuple[int, str], _ACAStateFact] = {} + + for row in rows: + fact = _aca_state_fact_from_arch_consumer_fact(row) + if fact is None: + continue + key = (fact.period, fact.state) + if fact.concept == ACA_MARKETPLACE_EFFECTUATED_ENROLLMENT_CONCEPT: + enrollment[key] = fact + elif fact.concept == ACA_AVERAGE_MONTHLY_APTC_CONCEPT: + if fact.source_kind == "oep": + oep_aptc[key] = fact + elif fact.source_kind == "effectuated": + effectuated_aptc[key] = fact + + states = sorted( + { + state + for period, state in enrollment + if period == base_year and (target_year, state) in enrollment + } + ) + inputs = [] + missing: list[str] = [] + for state in states: + enroll_base = enrollment[(base_year, state)] + enroll_target = enrollment[(target_year, state)] + aptc_base = _select_base_aptc_fact( + state, + base_year=base_year, + policy=base_aptc_policy, + oep_aptc=oep_aptc, + effectuated_aptc=effectuated_aptc, + ) + aptc_target = oep_aptc.get((target_year, state)) + if aptc_base is None: + missing.append(f"{state} {base_year} average APTC") + continue + if aptc_target is None: + missing.append(f"{state} {target_year} OEP average APTC") + continue + inputs.append( + ACAPTCMultiplierInput( + state=state, + base_year=base_year, + target_year=target_year, + enroll_base=enroll_base.value, + enroll_target=enroll_target.value, + aptc_base=aptc_base.value, + aptc_target=aptc_target.value, + enroll_base_source_record_id=enroll_base.source_record_id, + enroll_target_source_record_id=enroll_target.source_record_id, + aptc_base_source_record_id=aptc_base.source_record_id, + aptc_target_source_record_id=aptc_target.source_record_id, + aptc_base_source_kind=aptc_base.source_kind, + aptc_target_source_kind=aptc_target.source_kind, + ) + ) + + if missing: + preview = ", ".join(missing[:5]) + suffix = "" if len(missing) <= 5 else f", and {len(missing) - 5} more" + raise ValueError(f"Missing ACA PTC source facts: {preview}{suffix}") + return tuple(inputs) + + +def write_policyengine_aca_ptc_multiplier_csv( + rows: Iterable[ACAPTCMultiplierRow], + path: str | Path, +) -> None: + """Write PE-compatible ACA PTC multiplier rows.""" + + rows = tuple(rows) + if not rows: + raise ValueError("Cannot write ACA PTC multiplier CSV with no rows.") + year_pairs = {(row.base_year, row.target_year) for row in rows} + if len(year_pairs) != 1: + raise ValueError("ACA PTC multiplier CSV rows must use one year pair.") + base_year, target_year = next(iter(year_pairs)) + fieldnames = [ + "state", + f"enroll_{base_year}", + f"enroll_{target_year}", + "vol_mult", + f"aptc_{base_year}", + f"aptc_{target_year}", + "val_mult", + ] + with Path(path).open("w", newline="") as file: + writer = csv.DictWriter(file, fieldnames=fieldnames) + writer.writeheader() + for row in rows: + writer.writerow(row.to_policyengine_csv_row()) + + +def main(argv: list[str] | None = None) -> int: + """Build a PE-compatible ACA PTC multiplier CSV from Arch consumer facts.""" + + parser = argparse.ArgumentParser( + description=( + "Build a PE-compatible ACA PTC multiplier CSV from Arch " + "consumer_facts.jsonl files." + ) + ) + parser.add_argument( + "consumer_facts", + nargs="+", + help="Arch consumer_facts.jsonl path(s) containing ACA source facts.", + ) + parser.add_argument( + "--out", + required=True, + help="Output CSV path.", + ) + parser.add_argument( + "--base-year", + type=int, + default=2022, + help="Source year for the multiplier denominator.", + ) + parser.add_argument( + "--target-year", + type=int, + default=2024, + help="Target year for the multiplier numerator.", + ) + parser.add_argument( + "--base-aptc-policy", + choices=("oep", "effectuated", "oep_with_effectuated_fallback"), + default="oep_with_effectuated_fallback", + help="Source selection policy for base-year average monthly APTC.", + ) + args = parser.parse_args(argv) + + consumer_fact_rows = load_arch_consumer_fact_jsonl_rows(args.consumer_facts) + inputs = aca_ptc_multiplier_inputs_from_arch_consumer_facts( + consumer_fact_rows, + base_year=args.base_year, + target_year=args.target_year, + base_aptc_policy=args.base_aptc_policy, + ) + rows = build_aca_ptc_multiplier_rows(inputs) + write_policyengine_aca_ptc_multiplier_csv(rows, args.out) + print(f"Wrote {len(rows)} ACA PTC multiplier rows to {args.out}") + return 0 + + +def _select_base_aptc_fact( + state: str, + *, + base_year: int, + policy: ACAPTCBaseAPTCPolicy, + oep_aptc: Mapping[tuple[int, str], _ACAStateFact], + effectuated_aptc: Mapping[tuple[int, str], _ACAStateFact], +) -> _ACAStateFact | None: + key = (base_year, state) + if policy == "oep": + return oep_aptc.get(key) + if policy == "effectuated": + return effectuated_aptc.get(key) + if policy == "oep_with_effectuated_fallback": + return oep_aptc.get(key) or effectuated_aptc.get(key) + raise ValueError(f"Unsupported ACA PTC base APTC policy: {policy}") + + +def _aca_state_fact_from_arch_consumer_fact( + row: Mapping[str, Any], +) -> _ACAStateFact | None: + concept = _arch_consumer_fact_concept(row) + if concept not in { + ACA_MARKETPLACE_EFFECTUATED_ENROLLMENT_CONCEPT, + ACA_AVERAGE_MONTHLY_APTC_CONCEPT, + }: + return None + geography = _mapping(row.get("geography")) + if str(geography.get("level") or "").lower() != "state": + return None + state = _arch_consumer_fact_state(row, geography) + if not state: + return None + return _ACAStateFact( + state=state, + period=_arch_consumer_fact_period(row), + value=_json_numeric_value(row.get("value")), + concept=concept, + source_record_id=_arch_consumer_fact_source_record_id(row), + source_kind=_aca_source_kind(row), + ) + + +def _arch_consumer_fact_concept(row: Mapping[str, Any]) -> str | None: + return arch_consumer_fact_concept(row) + + +def _arch_consumer_fact_period(row: Mapping[str, Any]) -> int: + return arch_consumer_fact_period(row) + + +def _arch_consumer_fact_state( + row: Mapping[str, Any], + geography: Mapping[str, Any], +) -> str | None: + name = geography.get("name") + if name: + return str(name) + source_record_id = _arch_consumer_fact_source_record_id(row) or "" + for token in source_record_id.split("."): + state = _STATE_ABBR_TO_NAME.get(token.lower()) + if state is not None: + return state + return None + + +def _arch_consumer_fact_source_record_id(row: Mapping[str, Any]) -> str | None: + source_record_id = arch_consumer_fact_source_record_id(row) + if source_record_id is not None: + return source_record_id + fallback = row.get("source_record_id") + return str(fallback) if fallback else None + + +def _aca_source_kind(row: Mapping[str, Any]) -> str | None: + source_record_id = (_arch_consumer_fact_source_record_id(row) or "").lower() + if ".oep" in source_record_id: + return "oep" + if ".effectuated_enrollment." in source_record_id: + return "effectuated" + source = _mapping(row.get("source")) + source_table = str(source.get("source_table") or "").lower() + if "open enrollment" in source_table or "oep" in source_table: + return "oep" + if "effectuated enrollment" in source_table: + return "effectuated" + return None + + +def _validate_positive_source_value(value: float, label: str, state: str) -> None: + if value <= 0: + raise ValueError(f"{state} {label} must be positive; got {value}.") + + +def _json_numeric_value(value: Any) -> float: + return arch_consumer_fact_numeric_value(value) + + +def _source_csv_number(value: float) -> float | int: + numeric = float(value) + return int(numeric) if numeric.is_integer() else numeric + + +def _mapping(value: Any) -> Mapping[str, Any]: + return value if isinstance(value, Mapping) else {} + + +_STATE_ABBR_TO_NAME = { + "ak": "Alaska", + "al": "Alabama", + "ar": "Arkansas", + "az": "Arizona", + "ca": "California", + "co": "Colorado", + "ct": "Connecticut", + "dc": "District of Columbia", + "de": "Delaware", + "fl": "Florida", + "ga": "Georgia", + "hi": "Hawaii", + "ia": "Iowa", + "id": "Idaho", + "il": "Illinois", + "in": "Indiana", + "ks": "Kansas", + "ky": "Kentucky", + "la": "Louisiana", + "ma": "Massachusetts", + "md": "Maryland", + "me": "Maine", + "mi": "Michigan", + "mn": "Minnesota", + "mo": "Missouri", + "ms": "Mississippi", + "mt": "Montana", + "nc": "North Carolina", + "nd": "North Dakota", + "ne": "Nebraska", + "nh": "New Hampshire", + "nj": "New Jersey", + "nm": "New Mexico", + "nv": "Nevada", + "ny": "New York", + "oh": "Ohio", + "ok": "Oklahoma", + "or": "Oregon", + "pa": "Pennsylvania", + "ri": "Rhode Island", + "sc": "South Carolina", + "sd": "South Dakota", + "tn": "Tennessee", + "tx": "Texas", + "ut": "Utah", + "va": "Virginia", + "vt": "Vermont", + "wa": "Washington", + "wi": "Wisconsin", + "wv": "West Virginia", + "wy": "Wyoming", +} diff --git a/src/microplex_us/targets/adapters.py b/src/microplex_us/targets/adapters.py index 35f4b6d..ef8fe80 100644 --- a/src/microplex_us/targets/adapters.py +++ b/src/microplex_us/targets/adapters.py @@ -14,7 +14,11 @@ TargetSpec as CanonicalTargetSpec, ) -from microplex_us.policyengine.us import PolicyEngineUSDBTarget +from microplex_us.microdata_roles import policyengine_us_variable_role +from microplex_us.policyengine.us import ( + PolicyEngineUSConstraint, + PolicyEngineUSDBTarget, +) POLICYENGINE_US_COUNT_ENTITIES: dict[str, EntityType] = { "household_count": EntityType.HOUSEHOLD, @@ -24,6 +28,8 @@ "family_count": EntityType.FAMILY, } +POLICYENGINE_US_ACTUAL_ACA_PTC_VARIABLE = "assigned_aca_ptc" + def policyengine_db_target_to_canonical_spec( target: PolicyEngineUSDBTarget, @@ -47,13 +53,11 @@ def policyengine_db_target_to_canonical_spec( if target.variable.endswith("_count") else TargetAggregation.SUM ) - measure = None if aggregation is TargetAggregation.COUNT else target.variable + measure_variable = _policyengine_db_target_measure_variable(target) + measure = None if aggregation is TargetAggregation.COUNT else measure_variable + model_variable = measure_variable if measure is not None else target.variable filters = tuple( - TargetFilter( - feature=constraint.variable, - operator=constraint.operation, - value=constraint.value, - ) + _policyengine_db_constraint_to_target_filter(target, constraint) for constraint in target.constraints ) @@ -80,11 +84,43 @@ def policyengine_db_target_to_canonical_spec( "geographic_id": target.geographic_id, "domain_variable": target.domain_variable, "domain_variables": target.domain_variables, + "model_variable_role": policyengine_us_variable_role(model_variable).value, + "target_semantic": ( + "count" if aggregation is TargetAggregation.COUNT else "amount" + ), "constraint_count": len(target.constraints), }, ) +def _policyengine_db_target_uses_aca_ptc(target: PolicyEngineUSDBTarget) -> bool: + return ( + target.variable == "aca_ptc" + or "aca_ptc" in target.domain_variables + or any(constraint.variable == "aca_ptc" for constraint in target.constraints) + ) + + +def _policyengine_db_target_measure_variable(target: PolicyEngineUSDBTarget) -> str: + if target.variable == "aca_ptc": + return POLICYENGINE_US_ACTUAL_ACA_PTC_VARIABLE + return target.variable + + +def _policyengine_db_constraint_to_target_filter( + target: PolicyEngineUSDBTarget, + constraint: PolicyEngineUSConstraint, +) -> TargetFilter: + feature = constraint.variable + if feature == "aca_ptc" and _policyengine_db_target_uses_aca_ptc(target): + feature = POLICYENGINE_US_ACTUAL_ACA_PTC_VARIABLE + return TargetFilter( + feature=feature, + operator=constraint.operation, + value=constraint.value, + ) + + def policyengine_db_targets_to_canonical_set( targets: Iterable[PolicyEngineUSDBTarget], *, diff --git a/src/microplex_us/targets/arch.py b/src/microplex_us/targets/arch.py new file mode 100644 index 0000000..62155d0 --- /dev/null +++ b/src/microplex_us/targets/arch.py @@ -0,0 +1,6212 @@ +"""Adapters from Arch target records to core Microplex target specs.""" + +from __future__ import annotations + +import json +import sqlite3 +from collections import Counter +from dataclasses import dataclass, replace +from hashlib import sha1 +from pathlib import Path +from typing import Any + +from microplex.core import EntityType +from microplex.targets import ( + TargetAggregation, + TargetFilter, + TargetQuery, + TargetSet, + apply_target_query, + arch_consumer_fact_concept, + arch_consumer_fact_numeric_value, + arch_consumer_fact_period, + arch_consumer_fact_source_record_id, + load_arch_consumer_fact_jsonl_rows, +) +from microplex.targets import ( + TargetSpec as CanonicalTargetSpec, +) + +from microplex_us.geography import ( + US_STATE_ABBR_BY_FIPS, + normalize_state_legislative_district_id, +) +from microplex_us.microdata_roles import policyengine_us_variable_role +from microplex_us.policyengine.target_profiles import ( + PolicyEngineUSTargetCell, + resolve_policyengine_us_target_profile, +) + +ARCH_SOURCE_ALIASES = { + "bea": "BEA", + "bea-nipa": "BEA", + "bea-regional": "BEA", + "census-decennial": "CENSUS_DECENNIAL", + "irs-soi": "IRS_SOI", + "census-acs": "CENSUS_ACS", + "census-pep": "CENSUS_PEP", + "census-stc": "CENSUS_STC", + "usda-snap": "USDA_SNAP", + "cms-aca": "CMS_ACA", + "cms-medicare": "CMS_MEDICARE", + "cms-medicaid": "CMS_MEDICAID", + "federal-reserve": "FEDERAL_RESERVE", + "hhs-acf-liheap": "HHS_ACF_LIHEAP", + "hhs-acf-tanf": "HHS_ACF_TANF", +} + +ARCH_CONSTRAINT_VARIABLE_ALIASES = { + "eitc_qualifying_children": "eitc_child_count", + "is_tax_filer": "tax_unit_is_filer", +} + +ARCH_POSITIVE_CONSTRAINT_ALIASES = { + "aca": "aca_ptc", + "aca_marketplace": "aca_ptc", + "aca_ptc": "aca_ptc", + "is_aca_ptc_eligible": "aca_ptc", + "selected_marketplace_plan_benchmark_ratio": "aca_ptc", + "total_self_employment_income": "self_employment_income", + "used_aca_ptc": "aca_ptc", + "is_medicaid": "medicaid_enrolled", + "medicaid": "medicaid_enrolled", + "medicaid_enrolled": "medicaid_enrolled", + "snap": "snap", +} + +ARCH_CONSTRAINT_OPERATOR_ALIASES = { + "=": "==", + "eq": "==", + "<>": "!=", + "ne": "!=", + "neq": "!=", +} + +ARCH_AMOUNT_VARIABLE_ALIASES = { + "adjusted_gross_income": "adjusted_gross_income", + "income_tax_liability": "income_tax", + "income_tax_before_credits_amount": "income_tax_before_credits", + "eitc_amount": "eitc", + "ctc_amount": "non_refundable_ctc", + "actc_amount": "refundable_ctc", + "taxable_interest_amount": "taxable_interest_income", + "tax_exempt_interest_amount": "tax_exempt_interest_income", + "alimony_received_amount": "alimony_income", + "personal_dividend_income_amount": "dividend_income", + "ordinary_dividends_amount": "dividend_income", + "qualified_dividends_amount": "qualified_dividend_income", + "long_term_capital_gains_amount": "long_term_capital_gains", + "short_term_capital_gains_amount": "short_term_capital_gains", + "wages_salaries_amount": "employment_income", + "net_capital_gains_amount": "net_capital_gains", + "taxable_ira_distributions_amount": "taxable_ira_distributions", + "traditional_ira_contributions": "traditional_ira_contributions", + "roth_ira_contributions": "roth_ira_contributions", + "taxable_pension_income_amount": "taxable_pension_income", + "taxable_social_security_amount": "taxable_social_security", + "unemployment_insurance_benefits": "unemployment_compensation", + "unemployment_compensation_amount": "unemployment_compensation", + "tip_income": "tip_income", + "rental_income_amount": "rental_income", + "rental_royalty_income_amount": "rental_income", + "partnership_scorp_income_amount": "tax_unit_partnership_s_corp_income", + "schedule_c_income_amount": "self_employment_income", + "state_local_refunds_amount": "salt_refund_income", + "qbi_amount": "qualified_business_income_deduction", + "salt_amount": "salt", + "limited_state_local_taxes_amount": "salt_deduction", + "charitable_amount": "charitable_deduction", + "mortgage_interest_amount": "deductible_mortgage_interest", + "mortgage_interest_paid_amount": "deductible_mortgage_interest", + "home_mortgage_personal_seller_amount": "deductible_mortgage_interest", + "deductible_points_amount": "deductible_mortgage_interest", + "investment_interest_paid_amount": "investment_interest_expense", + "interest_paid_deduction_amount": "interest_deduction", + "medical_amount": "medical_expense_deduction", + "medical_dental_expense_amount": "medical_expense_deduction", + "real_estate_taxes_amount": "real_estate_taxes", + "aca_aptc_amount": "aca_ptc", + "medicaid_benefits": "medicaid", + "social_security_benefits": "social_security", + "social_security_dependents_benefits": "social_security_dependents", + "social_security_disability_benefits": "social_security_disability", + "social_security_retirement_benefits": "social_security_retirement", + "social_security_survivors_benefits": "social_security_survivors", + "snap_benefits": "snap", + "state_individual_income_tax_collections": "state_income_tax", + "ssi_payments": "ssi", + "ssi_total_payments": "ssi", + "tanf_cash_assistance": "tanf", + "medicare_part_b_premiums": "medicare_part_b_premiums", + "net_worth": "net_worth", +} + +ARCH_SELF_DOMAIN_AMOUNT_VARIABLES = frozenset( + set(ARCH_AMOUNT_VARIABLE_ALIASES.values()) - {"adjusted_gross_income"} +) + +ARCH_IRS_SOI_ITEMIZED_DEDUCTION_AMOUNT_VARIABLES = frozenset( + { + "medical_amount", + "medical_dental_expense_amount", + "real_estate_taxes_amount", + "salt_amount", + } +) + +ARCH_IRS_SOI_ITEMIZED_DEDUCTION_COUNT_VARIABLES = frozenset( + { + "medical_claims", + "real_estate_taxes_claims", + "salt_claims", + } +) + +ARCH_IRS_SOI_ITEMIZED_DEDUCTION_TABLE_MARKERS = ( + "itemized", + "historic table 2", + "table 2.", +) + +ARCH_IRS_SOI_CREDIT_AGI_DOMAIN_VARIABLES = frozenset( + { + "actc_amount", + "actc_claims", + "ctc_amount", + "ctc_claims", + } +) + +ARCH_STATE_TO_NATIONAL_ROLLUP_VARIABLES = frozenset( + { + "aca_aptc_amount", + "ctc_amount", + "ctc_claims", + } +) + +ARCH_NATIONAL_ROLLUP_STATE_FIPS = frozenset( + state_fips for state_fips in US_STATE_ABBR_BY_FIPS if state_fips != "72" +) + +ARCH_POSITIVE_AMOUNT_FILTER_VARIABLES = frozenset( + { + # SOI Table 1.4's taxable net capital gains amount is paired with + # returns with taxable net capital gains; PolicyEngine's variable can be + # negative, so the amount target must use the same positive domain. + "net_capital_gains", + } +) + +ARCH_TARGET_CELL_VARIABLE_ALIASES = { + "income_tax": frozenset({"income_tax_positive"}), + "self_employment_income": frozenset({"total_self_employment_income"}), +} + +ARCH_BROAD_BUSINESS_INCOME_SELF_EMPLOYMENT_BLOCKLIST = frozenset( + { + "bea_nipa.proprietors_income_with_inventory_valuation_and_capital_consumption_adjustments", + "bea_nipa.a041rc_proprietors_income_with_inventory_valuation_and_capital_consumption_adjustments", + "bea_regional.proprietors_income", + "bea_regional.sainc5n_line_70_proprietors_income", + "cbo.income_source:net_business_income", + } +) + +ARCH_COUNT_VARIABLE_ALIASES = { + "tax_unit_count": ("tax_unit_count", EntityType.TAX_UNIT, None), + "income_tax_liability_returns": ( + "tax_unit_count", + EntityType.TAX_UNIT, + "income_tax", + ), + "income_tax_before_credits_returns": ( + "tax_unit_count", + EntityType.TAX_UNIT, + "income_tax_before_credits", + ), + "household_count": ("household_count", EntityType.HOUSEHOLD, None), + "population": ("person_count", EntityType.PERSON, None), + "tax_filer_individual_count": ("person_count", EntityType.PERSON, None), + "snap_household_count": ("household_count", EntityType.HOUSEHOLD, "snap"), + "snap_participant_count": ("person_count", EntityType.PERSON, "snap"), + "aca_marketplace_enrollment": ( + "person_count", + EntityType.PERSON, + "aca_ptc", + ), + "aca_ptc_returns": ("tax_unit_count", EntityType.TAX_UNIT, "aca_ptc"), + "medicaid_total_enrollment": ( + "person_count", + EntityType.PERSON, + "medicaid_enrolled", + ), + "medicaid_enrollment": ("person_count", EntityType.PERSON, "medicaid_enrolled"), + "liheap_household_count": ( + "household_count", + EntityType.HOUSEHOLD, + "spm_unit_energy_subsidy_reported", + ), + "tanf_family_count": ("spm_unit_count", EntityType.SPM_UNIT, "tanf"), + "tanf_recipient_count": ("person_count", EntityType.PERSON, "tanf"), +} + +ARCH_FACT_CONCEPT_TO_TARGET = { + "irs_soi.individual_income_tax_returns": ("tax_unit_count", "COUNT"), + "irs_soi.returns_with_total_wages": ("wages_salaries_returns", "COUNT"), + "irs_soi.returns_with_taxable_net_capital_gains": ( + "net_capital_gains_returns", + "COUNT", + ), + "irs_soi.returns_with_taxable_ira_distributions": ( + "taxable_ira_distributions_returns", + "COUNT", + ), + "irs_soi.returns_with_taxable_pension_income": ( + "taxable_pension_income_returns", + "COUNT", + ), + "irs_soi.returns_with_unemployment_compensation": ( + "unemployment_compensation_returns", + "COUNT", + ), + "irs_soi.returns_with_taxable_social_security_benefits": ( + "taxable_social_security_returns", + "COUNT", + ), + "irs_soi.returns_with_income_tax_after_credits": ( + "income_tax_liability_returns", + "COUNT", + ), + "irs_soi.tax_filer_individuals": ( + "tax_filer_individual_count", + "COUNT", + ), + "irs_soi.returns_with_income_tax_before_credits": ( + "income_tax_before_credits_returns", + "COUNT", + ), + "irs_soi.income_tax_before_credits": ( + "income_tax_before_credits_amount", + "AMOUNT", + ), + "irs_soi.income_tax_after_credits": ("income_tax_liability", "AMOUNT"), + "irs_soi.returns_with_premium_tax_credit": ( + "aca_ptc_returns", + "COUNT", + ), + "irs_soi.premium_tax_credit": ("aca_aptc_amount", "AMOUNT"), + "irs_soi.returns_with_earned_income_credit": ("eitc_claims", "COUNT"), + "irs_soi.earned_income_credit": ("eitc_amount", "AMOUNT"), + "irs_soi.total_earned_income_credit": ("eitc_amount", "AMOUNT"), + "irs_soi.returns_with_total_earned_income_credit": ("eitc_claims", "COUNT"), + "irs_soi.returns_with_child_tax_credit": ("ctc_claims", "COUNT"), + "irs_soi.child_tax_credit": ("ctc_amount", "AMOUNT"), + "irs_soi.returns_with_additional_child_tax_credit": ( + "actc_claims", + "COUNT", + ), + "irs_soi.additional_child_tax_credit": ("actc_amount", "AMOUNT"), + "irs_soi.returns_with_real_estate_taxes": ( + "real_estate_taxes_claims", + "COUNT", + ), + "irs_soi.real_estate_taxes": ("real_estate_taxes_amount", "AMOUNT"), + "irs_soi.returns_with_limited_state_local_taxes": ( + "limited_state_local_taxes_returns", + "COUNT", + ), + "irs_soi.limited_state_local_taxes": ( + "limited_state_local_taxes_amount", + "AMOUNT", + ), + "us:statutes/26/62#adjusted_gross_income": ( + "adjusted_gross_income", + "AMOUNT", + ), + "us:statutes/26/62#input.wages": ("wages_salaries_amount", "AMOUNT"), + "irs_soi.adjusted_gross_income": ("adjusted_gross_income", "AMOUNT"), + "irs_soi.total_income_tax": ("income_tax_liability", "AMOUNT"), + "irs_soi.total_wages": ("wages_salaries_amount", "AMOUNT"), + "irs_soi.returns_with_ordinary_dividends": ( + "ordinary_dividends_returns", + "COUNT", + ), + "irs_soi.ordinary_dividends": ("ordinary_dividends_amount", "AMOUNT"), + "irs_soi.returns_with_qualified_dividends": ( + "qualified_dividends_returns", + "COUNT", + ), + "irs_soi.qualified_dividends": ("qualified_dividends_amount", "AMOUNT"), + "irs_soi.returns_with_qualified_business_income_deduction": ( + "qbi_claims", + "COUNT", + ), + "irs_soi.qualified_business_income_deduction": ("qbi_amount", "AMOUNT"), + "irs_soi.returns_with_taxable_interest": ( + "taxable_interest_returns", + "COUNT", + ), + "irs_soi.taxable_interest": ("taxable_interest_amount", "AMOUNT"), + "irs_soi.returns_with_tax_exempt_interest": ( + "tax_exempt_interest_returns", + "COUNT", + ), + "irs_soi.tax_exempt_interest": ("tax_exempt_interest_amount", "AMOUNT"), + "irs_soi.returns_with_schedule_c_income": ( + "schedule_c_income_returns", + "COUNT", + ), + "irs_soi.schedule_c_income": ("schedule_c_income_amount", "AMOUNT"), + "irs_soi.taxable_net_capital_gains": ("net_capital_gains_amount", "AMOUNT"), + "irs_soi.returns_with_partnership_scorp_income": ( + "partnership_scorp_income_returns", + "COUNT", + ), + "irs_soi.partnership_scorp_income": ( + "partnership_scorp_income_amount", + "AMOUNT", + ), + "irs_soi.returns_with_rental_royalty_income": ( + "rental_royalty_income_returns", + "COUNT", + ), + "irs_soi.rental_royalty_income": ( + "rental_royalty_income_amount", + "AMOUNT", + ), + "irs_soi.taxable_ira_distributions": ( + "taxable_ira_distributions_amount", + "AMOUNT", + ), + "irs_soi.taxable_pension_income": ("taxable_pension_income_amount", "AMOUNT"), + "irs_soi.unemployment_compensation": ( + "unemployment_compensation_amount", + "AMOUNT", + ), + "irs_soi.taxable_social_security_benefits": ( + "taxable_social_security_amount", + "AMOUNT", + ), + "irs_soi.total_itemized_deductions": ("itemized_deductions", "AMOUNT"), + "irs_soi.returns_with_itemized_deductions": ( + "itemized_deductions_returns", + "COUNT", + ), + "irs_soi.returns_with_medical_dental_expense_deduction": ( + "medical_claims", + "COUNT", + ), + "irs_soi.medical_dental_expense_deduction": ( + "medical_dental_expense_amount", + "AMOUNT", + ), + "irs_soi.standard_deduction": ("standard_deduction", "AMOUNT"), + "irs_soi.taxable_income": ("taxable_income", "AMOUNT"), + "irs_soi.total_income": ("total_income", "AMOUNT"), + "irs_soi.returns_with_total_income": ("total_income_returns", "COUNT"), + "irs_soi.capital_asset_net_gain_less_loss": ( + "capital_asset_net_gain_less_loss", + "AMOUNT", + ), + "irs_soi.returns_with_capital_asset_net_gain_less_loss": ( + "capital_asset_net_gain_less_loss_returns", + "COUNT", + ), + "irs_soi.tax_credits": ("tax_credits", "AMOUNT"), + "irs_soi.returns_with_tax_credits": ("tax_credits_returns", "COUNT"), + "irs_soi.returns_with_taxable_income": ("taxable_income_returns", "COUNT"), + "irs_soi.returns_with_total_income_tax": ( + "income_tax_liability_returns", + "COUNT", + ), + "irs_soi.individual_income_tax_returns_excluding_dependents": ( + "tax_unit_count", + "COUNT", + ), + "irs_soi.eic_earned_income": ("eic_earned_income", "AMOUNT"), + "irs_soi.returns_with_eic_earned_income": ( + "eic_earned_income_returns", + "COUNT", + ), + "irs_soi.eic_refundable_portion": ("eitc_refundable_portion", "AMOUNT"), + "irs_soi.returns_with_eic_refundable_portion": ( + "eitc_refundable_portion_returns", + "COUNT", + ), + "irs_soi.roth_ira_contributions": ("roth_ira_contributions", "AMOUNT"), + "irs_soi.roth_ira_contributors": ("roth_ira_contributors", "COUNT"), + "irs_soi.traditional_ira_contributions": ( + "traditional_ira_contributions", + "AMOUNT", + ), + "irs_soi.traditional_ira_contributors": ( + "traditional_ira_contributors", + "COUNT", + ), + "irs_soi.form_w2_social_security_tip_income": ("tip_income", "AMOUNT"), + "irs_soi.form_w2_social_security_tip_returns": ( + "tip_income_returns", + "COUNT", + ), + "irs_soi.form_w2_social_security_tip_taxpayers": ( + "tip_income_taxpayers", + "COUNT", + ), + "irs_soi.form_w2_401k_elective_deferrals": ( + "traditional_401k_contributions", + "AMOUNT", + ), + "irs_soi.form_w2_designated_roth_401k_contributions": ( + "roth_401k_contributions", + "AMOUNT", + ), + "irs_soi.payments_to_keogh_plan": ( + "self_employed_pension_contribution_ald", + "AMOUNT", + ), + "federal_reserve.z1.households_nonprofits_net_worth": ( + "net_worth", + "AMOUNT", + ), + "cms_medicare.part_b_premium_income": ( + "medicare_part_b_premiums", + "AMOUNT", + ), + "census_decennial.resident_population": ("population", "COUNT"), + "census_decennial.occupied_housing_units": ("household_count", "COUNT"), + "census_pep.resident_population": ("population", "COUNT"), + "census_stc.individual_income_tax_collections": ( + "state_individual_income_tax_collections", + "AMOUNT", + ), + "cms_aca.marketplace_effectuated_enrollment": ( + "aca_marketplace_enrollment", + "COUNT", + ), + "cms_aca.marketplace_plan_selections": ( + "aca_marketplace_plan_selections", + "COUNT", + ), + "cms_aca.aptc_consumers": ("aca_aptc_consumers", "COUNT"), + "cms_aca.average_monthly_aptc": ("aca_average_monthly_aptc", "RATE"), + "cms_medicaid.total_medicaid_enrollment": ( + "medicaid_total_enrollment", + "COUNT", + ), + "cms_medicaid.total_medicaid_chip_enrollment": ( + "medicaid_chip_total_enrollment", + "COUNT", + ), + "cms_medicaid.total_chip_enrollment": ("chip_total_enrollment", "COUNT"), + "cms_medicaid.medicaid_chip_child_enrollment": ( + "medicaid_chip_child_enrollment", + "COUNT", + ), + "cms_medicaid.total_adult_medicaid_enrollment": ( + "adult_medicaid_enrollment", + "COUNT", + ), + "cms_nhe.medicaid_title_xix_expenditures": ( + "medicaid_benefits", + "AMOUNT", + ), + "hhs_acf_tanf.cash_assistance_expenditures": ( + "tanf_cash_assistance", + "AMOUNT", + ), + "hhs_acf_tanf.average_monthly_tanf_total_recipients": ( + "tanf_recipient_count", + "COUNT", + ), + "hhs_acf_tanf.average_monthly_tanf_adult_recipients": ( + "tanf_adult_recipient_count", + "COUNT", + ), + "hhs_acf_tanf.average_monthly_tanf_child_recipients": ( + "tanf_child_recipient_count", + "COUNT", + ), + "hhs_acf_tanf.average_monthly_tanf_total_families": ( + "tanf_family_count", + "COUNT", + ), + "hhs_acf_tanf.average_monthly_tanf_one_parent_families": ( + "tanf_one_parent_family_count", + "COUNT", + ), + "hhs_acf_tanf.average_monthly_tanf_two_parent_families": ( + "tanf_two_parent_family_count", + "COUNT", + ), + "hhs_acf_tanf.average_monthly_tanf_no_parent_families": ( + "tanf_no_parent_family_count", + "COUNT", + ), + "hhs_acf_liheap.households_served_by_state_programs": ( + "liheap_household_count", + "COUNT", + ), + "bea_nipa.wages_and_salaries": ("wages_salaries_amount", "AMOUNT"), + "bea_nipa.proprietors_income_with_inventory_valuation_and_capital_consumption_adjustments": ( + "proprietors_income_amount", + "AMOUNT", + ), + "bea_nipa.rental_income_of_persons_with_capital_consumption_adjustment": ( + "rental_income_amount", + "AMOUNT", + ), + "bea_nipa.personal_interest_income": ( + "personal_interest_income_amount", + "RATE", + ), + "bea_nipa.personal_dividend_income": ( + "personal_dividend_income_amount", + "AMOUNT", + ), + "bea_nipa.supplements_to_wages_and_salaries": ( + "supplements_to_wages_and_salaries", + "RATE", + ), + "bea_nipa.employer_contributions_for_employee_pension_and_insurance_funds": ( + "employer_pension_and_insurance_contributions", + "RATE", + ), + "bea_nipa.employer_contributions_for_government_social_insurance": ( + "employer_government_social_insurance_contributions", + "RATE", + ), + "bea_nipa.farm_proprietors_income": ("farm_proprietors_income", "RATE"), + "bea_nipa.nonfarm_proprietors_income": ("nonfarm_proprietors_income", "RATE"), + "bea_nipa.government_social_benefits_to_persons": ( + "government_social_benefits_to_persons", + "RATE", + ), + "bea_nipa.social_security_benefits": ("social_security_benefits", "AMOUNT"), + "bea_nipa.medicare_benefits": ("medicare_benefits", "RATE"), + "bea_nipa.medicaid_benefits": ("medicaid_benefits", "AMOUNT"), + "bea_nipa.unemployment_insurance_benefits": ( + "unemployment_insurance_benefits", + "AMOUNT", + ), + "bea_nipa.veterans_benefits": ("veterans_benefits", "RATE"), + "bea_nipa.other_government_social_benefits_to_persons": ( + "other_government_social_benefits_to_persons", + "RATE", + ), + "bea_nipa.other_current_transfer_receipts_from_business_net": ( + "other_current_transfer_receipts_from_business_net", + "RATE", + ), + "bea_nipa.personal_current_transfer_receipts": ( + "personal_current_transfer_receipts", + "RATE", + ), + "bea_nipa.personal_income": ("personal_income", "RATE"), + "bea_nipa.personal_current_taxes": ("personal_current_taxes", "RATE"), + "bea_nipa.disposable_personal_income": ("disposable_personal_income", "RATE"), + "bea_nipa.personal_outlays": ("personal_outlays", "RATE"), + "bea_nipa.personal_saving": ("personal_saving", "RATE"), + "bea_nipa.personal_saving_rate": ("personal_saving_rate", "RATE"), + "bea_regional.personal_income": ("regional_personal_income", "RATE"), + "bea_regional.dividends_interest_and_rent": ( + "regional_dividends_interest_and_rent", + "RATE", + ), + "bea_regional.personal_current_transfer_receipts": ( + "regional_personal_current_transfer_receipts", + "RATE", + ), + "bea_regional.wages_and_salaries": ("wages_salaries_amount", "AMOUNT"), + "bea_regional.supplements_to_wages_and_salaries": ( + "regional_supplements_to_wages_and_salaries", + "RATE", + ), + "bea_regional.proprietors_income": ("proprietors_income_amount", "AMOUNT"), + "usda_snap.total_benefits": ("snap_benefits", "AMOUNT"), + "usda_snap.average_monthly_households": ("snap_household_count", "COUNT"), + "usda_snap.average_monthly_persons": ("snap_participant_count", "COUNT"), + "usda_snap.average_monthly_benefit_per_person": ( + "snap_average_monthly_benefit_per_person", + "RATE", + ), +} + +ARCH_FACT_DOMAIN_CONSTRAINTS = { + "all_individual_income_tax_returns": (("is_tax_filer", "==", "1"),), + "form_w2_items": (), + "household_balance_sheet": (), + "individual_income_tax_returns": (("is_tax_filer", "==", "1"),), + "individual_income_tax_returns_excluding_dependents": ( + ("is_dependent", "==", "0"), + ), + "individual_income_tax_returns_with_earned_income_credit": (("eitc", ">", "0"),), + "individual_income_tax_returns_with_itemized_deductions": ( + ("itemized_deductions", ">", "0"), + ), + "individual_retirement_arrangement_contributions": (), + "compensation_of_employees": (), + "households": (), + "aca_marketplace_effectuated_enrollment": (), + "aca_marketplace_qhp_selections": (), + "medicaid_chip_enrollment": (), + "medicare_financing": (), + "national_health_expenditures": (), + "personal_current_transfer_receipts": (), + "personal_income": (), + "resident_population": (), + "social_security_and_ssi_payments": (), + "state_government_tax_collections": (), + "supplemental_nutrition_assistance_program": (("snap", "==", "1"),), + "tanf_cash_assistance": (), + "tanf_caseload": (), + "liheap_state_programs": (), +} + +ARCH_FACT_CONSTRAINT_VARIABLE_ALIASES = { + "age": "age", + "us.tax.earned_income_credit_qualifying_children": "eitc_child_count", + "us_social_security_and_ssi.program_payment_type": "program_payment_type", + "us:statutes/26/62#adjusted_gross_income": "adjusted_gross_income", + "irs_soi.adjusted_gross_income": "adjusted_gross_income", +} + +ARCH_IGNORED_FACT_CONSTRAINT_VARIABLES = frozenset( + { + "administering_entity", + "amount_basis", + "bea_nipa.series_code", + "bea_regional.geo_name", + "bea_regional.line_code", + "bea_regional.table_name", + "medicare.financing_component", + "medicare.part", + "program", + } +) + +ARCH_ENTITY_HINTS = { + "adjusted_gross_income": EntityType.TAX_UNIT, + "income_tax": EntityType.TAX_UNIT, + "income_tax_positive": EntityType.TAX_UNIT, + "income_tax_before_credits": EntityType.TAX_UNIT, + "eitc": EntityType.TAX_UNIT, + "non_refundable_ctc": EntityType.TAX_UNIT, + "refundable_ctc": EntityType.TAX_UNIT, + "qualified_business_income_deduction": EntityType.TAX_UNIT, + "salt": EntityType.TAX_UNIT, + "salt_deduction": EntityType.TAX_UNIT, + "charitable_deduction": EntityType.TAX_UNIT, + "deductible_mortgage_interest": EntityType.TAX_UNIT, + "interest_deduction": EntityType.TAX_UNIT, + "investment_interest_expense": EntityType.PERSON, + "medical_expense_deduction": EntityType.TAX_UNIT, + "real_estate_taxes": EntityType.TAX_UNIT, + "tax_unit_partnership_s_corp_income": EntityType.TAX_UNIT, + "dividend_income": EntityType.PERSON, + "employment_income": EntityType.PERSON, + "qualified_dividend_income": EntityType.PERSON, + "taxable_interest_income": EntityType.PERSON, + "tax_exempt_interest_income": EntityType.PERSON, + "long_term_capital_gains": EntityType.PERSON, + "short_term_capital_gains": EntityType.PERSON, + "proprietors_income_amount": EntityType.PERSON, + "rental_income": EntityType.PERSON, + "roth_401k_contributions": EntityType.PERSON, + "self_employment_income": EntityType.PERSON, + "self_employed_pension_contribution_ald": EntityType.TAX_UNIT, + "salt_refund_income": EntityType.PERSON, + "state_income_tax": EntityType.TAX_UNIT, + "taxable_ira_distributions": EntityType.PERSON, + "traditional_ira_contributions": EntityType.PERSON, + "roth_ira_contributions": EntityType.PERSON, + "taxable_pension_income": EntityType.PERSON, + "taxable_social_security": EntityType.PERSON, + "tip_income": EntityType.PERSON, + "traditional_401k_contributions": EntityType.PERSON, + "unemployment_compensation": EntityType.PERSON, + "medicare_part_b_premiums": EntityType.PERSON, + "medicaid": EntityType.PERSON, + "net_worth": EntityType.HOUSEHOLD, + "social_security": EntityType.PERSON, + "social_security_dependents": EntityType.PERSON, + "social_security_disability": EntityType.PERSON, + "social_security_retirement": EntityType.PERSON, + "social_security_survivors": EntityType.PERSON, + "snap": EntityType.HOUSEHOLD, + "ssi": EntityType.PERSON, + "tanf": EntityType.SPM_UNIT, +} + +ARCH_AGI_BRACKET_FILTERS = { + "under_1": (None, 1), + "1_to_10k": (1, 10_000), + "10k_to_25k": (10_000, 25_000), + "25k_to_50k": (25_000, 50_000), + "50k_to_75k": (50_000, 75_000), + "75k_to_100k": (75_000, 100_000), + "100k_to_200k": (100_000, 200_000), + "200k_to_500k": (200_000, 500_000), + "500k_to_1m": (500_000, 1_000_000), + "1m_plus": (1_000_000, None), +} + +ARCH_CURRENT_TAX_VARIABLES = frozenset( + { + "tax_unit_count", + "adjusted_gross_income", + "income_tax_liability", + } +) + +ARCH_LABEL_WORD_OVERRIDES = { + "aca": "ACA", + "actc": "ACTC", + "agi": "AGI", + "bls": "BLS", + "cbo": "CBO", + "cms": "CMS", + "ctc": "CTC", + "eitc": "EITC", + "irs": "IRS", + "qbi": "QBI", + "liheap": "LIHEAP", + "snap": "SNAP", + "soi": "SOI", + "ssi": "SSI", + "tanf": "TANF", + "usda": "USDA", +} + +ARCH_VARIABLE_LABEL_OVERRIDES = { + "adjusted_gross_income": "Adjusted gross income", + "income_tax_liability": "Income tax liability", + "income_tax_liability_returns": "Returns with income tax after credits", + "income_tax_before_credits_returns": ("Returns with income tax before credits"), + "income_tax_before_credits_amount": "Income tax before credits amount", + "tax_filer_individual_count": "Individuals on tax returns", + "aca_ptc_returns": "Returns with premium tax credit", + "aca_aptc_amount": "Premium tax credit amount", + "eitc_claims": "Returns with earned income credit", + "eitc_amount": "Earned income credit amount", + "real_estate_taxes_claims": "Returns with real estate taxes", + "real_estate_taxes_amount": "Real estate taxes amount", + "limited_state_local_taxes_returns": ("Returns with limited state and local taxes"), + "tax_exempt_interest_returns": "Tax-exempt interest returns", + "tax_exempt_interest_amount": "Tax-exempt interest amount", + "taxable_interest_amount": "Taxable interest amount", + "wages_salaries_returns": "Returns with total wages", + "wages_salaries_amount": "Total wages amount", + "personal_dividend_income_amount": "Personal dividend income amount", + "proprietors_income_amount": "Proprietors' income amount", + "rental_income_amount": "Rental income amount", + "net_capital_gains_returns": "Returns with taxable net capital gains", + "net_capital_gains_amount": "Taxable net capital gains amount", + "taxable_ira_distributions_returns": ("Returns with taxable IRA distributions"), + "taxable_ira_distributions_amount": "Taxable IRA distributions amount", + "taxable_pension_income_returns": "Returns with taxable pension income", + "taxable_pension_income_amount": "Taxable pension income amount", + "unemployment_compensation_returns": ("Returns with unemployment compensation"), + "unemployment_compensation_amount": "Unemployment compensation amount", + "unemployment_insurance_benefits": "Unemployment insurance benefits", + "taxable_social_security_returns": ( + "Returns with taxable Social Security benefits" + ), + "taxable_social_security_amount": "Taxable Social Security benefits amount", + "ordinary_dividends_amount": "Ordinary dividends amount", + "qualified_dividends_returns": "Returns with qualified dividends", + "qualified_dividends_amount": "Qualified dividends amount", + "long_term_capital_gains_amount": "Long-term capital gains amount", + "short_term_capital_gains_amount": "Short-term capital gains amount", + "partnership_scorp_income_returns": "Returns with partnership and S-corp income", + "partnership_scorp_income_amount": "Partnership and S-corp income amount", + "schedule_c_income_returns": "Returns with Schedule C income", + "schedule_c_income_amount": "Schedule C income amount", + "medical_claims": "Returns with medical expense deduction", + "medical_dental_expense_amount": "Medical and dental expense amount", + "tax_unit_count": "Tax unit count", + "household_count": "Household count", + "population": "Population count", + "snap_household_count": "SNAP household count", + "snap_participant_count": "SNAP participant count", + "aca_marketplace_enrollment": "ACA marketplace enrollment", + "state_individual_income_tax_collections": ( + "State individual income tax collections" + ), + "limited_state_local_taxes_amount": "Limited state and local taxes amount", + "interest_paid_deduction_amount": "Interest paid deduction amount", + "mortgage_interest_paid_amount": "Mortgage interest paid amount", + "home_mortgage_personal_seller_amount": ( + "Home mortgage from personal seller amount" + ), + "deductible_points_amount": "Deductible points amount", + "investment_interest_paid_amount": "Investment interest paid amount", + "medicaid_benefits": "Medicaid benefits", + "medicaid_total_enrollment": "Medicaid enrollment", + "medicaid_enrollment": "Medicaid enrollment", + "liheap_household_count": "LIHEAP household count", + "social_security_benefits": "Social Security benefits", + "social_security_dependents_benefits": "Social Security dependent benefits", + "social_security_disability_benefits": "Social Security disability benefits", + "social_security_retirement_benefits": "Social Security retirement benefits", + "social_security_survivors_benefits": "Social Security survivor benefits", + "ssi_payments": "SSI payments", + "tanf_cash_assistance": "TANF cash assistance", + "tanf_family_count": "TANF family count", + "tanf_recipient_count": "TANF recipient count", + "tip_income": "Tip income", + "traditional_401k_contributions": "Traditional 401(k) contributions", + "traditional_ira_contributions": "Traditional IRA contributions", + "roth_401k_contributions": "Roth 401(k) contributions", + "roth_ira_contributions": "Roth IRA contributions", + "self_employed_pension_contribution_ald": ( + "Self-employed pension contribution ALD" + ), +} + +ARCH_AGI_BRACKET_LABELS = { + "under_1": "under $1", + "1_to_10k": "$1-$10k", + "10k_to_25k": "$10k-$25k", + "25k_to_50k": "$25k-$50k", + "50k_to_75k": "$50k-$75k", + "75k_to_100k": "$75k-$100k", + "100k_to_200k": "$100k-$200k", + "200k_to_500k": "$200k-$500k", + "500k_to_1m": "$500k-$1m", + "1m_plus": "$1m+", +} + +ARCH_MODEL_AMOUNT_VARIABLE_HINTS = { + **{ + model_variable: source_variable + for source_variable, model_variable in ARCH_AMOUNT_VARIABLE_ALIASES.items() + }, + "employment_income": "wages_salaries_amount", + "income_tax_positive": "income_tax_liability", + "income_tax_before_credits": "income_tax_before_credits_amount", + "interest_deduction": "interest_paid_deduction_amount", + "medicare_part_b_premiums": "medicare_part_b_premiums", + "net_capital_gains": "net_capital_gains_amount", + "net_worth": "net_worth", + "real_estate_taxes": "real_estate_taxes_amount", + "roth_401k_contributions": "roth_401k_contributions", + "self_employed_pension_contribution_ald": ( + "self_employed_pension_contribution_ald" + ), + "total_self_employment_income": "schedule_c_income_amount", + "taxable_ira_distributions": "taxable_ira_distributions_amount", + "taxable_pension_income": "taxable_pension_income_amount", + "taxable_social_security": "taxable_social_security_amount", + "tip_income": "tip_income", + "traditional_401k_contributions": "traditional_401k_contributions", + "unemployment_compensation": "unemployment_compensation_amount", +} + +ARCH_MODEL_COUNT_DOMAIN_VARIABLE_HINTS = { + "adjusted_gross_income": "tax_unit_count", + "dividend_income": "ordinary_dividends_returns", + "employment_income": "wages_salaries_returns", + "eitc": "eitc_claims", + "income_tax": "income_tax_liability_returns", + "income_tax_before_credits": "income_tax_before_credits_returns", + "medical_expense_deduction": "medical_claims", + "net_capital_gains": "net_capital_gains_returns", + "non_refundable_ctc": "ctc_claims", + "qualified_business_income_deduction": "qbi_claims", + "qualified_dividend_income": "qualified_dividends_returns", + "real_estate_taxes": "real_estate_taxes_claims", + "refundable_ctc": "actc_claims", + "rental_income": "rental_royalty_income_returns", + "salt": "salt_claims", + "self_employment_income": "schedule_c_income_returns", + "total_self_employment_income": "schedule_c_income_returns", + "tax_exempt_interest_income": "tax_exempt_interest_returns", + "tax_unit_partnership_s_corp_income": "partnership_scorp_income_returns", + "taxable_interest_income": "taxable_interest_returns", + "taxable_ira_distributions": "taxable_ira_distributions_returns", + "taxable_pension_income": "taxable_pension_income_returns", + "taxable_social_security": "taxable_social_security_returns", + "unemployment_compensation": "unemployment_compensation_returns", +} + +ARCH_BEA_FULL_POP_AMOUNT_VARIABLES = frozenset( + { + "dividend_income", + "employment_income", + "rental_income", + "unemployment_compensation", + } +) + +ARCH_BEA_FULL_POP_AMOUNT_ARCH_VARIABLES = { + "dividend_income": "personal_dividend_income_amount", + "employment_income": "wages_salaries_amount", + "rental_income": "rental_income_amount", + "unemployment_compensation": "unemployment_insurance_benefits", +} + +ARCH_IRS_SOI_GAP_VARIABLES = frozenset( + { + *ARCH_MODEL_AMOUNT_VARIABLE_HINTS, + *ARCH_MODEL_COUNT_DOMAIN_VARIABLE_HINTS, + "income_tax_positive", + "interest_deduction", + "roth_ira_contributions", + "tax_unit_count", + "tip_income", + "traditional_ira_contributions", + } +) + +ARCH_DEPRIORITIZED_SURVEY_OR_MODEL_GAP_VARIABLES = frozenset( + { + "alimony_expense", + "child_support_expense", + "child_support_received", + "health_insurance_premiums_without_medicare_part_b", + "other_medical_expenses", + "over_the_counter_health_expenses", + "rent", + "spm_unit_capped_housing_subsidy", + "spm_unit_capped_work_childcare_expenses", + } +) + +ARCH_DEPRIORITIZED_SURVEY_OR_MODEL_GAP_DOMAINS = frozenset( + { + "ssn_card_type", + } +) + +ARCH_GAP_SOURCE_TABLE_HINTS = { + "aca_aptc_amount": "CMS Marketplace Open Enrollment public-use files", + "aca_marketplace_enrollment": "CMS Marketplace Open Enrollment public-use files", + "employment_income": "IRS SOI Publication 1304 Table 1.4", + "aca_ptc_returns": "IRS SOI Historic Table 2", + "eitc_amount": "IRS SOI Historic Table 2", + "eitc_claims": "IRS SOI Historic Table 2", + "income_tax_liability": "IRS SOI Publication 1304 Table 1.1 or Historic Table 2", + "income_tax_before_credits": "IRS SOI Publication 1304 Table 1.1", + "income_tax_before_credits_returns": "IRS SOI Historic Table 2", + "tax_filer_individual_count": "IRS SOI Historic Table 2", + "interest_paid_deduction_amount": "IRS SOI Historic Table 2", + "limited_state_local_taxes_amount": "IRS SOI Historic Table 2", + "liheap_household_count": "HHS ACF LIHEAP National Profile", + "medicaid_benefits": ( + "CMS National Health Expenditures by type of service and source of funds" + ), + "net_capital_gains": "IRS SOI Publication 1304 Table 1.4", + "population": "Census Population Estimates Program Vintage 2024 age-sex files", + "real_estate_taxes": "IRS SOI itemized deduction tables or ACS state files", + "roth_ira_contributions": "IRS SOI IRA contribution tables", + "roth_401k_contributions": "IRS SOI Form W-2 Statistics Table 4.B", + "self_employed_pension_contribution_ald": "IRS SOI Publication 1304 Table 1.4", + "state_individual_income_tax_collections": ( + "Census State Tax Collections item T40" + ), + "social_security_benefits": "SSA Annual Statistical Supplement", + "social_security_dependents_benefits": "SSA Annual Statistical Supplement", + "social_security_disability_benefits": "SSA Annual Statistical Supplement", + "social_security_retirement_benefits": "SSA Annual Statistical Supplement", + "social_security_survivors_benefits": "SSA Annual Statistical Supplement", + "snap_benefits": "USDA FNS SNAP annual state participation and benefit workbooks", + "snap_household_count": ( + "USDA FNS SNAP annual state participation and benefit workbooks" + ), + "snap_participant_count": ( + "USDA FNS SNAP annual state participation and benefit workbooks" + ), + "ssi_payments": "SSA Annual Statistical Supplement", + "tanf_cash_assistance": "ACF TANF Financial Data", + "tanf_family_count": "ACF TANF Caseload Data", + "tanf_recipient_count": "ACF TANF Caseload Data", + "tip_income": "IRS SOI Form W-2 Statistics", + "traditional_ira_contributions": "IRS SOI IRA contribution tables", + "traditional_401k_contributions": "IRS SOI Form W-2 Statistics Table 4.B", + "taxable_ira_distributions": "IRS SOI IRA accumulation/distribution tables", + "taxable_pension_income": "IRS SOI Publication 1304 Table 1.4", + "taxable_social_security": "IRS SOI Publication 1304 Table 1.4", + "unemployment_compensation": "IRS SOI Publication 1304 Table 1.4", +} + + +@dataclass(frozen=True) +class SOIAgingFactors: + """Declared factors used to age SOI target records to a model year.""" + + source_year: int + target_year: int + count_factor: float + amount_factor: float + count_method: str + amount_method: str + + +@dataclass(frozen=True) +class ArchTargetRecord: + """A source target record loaded from the Arch SQLite DB.""" + + target_id: int + stratum_id: int + variable: str + period: int + value: float + target_type: str + geographic_level: str | None + geography_id: str | None + source: str + source_table: str | None + source_url: str | None + notes: str | None + stratum_name: str | None + jurisdiction: str + constraints: tuple[tuple[str, str, str], ...] + source_period: int | None = None + aging_factors: SOIAgingFactors | None = None + aggregate_fact_key: str | None = None + semantic_fact_key: str | None = None + source_record_id: str | None = None + source_cell_keys: tuple[str, ...] = () + source_row_keys: tuple[str, ...] = () + unit: str | None = None + concept: str | None = None + source_concept: str | None = None + concept_relation: str | None = None + concept_authority: str | None = None + concept_evidence_url: str | None = None + concept_evidence_notes: str | None = None + legal_vintage: str | None = None + source_db_path: str | None = None + source_db_index: int | None = None + source_target_id: int | None = None + source_stratum_id: int | None = None + + +@dataclass(frozen=True) +class ArchTargetCellCoverage: + """Coverage for one PolicyEngine target cell from an Arch target DB.""" + + cell: dict[str, str | None] + target_ids: tuple[int, ...] + target_names: tuple[str, ...] + sources: tuple[str, ...] + + @property + def covered(self) -> bool: + return bool(self.target_ids) + + @property + def target_count(self) -> int: + return len(self.target_ids) + + def to_dict(self) -> dict[str, Any]: + return { + "cell": dict(self.cell), + "covered": self.covered, + "target_count": self.target_count, + "target_ids": list(self.target_ids), + "target_names": list(self.target_names), + "sources": list(self.sources), + } + + +@dataclass(frozen=True) +class ArchTargetProfileCoverageReport: + """JSON-ready summary of Arch coverage for a Microplex target profile.""" + + profile_name: str + period: int + target_cell_count: int + covered_cell_count: int + uncovered_cell_count: int + coverage_rate: float + by_geo_level: dict[str, dict[str, int]] + by_variable: dict[str, dict[str, int]] + cells: tuple[ArchTargetCellCoverage, ...] + + def to_dict(self) -> dict[str, Any]: + return { + "profile_name": self.profile_name, + "period": self.period, + "target_cell_count": self.target_cell_count, + "covered_cell_count": self.covered_cell_count, + "uncovered_cell_count": self.uncovered_cell_count, + "coverage_rate": self.coverage_rate, + "by_geo_level": self.by_geo_level, + "by_variable": self.by_variable, + "cells": [cell.to_dict() for cell in self.cells], + } + + +@dataclass(frozen=True) +class ArchTargetGapQueueRow: + """One target-profile cell as an Arch authoring task.""" + + priority: int + profile_name: str + period: int + variable: str + geo_level: str | None + domain_variable: str | None + geographic_id: str | None + covered: bool + target_count: int + target_ids: tuple[int, ...] + sources: tuple[str, ...] + expected_source: str | None + expected_source_table: str | None + expected_arch_variable: str | None + expected_target_type: str | None + expected_entity: str | None + expected_aggregation: str | None + expected_filters: tuple[dict[str, Any], ...] + gap_category: str + loader_status: str + agent_task_kind: str + notes: str + + def to_dict(self) -> dict[str, Any]: + return { + "priority": self.priority, + "profile_name": self.profile_name, + "period": self.period, + "cell": { + "variable": self.variable, + "geo_level": self.geo_level, + "domain_variable": self.domain_variable, + "geographic_id": self.geographic_id, + }, + "covered": self.covered, + "target_count": self.target_count, + "target_ids": list(self.target_ids), + "sources": list(self.sources), + "expected_source": self.expected_source, + "expected_source_table": self.expected_source_table, + "expected_arch_variable": self.expected_arch_variable, + "expected_target_type": self.expected_target_type, + "expected_entity": self.expected_entity, + "expected_aggregation": self.expected_aggregation, + "expected_filters": list(self.expected_filters), + "gap_category": self.gap_category, + "loader_status": self.loader_status, + "agent_task_kind": self.agent_task_kind, + "notes": self.notes, + } + + +@dataclass(frozen=True) +class ArchTargetGapQueueReport: + """JSON-ready Arch authoring queue for a Microplex target profile.""" + + profile_name: str + period: int + row_count: int + covered_row_count: int + uncovered_row_count: int + by_loader_status: dict[str, int] + by_gap_category: dict[str, int] + rows: tuple[ArchTargetGapQueueRow, ...] + + def to_dict(self) -> dict[str, Any]: + return { + "profile_name": self.profile_name, + "period": self.period, + "row_count": self.row_count, + "covered_row_count": self.covered_row_count, + "uncovered_row_count": self.uncovered_row_count, + "by_loader_status": self.by_loader_status, + "by_gap_category": self.by_gap_category, + "rows": [row.to_dict() for row in self.rows], + } + + +@dataclass(frozen=True) +class ArchTargetParityRow: + """One canonical target identity compared across two Arch artifacts.""" + + status: str + identity: tuple[Any, ...] + incumbent_targets: tuple[CanonicalTargetSpec, ...] + candidate_targets: tuple[CanonicalTargetSpec, ...] + absolute_delta: float | None + relative_delta: float | None + + def to_dict(self) -> dict[str, Any]: + return { + "status": self.status, + "identity": _arch_target_parity_identity_dict(self.identity), + "incumbent_target_count": len(self.incumbent_targets), + "candidate_target_count": len(self.candidate_targets), + "absolute_delta": self.absolute_delta, + "relative_delta": self.relative_delta, + "incumbent_targets": [ + _target_parity_sample(target) for target in self.incumbent_targets + ], + "candidate_targets": [ + _target_parity_sample(target) for target in self.candidate_targets + ], + } + + +@dataclass(frozen=True) +class ArchTargetParityReport: + """JSON-ready parity report between incumbent and candidate Arch artifacts.""" + + period: int + incumbent_artifacts: tuple[str, ...] + candidate_artifacts: tuple[str, ...] + value_abs_tolerance: float + value_rel_tolerance: float + counts: dict[str, int] + rows: tuple[ArchTargetParityRow, ...] + errors: tuple[dict[str, Any], ...] + + @property + def valid(self) -> bool: + return not self.errors + + def to_dict(self, *, row_limit: int | None = None) -> dict[str, Any]: + rows = self.rows if row_limit is None else self.rows[: max(0, row_limit)] + return { + "valid": self.valid, + "period": self.period, + "incumbent_artifacts": list(self.incumbent_artifacts), + "candidate_artifacts": list(self.candidate_artifacts), + "value_abs_tolerance": self.value_abs_tolerance, + "value_rel_tolerance": self.value_rel_tolerance, + "counts": self.counts, + "row_count": len(self.rows), + "rows": [row.to_dict() for row in rows], + "errors": list(self.errors), + } + + +class ArchSQLiteTargetProvider: + """Read Arch target records from the Arch SQLite DB.""" + + def __init__( + self, + db_path: str | Path, + *, + jurisdiction: str = "us", + compose_model_year_targets: bool = True, + age_soi_targets: bool = True, + ) -> None: + self.db_path = Path(db_path) + self.jurisdiction = jurisdiction + self.compose_model_year_targets = compose_model_year_targets + self.age_soi_targets = age_soi_targets + + def load_target_set(self, query: TargetQuery | None = None) -> TargetSet: + """Load canonical targets through the core provider protocol.""" + if not self.db_path.exists(): + raise FileNotFoundError(f"Arch targets DB not found: {self.db_path}") + + query = query or TargetQuery() + provider_filters = dict(query.provider_filters) + period = query.period if isinstance(query.period, int) else None + jurisdiction = str(provider_filters.get("jurisdiction") or self.jurisdiction) + variables = _as_string_tuple(provider_filters.get("variables")) + domain_variables = _as_string_tuple(provider_filters.get("domain_variables")) + sources = _as_string_tuple(provider_filters.get("sources")) + geo_levels = _as_string_tuple(provider_filters.get("geo_levels")) + target_cells = _as_target_cell_filters(provider_filters.get("target_cells")) + compose_model_year_targets = bool( + provider_filters.get( + "compose_model_year_targets", + self.compose_model_year_targets, + ) + ) + age_soi_targets = bool( + provider_filters.get("age_soi_targets", self.age_soi_targets) + ) + entity_overrides = provider_filters.get("entity_overrides") or {} + + records = ( + self._compose_model_year_records( + target_year=period, + jurisdiction=jurisdiction, + sources=sources, + age_soi_targets=age_soi_targets, + ) + if compose_model_year_targets and period is not None + else self.load_records( + period=period, + jurisdiction=jurisdiction, + sources=sources, + ) + ) + canonical_targets = TargetSet( + [ + target + for record in records + if _matches_arch_provider_filters( + record, + variables=variables, + domain_variables=domain_variables, + geo_levels=geo_levels, + target_cells=target_cells, + entity_overrides=entity_overrides, + ) + for target in [ + arch_target_record_to_canonical_spec( + record, + entity_overrides=entity_overrides, + ) + ] + if target is not None + ] + ) + return apply_target_query( + canonical_targets, + TargetQuery( + period=query.period, + entity=query.entity, + names=query.names, + metadata_filters=query.metadata_filters, + ), + ) + + def load_records( + self, + *, + period: int | None = None, + jurisdiction: str | None = None, + sources: tuple[str, ...] = (), + ) -> list[ArchTargetRecord]: + """Load source target records with attached stratum constraints.""" + jurisdiction = jurisdiction or self.jurisdiction + normalized_sources = tuple(_normalize_arch_source(source) for source in sources) + clauses = [_jurisdiction_clause(jurisdiction)] + params: list[Any] = [] + if period is not None: + clauses.append("t.period = ?") + params.append(int(period)) + if normalized_sources: + placeholders = ", ".join("?" for _ in normalized_sources) + clauses.append(f"t.source IN ({placeholders})") + params.extend(normalized_sources) + where_clause = " AND ".join(clauses) + sql = f""" + SELECT + t.id AS target_id, + t.stratum_id, + t.variable, + t.period, + t.value, + t.target_type, + t.geographic_level, + t.source, + t.source_table, + t.source_url, + t.notes, + s.name AS stratum_name, + s.jurisdiction, + sc.variable AS constraint_variable, + sc.operator AS constraint_operator, + sc.value AS constraint_value + FROM targets AS t + JOIN strata AS s + ON s.id = t.stratum_id + LEFT JOIN stratum_constraints AS sc + ON sc.stratum_id = s.id + WHERE {where_clause} + ORDER BY t.id, sc.variable, sc.operator, sc.value + """ + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row + try: + has_parent_id = _sqlite_table_has_column(conn, "strata", "parent_id") + if has_parent_id: + sql = f""" + WITH target_rows AS ( + SELECT + t.id AS target_id, + t.stratum_id, + t.variable, + t.period, + t.value, + t.target_type, + t.geographic_level, + t.source, + t.source_table, + t.source_url, + t.notes, + s.name AS stratum_name, + s.jurisdiction, + s.parent_id + FROM targets AS t + JOIN strata AS s + ON s.id = t.stratum_id + WHERE {where_clause} + ), + ancestor_strata(target_id, stratum_id, depth) AS ( + SELECT + target_id, + stratum_id, + 0 AS depth + FROM target_rows + UNION ALL + SELECT + a.target_id, + parent.id AS stratum_id, + a.depth + 1 AS depth + FROM ancestor_strata AS a + JOIN strata AS child + ON child.id = a.stratum_id + JOIN strata AS parent + ON parent.id = child.parent_id + WHERE child.parent_id IS NOT NULL + ) + SELECT + tr.target_id, + tr.stratum_id, + tr.variable, + tr.period, + tr.value, + tr.target_type, + tr.geographic_level, + tr.source, + tr.source_table, + tr.source_url, + tr.notes, + tr.stratum_name, + tr.jurisdiction, + sc.variable AS constraint_variable, + sc.operator AS constraint_operator, + sc.value AS constraint_value + FROM target_rows AS tr + LEFT JOIN ancestor_strata AS a + ON a.target_id = tr.target_id + LEFT JOIN stratum_constraints AS sc + ON sc.stratum_id = a.stratum_id + ORDER BY + tr.target_id, + a.depth DESC, + sc.variable, + sc.operator, + sc.value + """ + rows = conn.execute(sql, params).fetchall() + finally: + conn.close() + return _group_arch_target_rows(rows) + + def _compose_model_year_records( + self, + *, + target_year: int, + jurisdiction: str, + sources: tuple[str, ...], + age_soi_targets: bool, + ) -> list[ArchTargetRecord]: + current_records = self.load_records( + period=target_year, + jurisdiction=jurisdiction, + sources=sources, + ) + if sources and _normalize_arch_source("IRS_SOI") not in { + _normalize_arch_source(source) for source in sources + }: + return _with_state_to_national_rollup_records(current_records) + + non_soi_current_records = [ + record for record in current_records if record.source != "IRS_SOI" + ] + soi_records = self._latest_soi_records_by_composition( + target_year=target_year, + jurisdiction=jurisdiction, + ) + if age_soi_targets: + soi_records = self._age_soi_records_by_source_year( + soi_records, + target_year=target_year, + jurisdiction=jurisdiction, + ) + return _with_state_to_national_rollup_records( + [*non_soi_current_records, *soi_records] + ) + + def _latest_soi_records_by_composition( + self, + *, + target_year: int, + jurisdiction: str, + ) -> list[ArchTargetRecord]: + """Return the latest SOI records for each target composition.""" + records = [ + record + for record in self.load_records( + period=None, + jurisdiction=jurisdiction, + sources=("IRS_SOI",), + ) + if record.period <= target_year + ] + latest_period_by_key: dict[ + tuple[str, str, str, tuple[tuple[str, str, str], ...]], + int, + ] = {} + for record in records: + key = _arch_record_composition_key(record) + latest_period_by_key[key] = max( + latest_period_by_key.get(key, record.period), + record.period, + ) + return [ + record + for record in records + if record.period + == latest_period_by_key[_arch_record_composition_key(record)] + ] + + def _age_soi_records_by_source_year( + self, + records: list[ArchTargetRecord], + *, + target_year: int, + jurisdiction: str, + ) -> list[ArchTargetRecord]: + aged: list[ArchTargetRecord] = [] + source_years = sorted({record.period for record in records}) + for source_year in source_years: + source_records = [ + record for record in records if record.period == source_year + ] + if source_year == target_year: + aged.extend(source_records) + else: + aged.extend( + self.age_soi_records( + source_records, + source_year=source_year, + target_year=target_year, + jurisdiction=jurisdiction, + ) + ) + return aged + + def latest_soi_year(self, target_year: int, *, jurisdiction: str) -> int | None: + """Return the latest SOI year at or before the model year.""" + variables = tuple(sorted(ARCH_CURRENT_TAX_VARIABLES)) + placeholders = ", ".join("?" for _ in variables) + sql = f""" + SELECT DISTINCT t.period + FROM targets AS t + JOIN strata AS s + ON s.id = t.stratum_id + WHERE {_jurisdiction_clause(jurisdiction)} + AND t.source = 'IRS_SOI' + AND t.period <= ? + AND t.variable IN ({placeholders}) + ORDER BY t.period DESC + """ + conn = sqlite3.connect(self.db_path) + try: + rows = conn.execute(sql, [int(target_year), *variables]).fetchall() + finally: + conn.close() + return int(rows[0][0]) if rows else None + + def age_soi_records( + self, + records: list[ArchTargetRecord], + *, + source_year: int, + target_year: int, + jurisdiction: str, + ) -> list[ArchTargetRecord]: + """Age SOI records with declared Microplex-side factors.""" + needs_count_factor = any(record.target_type == "COUNT" for record in records) + needs_amount_factor = any(record.target_type == "AMOUNT" for record in records) + factors = self.get_soi_aging_factors( + source_year=source_year, + target_year=target_year, + jurisdiction=jurisdiction, + needs_count_factor=needs_count_factor, + needs_amount_factor=needs_amount_factor, + ) + aged: list[ArchTargetRecord] = [] + for record in records: + if record.source != "IRS_SOI": + aged.append(record) + continue + if record.target_type == "COUNT": + factor = factors.count_factor + elif record.target_type == "AMOUNT": + factor = factors.amount_factor + else: + factor = 1.0 + aged.append( + replace( + record, + value=float(record.value) * factor, + period=target_year, + source_period=record.period, + aging_factors=factors, + ) + ) + return aged + + def get_soi_aging_factors( + self, + *, + source_year: int, + target_year: int, + jurisdiction: str, + needs_count_factor: bool = True, + needs_amount_factor: bool = True, + ) -> SOIAgingFactors: + """Resolve source-backed factors for SOI count and amount targets.""" + if source_year == target_year: + return SOIAgingFactors( + source_year=source_year, + target_year=target_year, + count_factor=1.0, + amount_factor=1.0, + count_method="identity", + amount_method="identity", + ) + if needs_count_factor: + source_labor_force = self._target_value( + year=source_year, + jurisdiction=jurisdiction, + source="BLS", + variable="labor_force_count", + ) + target_labor_force, count_method = self._labor_force_for_year( + year=target_year, + jurisdiction=jurisdiction, + ) + count_factor = target_labor_force / source_labor_force + else: + count_factor = 1.0 + count_method = "not_required" + + if needs_amount_factor: + source_agi = self._soi_total_agi( + year=source_year, jurisdiction=jurisdiction + ) + target_agi, amount_method = self._soi_total_agi_for_year( + target_year=target_year, + jurisdiction=jurisdiction, + ) + amount_factor = target_agi / source_agi + else: + amount_factor = 1.0 + amount_method = "not_required" + + return SOIAgingFactors( + source_year=source_year, + target_year=target_year, + count_factor=count_factor, + amount_factor=amount_factor, + count_method=count_method, + amount_method=amount_method, + ) + + def _labor_force_for_year( + self, + *, + year: int, + jurisdiction: str, + ) -> tuple[float, str]: + bls_value = self._optional_target_value( + year=year, + jurisdiction=jurisdiction, + source="BLS", + variable="labor_force_count", + ) + if bls_value is not None: + return bls_value, "bls_labor_force_ratio" + cbo_value = self._optional_target_value( + year=year, + jurisdiction=jurisdiction, + source="CBO", + variable="labor_force", + ) + if cbo_value is not None: + return cbo_value, "cbo_labor_force_ratio" + raise ValueError(f"No BLS/CBO labor-force target found for {year}.") + + def _soi_total_agi_for_year( + self, + *, + target_year: int, + jurisdiction: str, + ) -> tuple[float, str]: + target_agi = self._optional_soi_total_agi( + year=target_year, + jurisdiction=jurisdiction, + ) + if target_agi is not None: + return target_agi, "soi_total_agi_ratio" + + available = { + year: value + for year in range(target_year - 20, target_year + 1) + if ( + value := self._optional_soi_total_agi( + year=year, + jurisdiction=jurisdiction, + ) + ) + is not None + } + if len(available) < 2: + raise ValueError( + "Need at least two SOI total AGI years to extrapolate " + f"aggregate income to {target_year}." + ) + latest_year = max(available) + previous_year = max(year for year in available if year < latest_year) + annual_growth = available[latest_year] / available[previous_year] + years_forward = target_year - latest_year + return ( + available[latest_year] * annual_growth**years_forward, + "soi_total_agi_last_growth_extrapolation", + ) + + def _soi_total_agi(self, *, year: int, jurisdiction: str) -> float: + value = self._optional_soi_total_agi(year=year, jurisdiction=jurisdiction) + if value is None: + raise ValueError(f"No SOI total AGI target found for {year}.") + return value + + def _optional_soi_total_agi(self, *, year: int, jurisdiction: str) -> float | None: + records = self.load_records( + period=year, + jurisdiction=jurisdiction, + sources=("IRS_SOI",), + ) + for record in records: + if ( + record.variable == "adjusted_gross_income" + and record.stratum_name == "US All Filers" + ): + return float(record.value) + for record in records: + if record.variable == "adjusted_gross_income" and record.constraints == ( + ("is_tax_filer", "==", "1"), + ): + return float(record.value) + return None + + def _target_value( + self, + *, + year: int, + jurisdiction: str, + source: str, + variable: str, + ) -> float: + value = self._optional_target_value( + year=year, + jurisdiction=jurisdiction, + source=source, + variable=variable, + ) + if value is None: + raise ValueError(f"No {source} {variable} target found for {year}.") + return value + + def _optional_target_value( + self, + *, + year: int, + jurisdiction: str, + source: str, + variable: str, + ) -> float | None: + records = self.load_records( + period=year, + jurisdiction=jurisdiction, + sources=(source,), + ) + matching = [record for record in records if record.variable == variable] + if not matching: + return None + unconstrained = [record for record in matching if not record.constraints] + if len(unconstrained) == 1: + return float(unconstrained[0].value) + return float(matching[0].value) + + +class ArchFactSQLiteTargetProvider: + """Read Arch aggregate facts and expose Microplex canonical targets.""" + + def __init__( + self, + db_path: str | Path, + *, + jurisdiction: str = "us", + compose_model_year_targets: bool = True, + age_soi_targets: bool = True, + ) -> None: + self.db_path = Path(db_path) + self.jurisdiction = jurisdiction + self.compose_model_year_targets = compose_model_year_targets + self.age_soi_targets = age_soi_targets + + def load_target_set(self, query: TargetQuery | None = None) -> TargetSet: + """Load canonical targets from Arch aggregate fact tables.""" + if not self.db_path.exists(): + raise FileNotFoundError(f"Arch facts DB not found: {self.db_path}") + + query = query or TargetQuery() + provider_filters = dict(query.provider_filters) + period = query.period if isinstance(query.period, int) else None + variables = _as_string_tuple(provider_filters.get("variables")) + domain_variables = _as_string_tuple(provider_filters.get("domain_variables")) + sources = _as_string_tuple(provider_filters.get("sources")) + geo_levels = _as_string_tuple(provider_filters.get("geo_levels")) + target_cells = _as_target_cell_filters(provider_filters.get("target_cells")) + entity_overrides = provider_filters.get("entity_overrides") or {} + compose_model_year_targets = bool( + provider_filters.get( + "compose_model_year_targets", + self.compose_model_year_targets, + ) + ) + age_soi_targets = bool( + provider_filters.get("age_soi_targets", self.age_soi_targets) + ) + + records = ( + self._compose_model_year_records( + target_year=period, + sources=sources, + age_soi_targets=age_soi_targets, + ) + if compose_model_year_targets and period is not None + else self.load_records(period=period, sources=sources) + ) + canonical_targets = TargetSet( + [ + target + for record in records + if _matches_arch_provider_filters( + record, + variables=variables, + domain_variables=domain_variables, + geo_levels=geo_levels, + target_cells=target_cells, + entity_overrides=entity_overrides, + ) + for target in [ + arch_target_record_to_canonical_spec( + record, + entity_overrides=entity_overrides, + ) + ] + if target is not None + ] + ) + return apply_target_query( + canonical_targets, + TargetQuery( + period=query.period, + entity=query.entity, + names=query.names, + metadata_filters=query.metadata_filters, + ), + ) + + def load_records( + self, + *, + period: int | None = None, + sources: tuple[str, ...] = (), + ) -> list[ArchTargetRecord]: + """Load Arch fact rows with attached fact constraints and lineage.""" + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row + try: + clauses = ["1 = 1"] + params: list[Any] = [] + if period is not None: + clauses.append("CAST(af.period_value AS INTEGER) = ?") + params.append(int(period)) + where_clause = " AND ".join(clauses) + rows = conn.execute( + f""" + SELECT + af.fact_key, + af.source_record_id, + af.value_numeric, + af.value_text, + af.value_json, + af.period_value, + af.geography_level, + af.geography_id, + af.geography_name, + af.measure_concept, + af.measure_source_concept, + af.measure_concept_relation, + af.measure_concept_authority, + af.measure_concept_evidence_url, + af.measure_concept_evidence_notes, + af.measure_legal_vintage, + af.measure_unit, + af.aggregation_method, + af.domain, + af.filters_json, + af.label, + af.source_name, + af.source_table, + af.source_url, + af.source_method_notes, + ac.ordinal AS constraint_ordinal, + ac.variable AS constraint_variable, + ac.operator AS constraint_operator, + ac.value_text AS constraint_value_text, + ac.value_numeric AS constraint_value_numeric, + ac.value_json AS constraint_value_json + FROM aggregate_facts AS af + LEFT JOIN aggregate_constraints AS ac + ON ac.fact_key = af.fact_key + WHERE {where_clause} + ORDER BY af.fact_key, ac.ordinal + """, + params, + ).fetchall() + lineage = _load_arch_fact_lineage(conn) + finally: + conn.close() + + records = _group_arch_fact_rows(rows, lineage=lineage) + if sources: + normalized_sources = {_normalize_arch_source(source) for source in sources} + records = [ + record + for record in records + if _normalize_arch_source(record.source) in normalized_sources + ] + return records + + def _compose_model_year_records( + self, + *, + target_year: int, + sources: tuple[str, ...], + age_soi_targets: bool, + ) -> list[ArchTargetRecord]: + return _compose_arch_model_year_records( + self.load_records(period=None, sources=()), + target_year=target_year, + sources=sources, + age_soi_targets=age_soi_targets, + ) + + +class ArchConsumerFactJSONLTargetProvider: + """Read Arch consumer-contract JSONL facts as Microplex targets.""" + + schema_version = "arch.consumer_fact.v1" + + def __init__( + self, + path: str | Path, + *, + jurisdiction: str = "us", + compose_model_year_targets: bool = True, + age_soi_targets: bool = True, + ) -> None: + self.path = Path(path) + self.jurisdiction = jurisdiction + self.compose_model_year_targets = compose_model_year_targets + self.age_soi_targets = age_soi_targets + + def load_target_set(self, query: TargetQuery | None = None) -> TargetSet: + """Load canonical targets from Arch consumer-contract JSONL.""" + if not self.path.exists(): + raise FileNotFoundError(f"Arch consumer facts JSONL not found: {self.path}") + + query = query or TargetQuery() + provider_filters = dict(query.provider_filters) + period = query.period if isinstance(query.period, int) else None + variables = _as_string_tuple(provider_filters.get("variables")) + domain_variables = _as_string_tuple(provider_filters.get("domain_variables")) + sources = _as_string_tuple(provider_filters.get("sources")) + geo_levels = _as_string_tuple(provider_filters.get("geo_levels")) + target_cells = _as_target_cell_filters(provider_filters.get("target_cells")) + entity_overrides = provider_filters.get("entity_overrides") or {} + compose_model_year_targets = bool( + provider_filters.get( + "compose_model_year_targets", + self.compose_model_year_targets, + ) + ) + age_soi_targets = bool( + provider_filters.get("age_soi_targets", self.age_soi_targets) + ) + + records = ( + self._compose_model_year_records( + target_year=period, + sources=sources, + age_soi_targets=age_soi_targets, + ) + if compose_model_year_targets and period is not None + else self.load_records(period=period, sources=sources) + ) + canonical_targets = TargetSet( + [ + target + for record in records + if _matches_arch_provider_filters( + record, + variables=variables, + domain_variables=domain_variables, + geo_levels=geo_levels, + target_cells=target_cells, + entity_overrides=entity_overrides, + ) + for target in [ + arch_target_record_to_canonical_spec( + record, + entity_overrides=entity_overrides, + ) + ] + if target is not None + ] + ) + return apply_target_query( + canonical_targets, + TargetQuery( + period=query.period, + entity=query.entity, + names=query.names, + metadata_filters=query.metadata_filters, + ), + ) + + def load_records( + self, + *, + period: int | None = None, + sources: tuple[str, ...] = (), + ) -> list[ArchTargetRecord]: + """Load Arch consumer-contract fact rows.""" + if not self.path.exists(): + raise FileNotFoundError(f"Arch consumer facts JSONL not found: {self.path}") + + rows = list( + load_arch_consumer_fact_jsonl_rows( + (self.path,), + period=period, + schema_version=self.schema_version, + ) + ) + + records = _consumer_fact_rows_to_records(rows) + if sources: + normalized_sources = {_normalize_arch_source(source) for source in sources} + records = [ + record + for record in records + if _normalize_arch_source(record.source) in normalized_sources + ] + return records + + def _compose_model_year_records( + self, + *, + target_year: int, + sources: tuple[str, ...], + age_soi_targets: bool, + ) -> list[ArchTargetRecord]: + return _compose_arch_model_year_records( + self.load_records(period=None, sources=()), + target_year=target_year, + sources=sources, + age_soi_targets=age_soi_targets, + ) + + +class ArchCompositeSQLiteTargetProvider: + """Compose multiple Arch SQLite artifacts into one target provider.""" + + def __init__( + self, + db_paths: tuple[str | Path, ...], + *, + jurisdiction: str = "us", + compose_model_year_targets: bool = True, + age_soi_targets: bool = True, + ) -> None: + paths = tuple(Path(path) for path in db_paths) + if not paths: + raise ValueError("At least one Arch targets DB path is required") + self.db_paths = paths + self.path = tuple(str(path) for path in paths) + self.jurisdiction = jurisdiction + self.compose_model_year_targets = compose_model_year_targets + self.age_soi_targets = age_soi_targets + self.providers = tuple( + resolve_arch_sqlite_target_provider( + path, + jurisdiction=jurisdiction, + compose_model_year_targets=compose_model_year_targets, + age_soi_targets=age_soi_targets, + ) + for path in paths + ) + + def load_target_set(self, query: TargetQuery | None = None) -> TargetSet: + """Load and renumber targets across all configured Arch artifacts.""" + query = query or TargetQuery() + provider_filters = dict(query.provider_filters) + period = query.period if isinstance(query.period, int) else None + variables = _as_string_tuple(provider_filters.get("variables")) + domain_variables = _as_string_tuple(provider_filters.get("domain_variables")) + sources = _as_string_tuple(provider_filters.get("sources")) + geo_levels = _as_string_tuple(provider_filters.get("geo_levels")) + target_cells = _as_target_cell_filters(provider_filters.get("target_cells")) + entity_overrides = provider_filters.get("entity_overrides") or {} + compose_model_year_targets = bool( + provider_filters.get( + "compose_model_year_targets", + self.compose_model_year_targets, + ) + ) + age_soi_targets = bool( + provider_filters.get("age_soi_targets", self.age_soi_targets) + ) + + records = self.load_records( + period=period, + sources=sources, + compose_model_year_targets=compose_model_year_targets, + age_soi_targets=age_soi_targets, + ) + stratum_ids: dict[tuple[tuple[str, str, Any], ...], int] = {} + targets: list[CanonicalTargetSpec] = [] + for record in records: + if not _matches_arch_provider_filters( + record, + variables=variables, + domain_variables=domain_variables, + geo_levels=geo_levels, + target_cells=target_cells, + entity_overrides=entity_overrides, + ): + continue + target = arch_target_record_to_canonical_spec( + record, + entity_overrides=entity_overrides, + ) + if target is None: + continue + metadata = dict(target.metadata) + metadata["stratum_id"] = stratum_ids.setdefault( + _target_filter_tuple(target), + len(stratum_ids) + 1, + ) + targets.append( + replace( + target, + name=f"arch_target_{metadata['target_id']}", + metadata=metadata, + ) + ) + return apply_target_query( + TargetSet(targets), + TargetQuery( + period=query.period, + entity=query.entity, + names=query.names, + metadata_filters=query.metadata_filters, + ), + ) + + def load_records( + self, + *, + period: int | None = None, + sources: tuple[str, ...] = (), + compose_model_year_targets: bool | None = None, + age_soi_targets: bool | None = None, + ) -> list[ArchTargetRecord]: + """Load and renumber raw records across configured Arch artifacts.""" + records = self._load_all_child_records() + resolved_compose = ( + self.compose_model_year_targets + if compose_model_year_targets is None + else compose_model_year_targets + ) + resolved_age_soi = ( + self.age_soi_targets if age_soi_targets is None else age_soi_targets + ) + if resolved_compose and period is not None: + records = _compose_arch_model_year_records( + records, + target_year=period, + sources=sources, + age_soi_targets=resolved_age_soi, + ) + else: + records = [ + record + for record in records + if (period is None or record.period == period) + and _record_matches_sources(record, sources) + ] + return _renumber_arch_records(records) + + def _load_all_child_records(self) -> list[ArchTargetRecord]: + records: list[ArchTargetRecord] = [] + seen_fact_keys: set[str] = set() + for source_index, (path, provider) in enumerate( + zip(self.db_paths, self.providers, strict=True), + start=1, + ): + provider_records = _load_arch_provider_raw_records( + provider, + jurisdiction=self.jurisdiction, + ) + for record in provider_records: + if record.aggregate_fact_key is not None: + if record.aggregate_fact_key in seen_fact_keys: + continue + seen_fact_keys.add(record.aggregate_fact_key) + records.append( + replace( + record, + source_db_path=str(path), + source_db_index=source_index, + source_target_id=record.source_target_id or record.target_id, + source_stratum_id=( + record.source_stratum_id or record.stratum_id + ), + ) + ) + return records + + +def _load_arch_provider_raw_records( + provider: ( + ArchSQLiteTargetProvider + | ArchFactSQLiteTargetProvider + | ArchConsumerFactJSONLTargetProvider + | ArchCompositeSQLiteTargetProvider + ), + *, + jurisdiction: str, +) -> list[ArchTargetRecord]: + if isinstance( + provider, + (ArchFactSQLiteTargetProvider, ArchConsumerFactJSONLTargetProvider), + ): + return provider.load_records(period=None, sources=()) + if isinstance(provider, ArchCompositeSQLiteTargetProvider): + return provider._load_all_child_records() + return provider.load_records(period=None, jurisdiction=jurisdiction, sources=()) + + +def _compose_arch_model_year_records( + records: list[ArchTargetRecord], + *, + target_year: int, + sources: tuple[str, ...], + age_soi_targets: bool, +) -> list[ArchTargetRecord]: + current_records = [ + record + for record in records + if record.period == target_year and _record_matches_sources(record, sources) + ] + normalized_sources = {_normalize_arch_source(source) for source in sources} + if sources and _normalize_arch_source("IRS_SOI") not in normalized_sources: + return _with_state_to_national_rollup_records(current_records) + + non_soi_current_records = [ + record + for record in current_records + if _normalize_arch_source(record.source) != "IRS_SOI" + ] + soi_records = _latest_soi_records_by_composition( + records, + target_year=target_year, + ) + if age_soi_targets: + soi_records = _age_arch_soi_records_by_source_year( + soi_records, + target_year=target_year, + reference_records=records, + ) + else: + soi_records = [ + _carry_forward_arch_record_to_model_year(record, target_year=target_year) + for record in soi_records + ] + return _with_state_to_national_rollup_records( + [*non_soi_current_records, *soi_records] + ) + + +def _with_state_to_national_rollup_records( + records: list[ArchTargetRecord], +) -> list[ArchTargetRecord]: + rollups = _state_to_national_rollup_records(records) + if not rollups: + return records + return [*records, *rollups] + + +def _state_to_national_rollup_records( + records: list[ArchTargetRecord], +) -> list[ArchTargetRecord]: + existing_national_keys = { + key + for record in records + if _arch_record_geo_level(record) == "national" + for key in [_state_rollup_group_key(record)] + if key is not None + } + grouped: dict[tuple[Any, ...], list[tuple[str, ArchTargetRecord]]] = {} + for record in records: + if _arch_record_geo_level(record) != "state": + continue + key = _state_rollup_group_key(record) + if key is None or key in existing_national_keys: + continue + state_fips = _arch_record_state_fips(record) + if state_fips is None or state_fips not in ARCH_NATIONAL_ROLLUP_STATE_FIPS: + continue + grouped.setdefault(key, []).append((state_fips, record)) + + rollups: list[ArchTargetRecord] = [] + for key, state_records in grouped.items(): + records_by_state: dict[str, ArchTargetRecord] = {} + for state_fips, record in state_records: + if state_fips in records_by_state: + records_by_state = {} + break + records_by_state[state_fips] = record + if set(records_by_state) != ARCH_NATIONAL_ROLLUP_STATE_FIPS: + continue + ordered_records = [ + records_by_state[state_fips] + for state_fips in sorted(ARCH_NATIONAL_ROLLUP_STATE_FIPS) + ] + rollups.append( + _state_records_to_national_rollup_record( + key, + ordered_records, + ) + ) + return rollups + + +def _state_rollup_group_key(record: ArchTargetRecord) -> tuple[Any, ...] | None: + if record.variable not in ARCH_STATE_TO_NATIONAL_ROLLUP_VARIABLES: + return None + return ( + _normalize_arch_source(record.source), + record.source_table, + record.source_url, + record.variable, + record.target_type, + record.period, + record.source_period, + record.aging_factors, + record.unit, + record.concept, + record.source_concept, + record.concept_relation, + record.concept_authority, + record.legal_vintage, + _non_state_constraints(record.constraints), + ) + + +def _state_records_to_national_rollup_record( + key: tuple[Any, ...], + records: list[ArchTargetRecord], +) -> ArchTargetRecord: + first = records[0] + digest = sha1(repr(key).encode("utf-8")).hexdigest() + source_row_keys = tuple( + dict.fromkeys( + source_row_key + for record in records + for source_row_key in ( + record.source_row_keys + or (str(record.source_target_id or record.target_id),) + ) + ) + ) + source_cell_keys = tuple( + dict.fromkeys( + source_cell_key + for record in records + for source_cell_key in record.source_cell_keys + ) + ) + notes = "Microplex national rollup from 51 state targets." + if first.notes: + notes = f"{first.notes} {notes}" + return replace( + first, + target_id=-int(digest[:12], 16), + stratum_id=-int(digest[12:20], 16), + value=sum(record.value for record in records), + geographic_level=None, + geography_id=None, + stratum_name="US National Rollup", + constraints=_non_state_constraints(first.constraints), + notes=notes, + source_record_id=f"microplex_state_rollup:{digest[:16]}", + source_cell_keys=source_cell_keys, + source_row_keys=source_row_keys, + source_target_id=None, + source_stratum_id=None, + ) + + +def _non_state_constraints( + constraints: tuple[tuple[str, str, str], ...], +) -> tuple[tuple[str, str, str], ...]: + return tuple( + constraint for constraint in constraints if constraint[0] != "state_fips" + ) + + +def _arch_record_state_fips(record: ArchTargetRecord) -> str | None: + for variable, operator, value in record.constraints: + if variable != "state_fips": + continue + if _canonical_arch_constraint_operator(operator) != "==": + continue + try: + return str(int(float(value))).zfill(2) + except (TypeError, ValueError): + return str(value).zfill(2) + if _normalize_geo_level(record.geographic_level) == "state": + geography_id = record.geography_id + if geography_id is not None: + return _state_fips_from_arch_geography_id(geography_id) + return None + + +def _latest_soi_records_by_composition( + records: list[ArchTargetRecord], + *, + target_year: int, +) -> list[ArchTargetRecord]: + candidates = [ + record + for record in records + if _normalize_arch_source(record.source) == "IRS_SOI" + and record.period <= target_year + ] + latest_period_by_key: dict[ + tuple[str, str, str, tuple[tuple[str, str, str], ...]], + int, + ] = {} + for record in candidates: + key = _arch_record_composition_key(record) + latest_period_by_key[key] = max( + latest_period_by_key.get(key, record.period), + record.period, + ) + return [ + record + for record in candidates + if record.period == latest_period_by_key[_arch_record_composition_key(record)] + ] + + +def _age_arch_soi_records_by_source_year( + records: list[ArchTargetRecord], + *, + target_year: int, + reference_records: list[ArchTargetRecord], +) -> list[ArchTargetRecord]: + aged: list[ArchTargetRecord] = [] + for source_year in sorted({record.period for record in records}): + source_records = [record for record in records if record.period == source_year] + if source_year == target_year: + aged.extend(source_records) + continue + needs_count_factor = any( + record.target_type == "COUNT" for record in source_records + ) + needs_amount_factor = any( + record.target_type == "AMOUNT" for record in source_records + ) + factors = _arch_record_soi_aging_factors( + reference_records, + source_year=source_year, + target_year=target_year, + needs_count_factor=needs_count_factor, + needs_amount_factor=needs_amount_factor, + ) + for record in source_records: + factor = 1.0 + if record.target_type == "COUNT": + factor = factors.count_factor + elif record.target_type == "AMOUNT": + factor = factors.amount_factor + aged.append( + replace( + record, + value=float(record.value) * factor, + period=target_year, + source_period=record.period, + aging_factors=factors, + ) + ) + return aged + + +def _carry_forward_arch_record_to_model_year( + record: ArchTargetRecord, + *, + target_year: int, +) -> ArchTargetRecord: + if record.period == target_year: + return record + return replace(record, period=target_year, source_period=record.period) + + +def _arch_record_soi_aging_factors( + records: list[ArchTargetRecord], + *, + source_year: int, + target_year: int, + needs_count_factor: bool, + needs_amount_factor: bool, +) -> SOIAgingFactors: + if source_year == target_year: + return SOIAgingFactors( + source_year=source_year, + target_year=target_year, + count_factor=1.0, + amount_factor=1.0, + count_method="identity", + amount_method="identity", + ) + + if needs_count_factor: + count_factor, count_method = _arch_record_soi_count_aging_factor( + records, + source_year=source_year, + target_year=target_year, + ) + else: + count_factor = 1.0 + count_method = "not_required" + + if needs_amount_factor: + amount_factor, amount_method = _arch_record_soi_amount_aging_factor( + records, + source_year=source_year, + target_year=target_year, + ) + else: + amount_factor = 1.0 + amount_method = "not_required" + + return SOIAgingFactors( + source_year=source_year, + target_year=target_year, + count_factor=count_factor, + amount_factor=amount_factor, + count_method=count_method, + amount_method=amount_method, + ) + + +def _arch_record_soi_count_aging_factor( + records: list[ArchTargetRecord], + *, + source_year: int, + target_year: int, +) -> tuple[float, str]: + source_labor_force = _optional_arch_total_value( + records, + year=source_year, + source="BLS", + variable="labor_force_count", + ) + target_labor_force, labor_force_method = _optional_arch_labor_force_for_year( + records, + year=target_year, + ) + if source_labor_force is not None and target_labor_force is not None: + return target_labor_force / source_labor_force, labor_force_method + + source_count = _optional_arch_soi_total_value( + records, + year=source_year, + variable="tax_unit_count", + ) + target_count, count_method = _arch_soi_total_for_year( + records, + target_year=target_year, + variable="tax_unit_count", + exact_method="soi_total_return_count_ratio", + extrapolation_method="soi_total_return_count_last_growth_extrapolation", + ) + if source_count is not None and target_count is not None: + return target_count / source_count, count_method + return 1.0, "source_fact_carry_forward_no_count_reference" + + +def _arch_record_soi_amount_aging_factor( + records: list[ArchTargetRecord], + *, + source_year: int, + target_year: int, +) -> tuple[float, str]: + source_agi = _optional_arch_soi_total_value( + records, + year=source_year, + variable="adjusted_gross_income", + ) + target_agi, amount_method = _arch_soi_total_for_year( + records, + target_year=target_year, + variable="adjusted_gross_income", + exact_method="soi_total_agi_ratio", + extrapolation_method="soi_total_agi_last_growth_extrapolation", + ) + if source_agi is not None and target_agi is not None: + return target_agi / source_agi, amount_method + return 1.0, "source_fact_carry_forward_no_amount_reference" + + +def _optional_arch_labor_force_for_year( + records: list[ArchTargetRecord], + *, + year: int, +) -> tuple[float | None, str]: + bls_value = _optional_arch_total_value( + records, + year=year, + source="BLS", + variable="labor_force_count", + ) + if bls_value is not None: + return bls_value, "bls_labor_force_ratio" + cbo_value = _optional_arch_total_value( + records, + year=year, + source="CBO", + variable="labor_force", + ) + if cbo_value is not None: + return cbo_value, "cbo_labor_force_ratio" + return None, "source_fact_carry_forward_no_labor_force_reference" + + +def _arch_soi_total_for_year( + records: list[ArchTargetRecord], + *, + target_year: int, + variable: str, + exact_method: str, + extrapolation_method: str, +) -> tuple[float | None, str]: + exact = _optional_arch_soi_total_value( + records, + year=target_year, + variable=variable, + ) + if exact is not None: + return exact, exact_method + + available = { + year: value + for year in sorted({record.period for record in records}) + if year <= target_year + for value in [ + _optional_arch_soi_total_value( + records, + year=year, + variable=variable, + ) + ] + if value is not None + } + if len(available) < 2: + return None, f"source_fact_carry_forward_no_{variable}_reference" + latest_year = max(available) + previous_year = max(year for year in available if year < latest_year) + annual_growth = available[latest_year] / available[previous_year] + years_forward = target_year - latest_year + return available[latest_year] * annual_growth**years_forward, extrapolation_method + + +def _optional_arch_soi_total_value( + records: list[ArchTargetRecord], + *, + year: int, + variable: str, +) -> float | None: + return _optional_arch_total_value( + records, + year=year, + source="IRS_SOI", + variable=variable, + require_total_scope=True, + ) + + +def _optional_arch_total_value( + records: list[ArchTargetRecord], + *, + year: int, + source: str, + variable: str, + require_total_scope: bool = False, +) -> float | None: + matches = [ + record + for record in records + if record.period == year + and _normalize_arch_source(record.source) == _normalize_arch_source(source) + and record.variable == variable + ] + if require_total_scope: + matches = [record for record in matches if _arch_record_is_total_scope(record)] + if not matches: + return None + return float(matches[0].value) + + +def _arch_record_is_total_scope(record: ArchTargetRecord) -> bool: + if not record.constraints: + return True + if tuple(record.constraints) == (("is_tax_filer", "==", "1"),): + return True + if tuple(record.constraints) == (("tax_unit_is_filer", "==", "1"),): + return True + return str(record.stratum_name or "").lower().endswith("all filers") + + +def _record_matches_sources( + record: ArchTargetRecord, + sources: tuple[str, ...], +) -> bool: + if not sources: + return True + normalized_sources = {_normalize_arch_source(source) for source in sources} + return _normalize_arch_source(record.source) in normalized_sources + + +def _renumber_arch_records(records: list[ArchTargetRecord]) -> list[ArchTargetRecord]: + renumbered: list[ArchTargetRecord] = [] + stratum_ids: dict[tuple[tuple[str, str, str], ...], int] = {} + for record in records: + renumbered.append( + replace( + record, + target_id=len(renumbered) + 1, + stratum_id=stratum_ids.setdefault( + record.constraints, + len(stratum_ids) + 1, + ), + ) + ) + return renumbered + + +def resolve_arch_sqlite_target_provider( + db_path: str | Path | tuple[str | Path, ...], + *, + jurisdiction: str = "us", + compose_model_year_targets: bool = True, + age_soi_targets: bool = True, +) -> ( + ArchSQLiteTargetProvider + | ArchFactSQLiteTargetProvider + | ArchConsumerFactJSONLTargetProvider + | ArchCompositeSQLiteTargetProvider +): + """Return the Arch provider matching a source artifact's schema.""" + paths = _as_arch_db_path_tuple(db_path) + if len(paths) > 1: + return ArchCompositeSQLiteTargetProvider( + paths, + jurisdiction=jurisdiction, + compose_model_year_targets=compose_model_year_targets, + age_soi_targets=age_soi_targets, + ) + path = paths[0] + if not path.exists(): + raise FileNotFoundError(f"Arch targets DB not found: {path}") + if _looks_like_arch_consumer_fact_jsonl(path): + return ArchConsumerFactJSONLTargetProvider( + path, + jurisdiction=jurisdiction, + compose_model_year_targets=compose_model_year_targets, + age_soi_targets=age_soi_targets, + ) + conn = sqlite3.connect(path) + try: + if _sqlite_table_exists(conn, "aggregate_facts"): + return ArchFactSQLiteTargetProvider( + path, + jurisdiction=jurisdiction, + compose_model_year_targets=compose_model_year_targets, + age_soi_targets=age_soi_targets, + ) + finally: + conn.close() + return ArchSQLiteTargetProvider( + path, + jurisdiction=jurisdiction, + compose_model_year_targets=compose_model_year_targets, + age_soi_targets=age_soi_targets, + ) + + +def summarize_arch_target_profile_coverage( + provider: ( + ArchSQLiteTargetProvider + | ArchFactSQLiteTargetProvider + | ArchConsumerFactJSONLTargetProvider + | ArchCompositeSQLiteTargetProvider + ), + *, + period: int, + profile_name: str = "pe_native_broad", + target_cells: tuple[PolicyEngineUSTargetCell | dict[str, Any], ...] | None = None, + sources: tuple[str, ...] = (), + jurisdiction: str | None = None, + compose_model_year_targets: bool | None = None, + age_soi_targets: bool | None = None, + entity_overrides: dict[str, Any] | None = None, + provider_filters: dict[str, Any] | None = None, +) -> ArchTargetProfileCoverageReport: + """Summarize how much of a Microplex target profile Arch can satisfy.""" + + resolved_cells = ( + tuple(target_cells) + if target_cells is not None + else resolve_policyengine_us_target_profile(profile_name) + ) + cell_filters = tuple( + _target_cell_to_provider_filter(cell) for cell in resolved_cells + ) + query_filters: dict[str, Any] = dict(provider_filters or {}) + query_filters["target_profile"] = profile_name + query_filters["target_cells"] = [dict(cell) for cell in cell_filters] + if sources: + query_filters["sources"] = list(sources) + if jurisdiction is not None: + query_filters["jurisdiction"] = jurisdiction + if compose_model_year_targets is not None: + query_filters["compose_model_year_targets"] = compose_model_year_targets + if age_soi_targets is not None: + query_filters["age_soi_targets"] = age_soi_targets + if entity_overrides is not None: + query_filters["entity_overrides"] = entity_overrides + + target_set = provider.load_target_set( + TargetQuery(period=period, provider_filters=query_filters) + ) + coverage_cells = tuple( + _coverage_for_arch_target_cell(cell_filter, target_set) + for cell_filter in cell_filters + ) + target_cell_count = len(coverage_cells) + covered_cell_count = sum(1 for cell in coverage_cells if cell.covered) + uncovered_cell_count = target_cell_count - covered_cell_count + coverage_rate = covered_cell_count / target_cell_count if target_cell_count else 0.0 + return ArchTargetProfileCoverageReport( + profile_name=profile_name, + period=int(period), + target_cell_count=target_cell_count, + covered_cell_count=covered_cell_count, + uncovered_cell_count=uncovered_cell_count, + coverage_rate=coverage_rate, + by_geo_level=_summarize_arch_cell_coverage(coverage_cells, field="geo_level"), + by_variable=_summarize_arch_cell_coverage(coverage_cells, field="variable"), + cells=coverage_cells, + ) + + +def summarize_arch_target_gap_queue( + provider: ( + ArchSQLiteTargetProvider + | ArchFactSQLiteTargetProvider + | ArchConsumerFactJSONLTargetProvider + | ArchCompositeSQLiteTargetProvider + ), + *, + period: int, + profile_name: str = "pe_native_broad", + include_covered: bool = False, + target_cells: tuple[PolicyEngineUSTargetCell | dict[str, Any], ...] | None = None, + sources: tuple[str, ...] = (), + jurisdiction: str | None = None, + compose_model_year_targets: bool | None = None, + age_soi_targets: bool | None = None, + entity_overrides: dict[str, Any] | None = None, + provider_filters: dict[str, Any] | None = None, +) -> ArchTargetGapQueueReport: + """Build an agent-facing queue of Arch target records to add or review.""" + + coverage = summarize_arch_target_profile_coverage( + provider, + period=period, + profile_name=profile_name, + target_cells=target_cells, + sources=sources, + jurisdiction=jurisdiction, + compose_model_year_targets=compose_model_year_targets, + age_soi_targets=age_soi_targets, + entity_overrides=entity_overrides, + provider_filters=provider_filters, + ) + catalog = _arch_gap_loaded_variable_catalog( + provider, + period=period, + jurisdiction=jurisdiction, + sources=sources, + compose_model_year_targets=compose_model_year_targets, + age_soi_targets=age_soi_targets, + ) + variable_uncovered_counts = { + variable: counts["uncovered_cell_count"] + for variable, counts in coverage.by_variable.items() + } + rows = [ + _arch_gap_queue_row_for_coverage_cell( + coverage_cell, + profile_name=profile_name, + period=period, + loaded_variable_catalog=catalog, + variable_uncovered_count=variable_uncovered_counts.get( + str(coverage_cell.cell.get("variable") or ""), + 0, + ), + ) + for coverage_cell in coverage.cells + if include_covered or not coverage_cell.covered + ] + rows = [ + replace(row, priority=priority) + for priority, row in enumerate( + sorted(rows, key=_arch_gap_queue_sort_key), + start=1, + ) + ] + by_loader_status: dict[str, int] = {} + by_gap_category: dict[str, int] = {} + for row in rows: + by_loader_status[row.loader_status] = ( + by_loader_status.get(row.loader_status, 0) + 1 + ) + by_gap_category[row.gap_category] = by_gap_category.get(row.gap_category, 0) + 1 + covered_row_count = sum(1 for row in rows if row.covered) + return ArchTargetGapQueueReport( + profile_name=profile_name, + period=int(period), + row_count=len(rows), + covered_row_count=covered_row_count, + uncovered_row_count=len(rows) - covered_row_count, + by_loader_status=dict(sorted(by_loader_status.items())), + by_gap_category=dict(sorted(by_gap_category.items())), + rows=tuple(rows), + ) + + +def summarize_arch_target_parity( + incumbent_provider: ( + ArchSQLiteTargetProvider + | ArchFactSQLiteTargetProvider + | ArchConsumerFactJSONLTargetProvider + | ArchCompositeSQLiteTargetProvider + ), + candidate_provider: ( + ArchSQLiteTargetProvider + | ArchFactSQLiteTargetProvider + | ArchConsumerFactJSONLTargetProvider + | ArchCompositeSQLiteTargetProvider + ), + *, + period: int, + sources: tuple[str, ...] = (), + variables: tuple[str, ...] = (), + value_abs_tolerance: float = 1e-6, + value_rel_tolerance: float = 1e-12, +) -> ArchTargetParityReport: + """Compare canonical Microplex targets loaded from two Arch artifacts.""" + provider_filters: dict[str, Any] = {} + if sources: + provider_filters["sources"] = tuple(sources) + if variables: + provider_filters["variables"] = tuple(variables) + + query = TargetQuery(period=period, provider_filters=provider_filters) + incumbent_targets = list(incumbent_provider.load_target_set(query).targets) + candidate_targets = list(candidate_provider.load_target_set(query).targets) + rows = _arch_target_parity_rows( + incumbent_targets, + candidate_targets, + value_abs_tolerance=value_abs_tolerance, + value_rel_tolerance=value_rel_tolerance, + ) + errors = tuple( + _arch_target_parity_error(row) for row in rows if row.status != "matched" + ) + counts = { + "incumbent_target_count": len(incumbent_targets), + "candidate_target_count": len(candidate_targets), + "matched_count": sum(1 for row in rows if row.status == "matched"), + "value_mismatch_count": sum( + 1 for row in rows if row.status == "value_mismatch" + ), + "incumbent_only_count": sum( + 1 for row in rows if row.status == "incumbent_only" + ), + "candidate_only_count": sum( + 1 for row in rows if row.status == "candidate_only" + ), + "duplicate_identity_count": sum( + 1 for row in rows if row.status == "duplicate_identity" + ), + } + return ArchTargetParityReport( + period=int(period), + incumbent_artifacts=_arch_provider_artifacts(incumbent_provider), + candidate_artifacts=_arch_provider_artifacts(candidate_provider), + value_abs_tolerance=value_abs_tolerance, + value_rel_tolerance=value_rel_tolerance, + counts=counts, + rows=rows, + errors=errors, + ) + + +def main_coverage(argv: list[str] | None = None) -> int: + """CLI entrypoint for Arch target-profile coverage JSON.""" + import argparse + import json + import sys + + parser = argparse.ArgumentParser( + description="Summarize Arch target DB coverage for a Microplex target profile." + ) + parser.add_argument( + "--arch-targets-db", + required=True, + action="append", + help=( + "Arch targets SQLite DB path or consumer-fact JSONL path. May be " + "supplied multiple times to combine source-package artifacts." + ), + ) + parser.add_argument("--period", type=int, required=True) + parser.add_argument("--profile", default="pe_native_broad") + parser.add_argument("--jurisdiction", default="us") + parser.add_argument("--source", action="append", dest="sources", default=[]) + parser.add_argument( + "--no-compose-model-year-targets", + action="store_false", + dest="compose_model_year_targets", + default=True, + ) + parser.add_argument( + "--no-age-soi-targets", + action="store_false", + dest="age_soi_targets", + default=True, + ) + parser.add_argument("--indent", type=int, default=2) + args = parser.parse_args(argv) + + provider = resolve_arch_sqlite_target_provider( + _single_or_many_paths(args.arch_targets_db), + jurisdiction=args.jurisdiction, + compose_model_year_targets=args.compose_model_year_targets, + age_soi_targets=args.age_soi_targets, + ) + report = summarize_arch_target_profile_coverage( + provider, + period=args.period, + profile_name=args.profile, + sources=tuple(args.sources), + jurisdiction=args.jurisdiction, + compose_model_year_targets=args.compose_model_year_targets, + age_soi_targets=args.age_soi_targets, + ) + json.dump(report.to_dict(), sys.stdout, indent=args.indent, sort_keys=True) + sys.stdout.write("\n") + return 0 + + +def main_gaps(argv: list[str] | None = None) -> int: + """CLI entrypoint for Arch target-profile gap queue rows.""" + import argparse + import json + import sys + + parser = argparse.ArgumentParser( + description="Emit an agent-facing Arch target gap queue for a profile." + ) + parser.add_argument( + "--arch-targets-db", + required=True, + action="append", + help=( + "Arch targets SQLite DB path or consumer-fact JSONL path. May be " + "supplied multiple times to combine source-package artifacts." + ), + ) + parser.add_argument("--period", type=int, required=True) + parser.add_argument("--profile", default="pe_native_broad") + parser.add_argument("--jurisdiction", default="us") + parser.add_argument("--source", action="append", dest="sources", default=[]) + parser.add_argument("--include-covered", action="store_true") + parser.add_argument("--format", choices=["json", "csv"], default="json") + parser.add_argument("--output") + parser.add_argument( + "--no-compose-model-year-targets", + action="store_false", + dest="compose_model_year_targets", + default=True, + ) + parser.add_argument( + "--no-age-soi-targets", + action="store_false", + dest="age_soi_targets", + default=True, + ) + parser.add_argument("--indent", type=int, default=2) + args = parser.parse_args(argv) + + provider = resolve_arch_sqlite_target_provider( + _single_or_many_paths(args.arch_targets_db), + jurisdiction=args.jurisdiction, + compose_model_year_targets=args.compose_model_year_targets, + age_soi_targets=args.age_soi_targets, + ) + report = summarize_arch_target_gap_queue( + provider, + period=args.period, + profile_name=args.profile, + include_covered=args.include_covered, + sources=tuple(args.sources), + jurisdiction=args.jurisdiction, + compose_model_year_targets=args.compose_model_year_targets, + age_soi_targets=args.age_soi_targets, + ) + if args.format == "csv": + output = _arch_target_gap_queue_csv(report) + else: + output = json.dumps(report.to_dict(), indent=args.indent, sort_keys=True) + output += "\n" + + if args.output: + Path(args.output).write_text(output) + else: + sys.stdout.write(output) + return 0 + + +def main_refresh(argv: list[str] | None = None) -> int: + """Refresh Arch target coverage and gap snapshots for a target profile.""" + import argparse + import json + import sys + + parser = argparse.ArgumentParser( + description=( + "Write Arch target-profile coverage, gap queue, and summary artifacts." + ) + ) + parser.add_argument( + "--arch-targets-db", + action="append", + default=[], + help=( + "Arch targets SQLite DB path or consumer-fact JSONL path. May be " + "supplied multiple times. If omitted, --artifact-root is searched." + ), + ) + parser.add_argument( + "--artifact-root", + action="append", + default=[], + help=( + "Directory or file to search for Arch target artifacts when " + "--arch-targets-db is omitted. May be supplied multiple times." + ), + ) + parser.add_argument("--period", type=int, required=True) + parser.add_argument("--profile", default="pe_native_broad") + parser.add_argument("--jurisdiction", default="us") + parser.add_argument("--source", action="append", dest="sources", default=[]) + parser.add_argument( + "--output-dir", + default="artifacts/arch-target-coverage", + help="Directory for coverage JSON, gap JSON/CSV, and markdown summary.", + ) + parser.add_argument( + "--no-compose-model-year-targets", + action="store_false", + dest="compose_model_year_targets", + default=True, + ) + parser.add_argument( + "--no-age-soi-targets", + action="store_false", + dest="age_soi_targets", + default=True, + ) + parser.add_argument("--indent", type=int, default=2) + args = parser.parse_args(argv) + + artifact_paths = tuple(Path(path) for path in args.arch_targets_db) + if not artifact_paths: + discovery_roots = ( + tuple(Path(path) for path in args.artifact_root) + or _default_arch_target_artifact_roots() + ) + artifact_paths = discover_arch_target_artifacts(discovery_roots) + if not artifact_paths: + roots = args.artifact_root or [ + str(path) for path in _default_arch_target_artifact_roots() + ] + raise FileNotFoundError( + "No Arch target artifacts found. Pass --arch-targets-db or place " + f"consumer_facts.jsonl / Arch targets DB files under: {', '.join(roots)}" + ) + + provider = resolve_arch_sqlite_target_provider( + _single_or_many_paths([str(path) for path in artifact_paths]), + jurisdiction=args.jurisdiction, + compose_model_year_targets=args.compose_model_year_targets, + age_soi_targets=args.age_soi_targets, + ) + coverage = summarize_arch_target_profile_coverage( + provider, + period=args.period, + profile_name=args.profile, + sources=tuple(args.sources), + jurisdiction=args.jurisdiction, + compose_model_year_targets=args.compose_model_year_targets, + age_soi_targets=args.age_soi_targets, + ) + gaps = summarize_arch_target_gap_queue( + provider, + period=args.period, + profile_name=args.profile, + sources=tuple(args.sources), + jurisdiction=args.jurisdiction, + compose_model_year_targets=args.compose_model_year_targets, + age_soi_targets=args.age_soi_targets, + ) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + stem = f"{_filename_slug(args.profile)}_{int(args.period)}" + coverage_path = output_dir / f"{stem}_coverage.json" + gaps_json_path = output_dir / f"{stem}_gaps.json" + gaps_csv_path = output_dir / f"{stem}_gaps.csv" + summary_path = output_dir / f"{stem}_summary.md" + + coverage_path.write_text( + json.dumps(coverage.to_dict(), indent=args.indent, sort_keys=True) + "\n" + ) + gaps_json_path.write_text( + json.dumps(gaps.to_dict(), indent=args.indent, sort_keys=True) + "\n" + ) + gaps_csv_path.write_text(_arch_target_gap_queue_csv(gaps)) + summary_path.write_text( + _arch_target_refresh_summary_markdown( + coverage, + gaps, + artifact_paths=artifact_paths, + output_paths=( + coverage_path, + gaps_json_path, + gaps_csv_path, + summary_path, + ), + ) + ) + + json.dump( + { + "profile_name": coverage.profile_name, + "period": coverage.period, + "target_cell_count": coverage.target_cell_count, + "covered_cell_count": coverage.covered_cell_count, + "uncovered_cell_count": coverage.uncovered_cell_count, + "coverage_rate": coverage.coverage_rate, + "artifact_paths": [str(path) for path in artifact_paths], + "output_paths": { + "coverage": str(coverage_path), + "gaps_json": str(gaps_json_path), + "gaps_csv": str(gaps_csv_path), + "summary": str(summary_path), + }, + }, + sys.stdout, + indent=args.indent, + sort_keys=True, + ) + sys.stdout.write("\n") + return 0 + + +def main_parity(argv: list[str] | None = None) -> int: + """CLI entrypoint comparing incumbent and candidate Arch target artifacts.""" + import argparse + import json + import sys + + parser = argparse.ArgumentParser( + description=( + "Compare two Arch target artifacts after loading both through the " + "Microplex Arch provider." + ) + ) + parser.add_argument( + "--incumbent-arch-targets-db", + required=True, + action="append", + help=( + "Incumbent Arch targets SQLite DB path. May be supplied multiple " + "times to combine artifacts." + ), + ) + parser.add_argument( + "--candidate-arch-targets-db", + required=True, + action="append", + help=( + "Candidate Arch targets SQLite DB or consumer-fact JSONL path. May " + "be supplied multiple times to combine artifacts." + ), + ) + parser.add_argument("--period", type=int, required=True) + parser.add_argument("--jurisdiction", default="us") + parser.add_argument("--source", action="append", dest="sources", default=[]) + parser.add_argument("--variable", action="append", dest="variables", default=[]) + parser.add_argument("--value-abs-tolerance", type=float, default=1e-6) + parser.add_argument("--value-rel-tolerance", type=float, default=1e-12) + parser.add_argument("--row-limit", type=int, default=50) + parser.add_argument( + "--no-compose-model-year-targets", + action="store_false", + dest="compose_model_year_targets", + default=True, + ) + parser.add_argument( + "--no-age-soi-targets", + action="store_false", + dest="age_soi_targets", + default=True, + ) + parser.add_argument("--indent", type=int, default=2) + args = parser.parse_args(argv) + + try: + incumbent_provider = resolve_arch_sqlite_target_provider( + _single_or_many_paths(args.incumbent_arch_targets_db), + jurisdiction=args.jurisdiction, + compose_model_year_targets=args.compose_model_year_targets, + age_soi_targets=args.age_soi_targets, + ) + candidate_provider = resolve_arch_sqlite_target_provider( + _single_or_many_paths(args.candidate_arch_targets_db), + jurisdiction=args.jurisdiction, + compose_model_year_targets=args.compose_model_year_targets, + age_soi_targets=args.age_soi_targets, + ) + payload = summarize_arch_target_parity( + incumbent_provider, + candidate_provider, + period=args.period, + sources=tuple(args.sources), + variables=tuple(args.variables), + value_abs_tolerance=args.value_abs_tolerance, + value_rel_tolerance=args.value_rel_tolerance, + ).to_dict(row_limit=args.row_limit) + except Exception as exc: # noqa: BLE001 - CLI must return JSON on failures. + payload = { + "valid": False, + "period": args.period, + "incumbent_artifacts": list(args.incumbent_arch_targets_db), + "candidate_artifacts": list(args.candidate_arch_targets_db), + "counts": { + "incumbent_target_count": 0, + "candidate_target_count": 0, + "matched_count": 0, + "value_mismatch_count": 0, + "incumbent_only_count": 0, + "candidate_only_count": 0, + "duplicate_identity_count": 0, + }, + "row_count": 0, + "rows": [], + "errors": [{"code": "load_failed", "message": str(exc)}], + } + + json.dump(payload, sys.stdout, indent=args.indent, sort_keys=True) + sys.stdout.write("\n") + return 0 if payload["valid"] else 1 + + +def main_smoke(argv: list[str] | None = None) -> int: + """CLI entrypoint proving Arch artifacts load as Microplex targets.""" + import argparse + import json + import sys + + parser = argparse.ArgumentParser( + description=( + "Load an Arch target artifact, including consumer_facts.jsonl, " + "through the Microplex Arch provider and emit a JSON smoke report." + ) + ) + parser.add_argument( + "--arch-targets-db", + required=True, + action="append", + help=( + "Arch targets SQLite DB path or consumer-fact JSONL path. May be " + "supplied multiple times to combine source-package artifacts." + ), + ) + parser.add_argument("--period", type=int, required=True) + parser.add_argument("--jurisdiction", default="us") + parser.add_argument("--source", action="append", dest="sources", default=[]) + parser.add_argument("--variable", action="append", dest="variables", default=[]) + parser.add_argument("--expected-target-count", type=int) + parser.add_argument("--sample-limit", type=int, default=5) + parser.add_argument( + "--no-compose-model-year-targets", + action="store_false", + dest="compose_model_year_targets", + default=True, + ) + parser.add_argument( + "--no-age-soi-targets", + action="store_false", + dest="age_soi_targets", + default=True, + ) + parser.add_argument("--indent", type=int, default=2) + args = parser.parse_args(argv) + + errors: list[dict[str, str]] = [] + targets: list[CanonicalTargetSpec] = [] + try: + provider = resolve_arch_sqlite_target_provider( + _single_or_many_paths(args.arch_targets_db), + jurisdiction=args.jurisdiction, + compose_model_year_targets=args.compose_model_year_targets, + age_soi_targets=args.age_soi_targets, + ) + provider_filters: dict[str, Any] = {} + if args.sources: + provider_filters["sources"] = tuple(args.sources) + if args.variables: + provider_filters["variables"] = tuple(args.variables) + targets = list( + provider.load_target_set( + TargetQuery( + period=args.period, + provider_filters=provider_filters, + ) + ).targets + ) + except Exception as exc: # noqa: BLE001 - CLI must return JSON on failures. + errors.append({"code": "load_failed", "message": str(exc)}) + + if ( + args.expected_target_count is not None + and len(targets) != args.expected_target_count + ): + errors.append( + { + "code": "unexpected_target_count", + "message": ( + f"Expected {args.expected_target_count} targets, " + f"loaded {len(targets)}." + ), + } + ) + + payload = { + "valid": not errors, + "period": args.period, + "target_count": len(targets), + "by_variable": dict( + sorted(Counter(_target_variable(target) for target in targets).items()) + ), + "by_source": dict( + sorted(Counter(str(target.source) for target in targets).items()) + ), + "by_aggregation": dict( + sorted( + Counter( + str(getattr(target.aggregation, "value", target.aggregation)) + for target in targets + ).items() + ) + ), + "by_filter_count": { + str(key): value + for key, value in sorted( + Counter(len(target.filters) for target in targets).items() + ) + }, + "sample_targets": [ + _target_smoke_sample(target) + for target in targets[: max(0, args.sample_limit)] + ], + "errors": errors, + } + json.dump(payload, sys.stdout, indent=args.indent, sort_keys=True) + sys.stdout.write("\n") + return 0 if payload["valid"] else 1 + + +def _target_variable(target: CanonicalTargetSpec) -> str: + """Return the Microplex variable represented by a canonical target.""" + variable = target.metadata.get("variable") if target.metadata else None + return str(variable or target.measure or target.name) + + +def _target_smoke_sample(target: CanonicalTargetSpec) -> dict[str, Any]: + """Return a compact JSON sample for an Arch target smoke report.""" + return { + "name": target.name, + "variable": _target_variable(target), + "aggregation": str(getattr(target.aggregation, "value", target.aggregation)), + "measure": target.measure, + "value": target.value, + "period": target.period, + "source": str(target.source), + "filters": [ + { + "feature": target_filter.feature, + "operator": str( + getattr(target_filter.operator, "value", target_filter.operator) + ), + "value": target_filter.value, + } + for target_filter in target.filters + ], + "metadata": { + key: target.metadata[key] + for key in ( + "arch_aggregate_fact_key", + "arch_semantic_fact_key", + "arch_source_record_id", + "geo_level", + ) + if key in target.metadata + }, + } + + +def _arch_target_parity_rows( + incumbent_targets: list[CanonicalTargetSpec], + candidate_targets: list[CanonicalTargetSpec], + *, + value_abs_tolerance: float, + value_rel_tolerance: float, +) -> tuple[ArchTargetParityRow, ...]: + incumbent_by_identity = _index_arch_targets_by_parity_identity(incumbent_targets) + candidate_by_identity = _index_arch_targets_by_parity_identity(candidate_targets) + rows: list[ArchTargetParityRow] = [] + for identity in sorted( + set(incumbent_by_identity) | set(candidate_by_identity), + key=str, + ): + incumbent_group = tuple(incumbent_by_identity.get(identity, ())) + candidate_group = tuple(candidate_by_identity.get(identity, ())) + absolute_delta: float | None = None + relative_delta: float | None = None + if len(incumbent_group) != 1 or len(candidate_group) != 1: + status = _arch_target_parity_nonunique_status( + incumbent_group, + candidate_group, + ) + else: + incumbent_value = float(incumbent_group[0].value) + candidate_value = float(candidate_group[0].value) + absolute_delta = candidate_value - incumbent_value + relative_delta = ( + absolute_delta / incumbent_value if incumbent_value != 0 else None + ) + status = ( + "matched" + if _arch_target_values_match( + incumbent_value, + candidate_value, + abs_tolerance=value_abs_tolerance, + rel_tolerance=value_rel_tolerance, + ) + else "value_mismatch" + ) + rows.append( + ArchTargetParityRow( + status=status, + identity=identity, + incumbent_targets=incumbent_group, + candidate_targets=candidate_group, + absolute_delta=absolute_delta, + relative_delta=relative_delta, + ) + ) + return tuple(sorted(rows, key=_arch_target_parity_row_sort_key)) + + +def _index_arch_targets_by_parity_identity( + targets: list[CanonicalTargetSpec], +) -> dict[tuple[Any, ...], list[CanonicalTargetSpec]]: + indexed: dict[tuple[Any, ...], list[CanonicalTargetSpec]] = {} + for target in targets: + indexed.setdefault(_arch_target_parity_identity(target), []).append(target) + return indexed + + +def _arch_target_parity_identity(target: CanonicalTargetSpec) -> tuple[Any, ...]: + metadata = target.metadata or {} + return ( + str(getattr(target.entity, "value", target.entity)), + str(getattr(target.aggregation, "value", target.aggregation)), + str(target.measure or ""), + _arch_target_period_value(target.period), + str(target.source or ""), + _target_variable(target), + str(metadata.get("geo_level") or ""), + str(_arch_target_geographic_id(target) or ""), + _target_parity_filter_tuple(target), + ) + + +def _arch_target_period_value(value: int | str) -> int | str: + try: + return int(value) + except (TypeError, ValueError): + return str(value) + + +def _target_parity_filter_tuple( + target: CanonicalTargetSpec, +) -> tuple[tuple[str, str, str], ...]: + return tuple( + sorted( + ( + str(target_filter.feature), + str(getattr(target_filter.operator, "value", target_filter.operator)), + _json_scalar_text(target_filter.value), + ) + for target_filter in target.filters + ) + ) + + +def _arch_target_parity_nonunique_status( + incumbent_targets: tuple[CanonicalTargetSpec, ...], + candidate_targets: tuple[CanonicalTargetSpec, ...], +) -> str: + if len(incumbent_targets) > 1 or len(candidate_targets) > 1: + return "duplicate_identity" + if not incumbent_targets: + return "candidate_only" + if not candidate_targets: + return "incumbent_only" + return "duplicate_identity" + + +def _arch_target_values_match( + incumbent_value: float, + candidate_value: float, + *, + abs_tolerance: float, + rel_tolerance: float, +) -> bool: + delta = abs(candidate_value - incumbent_value) + if delta <= abs_tolerance: + return True + scale = max(abs(incumbent_value), abs(candidate_value), 1.0) + return delta <= rel_tolerance * scale + + +def _arch_target_parity_row_sort_key(row: ArchTargetParityRow) -> tuple[int, str]: + status_rank = { + "value_mismatch": 0, + "incumbent_only": 1, + "candidate_only": 2, + "duplicate_identity": 3, + "matched": 4, + } + return ( + status_rank.get(row.status, 99), + json.dumps(_arch_target_parity_identity_dict(row.identity), sort_keys=True), + ) + + +def _arch_target_parity_error(row: ArchTargetParityRow) -> dict[str, Any]: + identity = _arch_target_parity_identity_dict(row.identity) + if row.status == "value_mismatch": + return { + "code": "value_mismatch", + "message": "Candidate target value differs from incumbent target value.", + "identity": identity, + "incumbent_value": row.incumbent_targets[0].value, + "candidate_value": row.candidate_targets[0].value, + "absolute_delta": row.absolute_delta, + "relative_delta": row.relative_delta, + } + if row.status == "incumbent_only": + return { + "code": "missing_candidate_target", + "message": "Incumbent target identity is absent from the candidate artifact.", + "identity": identity, + "incumbent_target_count": len(row.incumbent_targets), + "candidate_target_count": len(row.candidate_targets), + } + if row.status == "candidate_only": + return { + "code": "unexpected_candidate_target", + "message": "Candidate target identity is absent from the incumbent artifact.", + "identity": identity, + "incumbent_target_count": len(row.incumbent_targets), + "candidate_target_count": len(row.candidate_targets), + } + return { + "code": "duplicate_identity", + "message": "A target identity is not unique in one or both artifacts.", + "identity": identity, + "incumbent_target_count": len(row.incumbent_targets), + "candidate_target_count": len(row.candidate_targets), + } + + +def _arch_target_parity_identity_dict(identity: tuple[Any, ...]) -> dict[str, Any]: + ( + entity, + aggregation, + measure, + period, + source, + variable, + geo_level, + geographic_id, + filters, + ) = identity + return { + "entity": entity, + "aggregation": aggregation, + "measure": measure or None, + "period": period, + "source": source or None, + "variable": variable, + "geo_level": geo_level or None, + "geographic_id": geographic_id or None, + "filters": [ + {"feature": feature, "operator": operator, "value": value} + for feature, operator, value in filters + ], + } + + +def _target_parity_sample(target: CanonicalTargetSpec) -> dict[str, Any]: + sample = _target_smoke_sample(target) + metadata = dict(sample["metadata"]) + for key in ( + "target_id", + "source_table", + "display_label", + "arch_source_period", + "arch_model_period", + ): + if key in target.metadata: + metadata[key] = target.metadata[key] + sample["metadata"] = metadata + return sample + + +def _arch_provider_artifacts( + provider: ( + ArchSQLiteTargetProvider + | ArchFactSQLiteTargetProvider + | ArchConsumerFactJSONLTargetProvider + | ArchCompositeSQLiteTargetProvider + ), +) -> tuple[str, ...]: + if isinstance(provider, ArchCompositeSQLiteTargetProvider): + return tuple(str(path) for path in provider.db_paths) + path = getattr(provider, "db_path", None) or getattr(provider, "path", None) + if path is None: + return () + return (str(path),) + + +def arch_target_record_to_canonical_spec( + record: ArchTargetRecord, + *, + entity_overrides: dict[str, Any] | None = None, +) -> CanonicalTargetSpec | None: + """Translate one Arch target record into a canonical core target spec.""" + if record.target_type == "RATE": + return None + if _should_skip_arch_target_record(record): + return None + + filters = list(_canonical_filters_for_arch_constraints(record.constraints)) + geography_filter = _target_filter_for_arch_geography(record) + if geography_filter is not None: + filters.append(geography_filter) + entity_overrides = entity_overrides or {} + source_variable = record.variable + model_variable: str + aggregation: TargetAggregation + measure: str | None + entity: EntityType + + if record.target_type == "COUNT": + count_mapping = ARCH_COUNT_VARIABLE_ALIASES.get(source_variable) + positive_measure = _positive_measure_for_count_record(source_variable) + if count_mapping is not None: + model_variable, entity, count_filter_measure = count_mapping + if count_filter_measure is not None: + filters.append( + TargetFilter( + feature=count_filter_measure, + operator=">", + value=0, + ) + ) + elif positive_measure is not None: + model_variable = positive_measure + entity = EntityType.TAX_UNIT + filters.append( + TargetFilter(feature=positive_measure, operator=">", value=0) + ) + else: + return None + aggregation = TargetAggregation.COUNT + measure = None + elif record.target_type == "AMOUNT": + model_variable = ARCH_AMOUNT_VARIABLE_ALIASES.get( + source_variable, source_variable + ) + aggregation = TargetAggregation.SUM + measure = model_variable + if _is_blocked_self_employment_binding(record, model_variable): + raise ValueError( + "Broad Arch business/proprietors income cannot be exposed as " + "plain self_employment_income; use a dedicated proprietors-income " + "target or an explicit composite mapping." + ) + entity = _entity_for_measure(model_variable, entity_overrides) + if model_variable in ARCH_POSITIVE_AMOUNT_FILTER_VARIABLES: + filters.append( + TargetFilter( + feature=model_variable, + operator=">", + value=0, + ) + ) + else: + return None + + filters = list(_dedupe_target_filters(filters)) + display_label = _arch_target_display_label(record) + metadata = { + "target_id": record.target_id, + "stratum_id": record.stratum_id, + "display_label": display_label, + "variable": model_variable, + "model_variable_role": policyengine_us_variable_role(model_variable).value, + "arch_variable": record.variable, + "arch_target_type": record.target_type, + "target_semantic": record.target_type.lower(), + "source": record.source, + "source_table": record.source_table, + "source_url": record.source_url, + "notes": record.notes, + "stratum_name": record.stratum_name, + "jurisdiction": record.jurisdiction, + "geo_level": _arch_record_geo_level(record), + "geographic_level": record.geographic_level, + "geography_id": record.geography_id, + "constraint_count": len(filters), + "arch_source_period": record.source_period or record.period, + "arch_model_period": record.period, + } + if record.aggregate_fact_key is not None: + metadata["arch_aggregate_fact_key"] = record.aggregate_fact_key + if record.semantic_fact_key is not None: + metadata["arch_semantic_fact_key"] = record.semantic_fact_key + if record.source_record_id is not None: + metadata["arch_source_record_id"] = record.source_record_id + if record.source_cell_keys: + metadata["arch_source_cell_keys"] = list(record.source_cell_keys) + if record.source_row_keys: + metadata["arch_source_row_keys"] = list(record.source_row_keys) + if record.unit is not None: + metadata["unit"] = record.unit + if record.concept is not None: + metadata["arch_concept"] = record.concept + if record.source_concept is not None: + metadata["arch_source_concept"] = record.source_concept + if record.concept_relation is not None: + metadata["arch_concept_relation"] = record.concept_relation + if record.concept_authority is not None: + metadata["arch_concept_authority"] = record.concept_authority + if record.concept_evidence_url is not None: + metadata["arch_concept_evidence_url"] = record.concept_evidence_url + if record.concept_evidence_notes is not None: + metadata["arch_concept_evidence_notes"] = record.concept_evidence_notes + if record.legal_vintage is not None: + metadata["arch_legal_vintage"] = record.legal_vintage + if record.source_db_path is not None: + metadata["arch_source_db_path"] = record.source_db_path + if record.source_db_index is not None: + metadata["arch_source_db_index"] = record.source_db_index + if record.source_target_id is not None: + metadata["arch_source_target_id"] = record.source_target_id + if record.source_stratum_id is not None: + metadata["arch_source_stratum_id"] = record.source_stratum_id + if record.aging_factors is not None: + factors = record.aging_factors + metadata.update( + { + "arch_aged": True, + "arch_aging_source_year": factors.source_year, + "arch_aging_target_year": factors.target_year, + "arch_aging_count_factor": factors.count_factor, + "arch_aging_amount_factor": factors.amount_factor, + "arch_aging_count_method": factors.count_method, + "arch_aging_amount_method": factors.amount_method, + } + ) + + return CanonicalTargetSpec( + name=f"arch_target_{record.target_id}", + entity=entity, + value=record.value, + period=record.period, + measure=measure, + aggregation=aggregation, + filters=tuple(filters), + source=record.source, + description=display_label, + metadata=metadata, + ) + + +def _should_skip_arch_target_record(record: ArchTargetRecord) -> bool: + return _is_bea_regional_country_record(record) + + +def _is_blocked_self_employment_binding( + record: ArchTargetRecord, + model_variable: str, +) -> bool: + if model_variable != "self_employment_income": + return False + markers = { + str(value) + for value in ( + record.variable, + record.concept, + record.source_concept, + record.source_record_id, + ) + if value is not None + } + markers.update( + f"{variable}:{value}" + for variable, _, value in record.constraints + if value is not None + ) + return bool(markers & ARCH_BROAD_BUSINESS_INCOME_SELF_EMPLOYMENT_BLOCKLIST) + + +def _is_bea_regional_country_record(record: ArchTargetRecord) -> bool: + if not _has_bea_regional_lineage(record): + return False + if str(record.geography_id) == "0100000US": + return True + return _arch_record_geo_level(record) in {"national", "country"} + + +def _has_bea_regional_lineage(record: ArchTargetRecord) -> bool: + lineage_values = ( + record.concept, + record.source_concept, + record.source_record_id, + ) + return any( + str(value).startswith("bea_regional.") + or str(value).startswith("bea-regional.") + or ".bea-regional-" in str(value) + for value in lineage_values + if value is not None + ) + + +def _group_arch_target_rows(rows: list[sqlite3.Row]) -> list[ArchTargetRecord]: + grouped: dict[int, dict[str, Any]] = {} + for row in rows: + target_id = int(row["target_id"]) + item = grouped.setdefault( + target_id, + { + "target_id": target_id, + "stratum_id": int(row["stratum_id"]), + "variable": row["variable"], + "period": int(row["period"]), + "value": float(row["value"]), + "target_type": str(row["target_type"]), + "geographic_level": row["geographic_level"], + "geography_id": None, + "source": row["source"], + "source_table": row["source_table"], + "source_url": row["source_url"], + "notes": row["notes"], + "stratum_name": row["stratum_name"], + "jurisdiction": row["jurisdiction"], + "constraints": [], + }, + ) + if row["constraint_variable"] is not None: + constraint = ( + str(row["constraint_variable"]), + str(row["constraint_operator"]), + str(row["constraint_value"]), + ) + if constraint not in item["constraints"]: + item["constraints"].append(constraint) + return [ + ArchTargetRecord( + **{ + **item, + "constraints": tuple(item["constraints"]), + } + ) + for item in grouped.values() + ] + + +def _load_arch_fact_lineage( + conn: sqlite3.Connection, +) -> dict[str, dict[str, tuple[str, ...]]]: + lineage: dict[str, dict[str, tuple[str, ...]]] = {} + if _sqlite_table_exists(conn, "fact_source_cells"): + for row in conn.execute( + """ + SELECT fact_key, source_cell_key + FROM fact_source_cells + ORDER BY fact_key, ordinal + """ + ): + fact_key = str(row["fact_key"]) + item = lineage.setdefault(fact_key, {}) + item["source_cell_keys"] = ( + *item.get("source_cell_keys", ()), + str(row["source_cell_key"]), + ) + if _sqlite_table_exists(conn, "fact_source_rows"): + for row in conn.execute( + """ + SELECT fact_key, source_row_key + FROM fact_source_rows + ORDER BY fact_key, ordinal + """ + ): + fact_key = str(row["fact_key"]) + item = lineage.setdefault(fact_key, {}) + item["source_row_keys"] = ( + *item.get("source_row_keys", ()), + str(row["source_row_key"]), + ) + return lineage + + +def _consumer_fact_rows_to_records( + rows: list[dict[str, Any]], +) -> list[ArchTargetRecord]: + records: list[ArchTargetRecord] = [] + stratum_ids: dict[tuple[tuple[str, str, str], ...], int] = {} + for target_id, row in enumerate(rows, start=1): + constraints = tuple( + dict.fromkeys( + constraint + for constraint in ( + *_arch_consumer_fact_domain_constraints(row), + *( + _arch_consumer_fact_constraint(constraint) + for constraint in _consumer_fact_universe_constraints(row).get( + "constraints", [] + ) + ), + ) + if constraint is not None + ) + ) + stratum_id = stratum_ids.setdefault(constraints, len(stratum_ids) + 1) + variable, target_type = _arch_consumer_fact_target_identity(row) + source = row.get("source") or {} + observed_measure = row.get("observed_measure") or {} + geography = row.get("geography") or {} + lineage = row.get("lineage") or {} + concept_alignment = row.get("concept_alignment") or {} + source_name = ( + source.get("source_name") or observed_measure.get("source_name") or "arch" + ) + records.append( + ArchTargetRecord( + target_id=target_id, + stratum_id=stratum_id, + variable=variable, + period=_consumer_fact_period(row), + value=_json_numeric_value(row.get("value")), + target_type=target_type, + geographic_level=_arch_consumer_fact_geographic_level(row), + geography_id=geography.get("id"), + source=_normalize_arch_source(str(source_name)), + source_table=source.get("source_table") + or observed_measure.get("source_table"), + source_url=source.get("url"), + notes=source.get("method_notes"), + stratum_name=_arch_consumer_fact_stratum_name(row), + jurisdiction="US", + constraints=constraints, + aggregate_fact_key=row.get("aggregate_fact_key"), + semantic_fact_key=row.get("semantic_fact_key"), + source_record_id=arch_consumer_fact_source_record_id(row), + source_cell_keys=tuple(lineage.get("source_cell_keys") or ()), + source_row_keys=tuple(lineage.get("source_row_keys") or ()), + unit=observed_measure.get("unit"), + concept=_arch_consumer_fact_concept(row), + source_concept=concept_alignment.get("source_concept") + or observed_measure.get("source_concept"), + concept_relation=concept_alignment.get("relation"), + concept_authority=concept_alignment.get("authority"), + concept_evidence_url=concept_alignment.get("evidence_url"), + concept_evidence_notes=concept_alignment.get("evidence_notes"), + legal_vintage=concept_alignment.get("legal_vintage"), + ) + ) + return records + + +def _consumer_fact_period(row: dict[str, Any]) -> int: + return arch_consumer_fact_period(row) + + +def _arch_consumer_fact_target_identity(row: dict[str, Any]) -> tuple[str, str]: + concept = _arch_consumer_fact_concept(row) + if concept == "ssa.annual_oasdi_or_ssi_payment_amount": + return (_ssa_payment_variable_from_consumer_fact(row), "AMOUNT") + try: + return ARCH_FACT_CONCEPT_TO_TARGET[concept] + except KeyError as exc: + raise ValueError( + f"No Microplex Arch consumer fact mapping for concept {concept!r}" + ) from exc + + +def _ssa_payment_variable_from_consumer_fact(row: dict[str, Any]) -> str: + for constraint in _consumer_fact_universe_constraints(row).get("constraints", []): + if ( + constraint.get("variable") + == "us_social_security_and_ssi.program_payment_type" + ): + return str(constraint.get("value")) + raise ValueError("SSA payment fact row has no program payment type constraint.") + + +def _arch_consumer_fact_concept(row: dict[str, Any]) -> str: + concept = arch_consumer_fact_concept(row) + if concept is None: + raise ValueError("Arch consumer fact row has no mappable concept.") + return concept + + +def _arch_consumer_fact_domain_constraints( + row: dict[str, Any], +) -> tuple[tuple[str, str, str], ...]: + domain = str(_consumer_fact_universe_constraints(row).get("domain")) + return _arch_fact_domain_constraints_for_domain(domain) + + +def _arch_consumer_fact_constraint( + constraint: dict[str, Any], +) -> tuple[str, str, str] | None: + variable = str(constraint["variable"]) + if variable in ARCH_IGNORED_FACT_CONSTRAINT_VARIABLES: + return None + try: + mapped_variable = ARCH_FACT_CONSTRAINT_VARIABLE_ALIASES[variable] + except KeyError as exc: + raise ValueError( + f"No Microplex Arch consumer fact constraint mapping for variable {variable!r}" + ) from exc + return ( + mapped_variable, + str(constraint["operator"]), + _json_scalar_text(constraint.get("value")), + ) + + +def _consumer_fact_universe_constraints(row: dict[str, Any]) -> dict[str, Any]: + universe_constraints = row.get("universe_constraints") or {} + if not isinstance(universe_constraints, dict): + raise ValueError("Arch consumer fact universe_constraints must be an object.") + return universe_constraints + + +def _arch_consumer_fact_geographic_level(row: dict[str, Any]) -> str | None: + geography = row.get("geography") or {} + return _arch_geographic_level_from_arch_level(geography.get("level")) + + +def _arch_consumer_fact_stratum_name(row: dict[str, Any]) -> str: + dimensions = row.get("dimensions") or {} + income_range = dimensions.get("income_range") + geography_name = _arch_consumer_fact_geography_name(row) + if income_range == "all": + return f"{geography_name} All Filers" + if income_range: + return f"{geography_name} Filers AGI {income_range}" + return str(row.get("label") or geography_name) + + +def _arch_consumer_fact_geography_name(row: dict[str, Any]) -> str: + geography = row.get("geography") or {} + level = str(geography.get("level") or "").lower() + if level == "country": + return "US" + return str(geography.get("name") or geography.get("id") or "US") + + +def _group_arch_fact_rows( + rows: list[sqlite3.Row], + *, + lineage: dict[str, dict[str, tuple[str, ...]]], +) -> list[ArchTargetRecord]: + grouped: dict[str, dict[str, Any]] = {} + stratum_ids: dict[tuple[tuple[str, str, str], ...], int] = {} + for row in rows: + fact_key = str(row["fact_key"]) + item = grouped.setdefault( + fact_key, + { + "row": row, + "constraints": list(_arch_fact_domain_constraints(row)), + }, + ) + if row["constraint_variable"] is not None: + constraint = _arch_fact_constraint(row) + if constraint is not None: + item["constraints"].append(constraint) + + records: list[ArchTargetRecord] = [] + for target_id, (fact_key, item) in enumerate(sorted(grouped.items()), start=1): + row = item["row"] + constraints = tuple(dict.fromkeys(item["constraints"])) + stratum_id = stratum_ids.setdefault(constraints, len(stratum_ids) + 1) + variable, target_type = _arch_fact_target_identity(row) + period = int(row["period_value"]) + source_name = row["source_name"] or "arch" + fact_lineage = lineage.get(fact_key, {}) + records.append( + ArchTargetRecord( + target_id=target_id, + stratum_id=stratum_id, + variable=variable, + period=period, + value=_arch_fact_numeric_value(row), + target_type=target_type, + geographic_level=_arch_fact_geographic_level(row), + geography_id=row["geography_id"], + source=_normalize_arch_source(source_name), + source_table=row["source_table"], + source_url=row["source_url"], + notes=row["source_method_notes"], + stratum_name=_arch_fact_stratum_name(row), + jurisdiction="US", + constraints=constraints, + aggregate_fact_key=fact_key, + semantic_fact_key=_arch_fact_semantic_key(row, constraints), + source_record_id=row["source_record_id"], + source_cell_keys=fact_lineage.get("source_cell_keys", ()), + source_row_keys=fact_lineage.get("source_row_keys", ()), + unit=row["measure_unit"], + concept=row["measure_concept"], + source_concept=row["measure_source_concept"], + concept_relation=row["measure_concept_relation"], + concept_authority=row["measure_concept_authority"], + concept_evidence_url=row["measure_concept_evidence_url"], + concept_evidence_notes=row["measure_concept_evidence_notes"], + legal_vintage=row["measure_legal_vintage"], + ) + ) + return records + + +def _arch_fact_target_identity(row: sqlite3.Row) -> tuple[str, str]: + concept = str(row["measure_concept"]) + try: + return ARCH_FACT_CONCEPT_TO_TARGET[concept] + except KeyError as exc: + raise ValueError( + f"No Microplex Arch fact mapping for concept {concept!r}" + ) from exc + + +def _arch_fact_domain_constraints(row: sqlite3.Row) -> tuple[tuple[str, str, str], ...]: + domain = str(row["domain"]) + return _arch_fact_domain_constraints_for_domain(domain) + + +def _arch_fact_domain_constraints_for_domain( + domain: str, +) -> tuple[tuple[str, str, str], ...]: + try: + return ARCH_FACT_DOMAIN_CONSTRAINTS[domain] + except KeyError as exc: + raise ValueError( + f"No Microplex Arch fact mapping for domain {domain!r}" + ) from exc + + +def _arch_fact_constraint(row: sqlite3.Row) -> tuple[str, str, str] | None: + variable = str(row["constraint_variable"]) + if variable in ARCH_IGNORED_FACT_CONSTRAINT_VARIABLES: + return None + try: + mapped_variable = ARCH_FACT_CONSTRAINT_VARIABLE_ALIASES[variable] + except KeyError as exc: + raise ValueError( + f"No Microplex Arch fact constraint mapping for variable {variable!r}" + ) from exc + return ( + mapped_variable, + str(row["constraint_operator"]), + _sqlite_json_scalar_text( + row["constraint_value_text"], + row["constraint_value_numeric"], + row["constraint_value_json"], + ), + ) + + +def _arch_fact_numeric_value(row: sqlite3.Row) -> float: + numeric = row["value_numeric"] + if numeric is not None: + return float(numeric) + return float(_sqlite_json_scalar_text(row["value_text"], None, row["value_json"])) + + +def _sqlite_json_scalar_text( + text_value: Any, + numeric_value: Any, + json_value: Any, +) -> str: + if text_value is not None: + return str(text_value) + if numeric_value is not None: + numeric = float(numeric_value) + return str(int(numeric)) if numeric.is_integer() else str(numeric) + return str(json_value) + + +def _arch_fact_geographic_level(row: sqlite3.Row) -> str | None: + return _arch_geographic_level_from_arch_level(row["geography_level"]) + + +def _arch_geographic_level_from_arch_level(level_value: Any) -> str | None: + level = str(level_value or "").lower() + if level == "country": + return "NATIONAL" + if level == "state": + return "STATE" + if level == "county": + return "COUNTY" + if level in {"congressional_district", "congressional-district"}: + return "CONGRESSIONAL_DISTRICT" + if level in { + "state_legislative_district_upper", + "state-legislative-district-upper", + }: + return "STATE_LEGISLATIVE_DISTRICT_UPPER" + if level in { + "state_legislative_district_lower", + "state-legislative-district-lower", + }: + return "STATE_LEGISLATIVE_DISTRICT_LOWER" + return level.upper() if level else None + + +def _json_numeric_value(value: Any) -> float: + return arch_consumer_fact_numeric_value(value) + + +def _json_scalar_text(value: Any) -> str: + if isinstance(value, float) and value.is_integer(): + return str(int(value)) + if isinstance(value, (int, float, str)): + return str(value) + return json.dumps(value, sort_keys=True) + + +def _arch_fact_stratum_name(row: sqlite3.Row) -> str: + income_range = _json_object_value(row["filters_json"], "income_range") + geography_name = row["geography_name"] or "US" + if income_range is None: + return str(geography_name) + if income_range == "all": + return f"{geography_name} All Filers" + return f"{geography_name} Filers AGI {income_range}" + + +def _arch_fact_semantic_key( + row: sqlite3.Row, + constraints: tuple[tuple[str, str, str], ...], +) -> str: + constraint_key = ",".join( + f"{variable}{operator}{value}" for variable, operator, value in constraints + ) + return "|".join( + [ + "arch.semantic_fact.v1", + str(row["measure_concept"]), + str(row["domain"]), + f"{row['period_value']}", + f"{row['geography_level']}:{row['geography_id']}", + constraint_key, + ] + ) + + +def _json_object_value(raw: Any, key: str) -> Any: + if raw is None: + return None + import json + + try: + payload = json.loads(str(raw)) + except json.JSONDecodeError: + return None + if not isinstance(payload, dict): + return None + return payload.get(key) + + +def _arch_target_display_label(record: ArchTargetRecord) -> str: + measure_label = _arch_target_measure_label(record) + scope_label = _arch_target_scope_label(record) + source_label = _humanize_arch_source(record.source) + suffix = ( + f" ({source_label}, {record.period})" if source_label else f" ({record.period})" + ) + if scope_label: + return f"{measure_label} for {scope_label}{suffix}" + return f"{measure_label}{suffix}" + + +def _arch_target_measure_label(record: ArchTargetRecord) -> str: + source_variable = str(record.variable) + override = ARCH_VARIABLE_LABEL_OVERRIDES.get(source_variable) + if override is not None: + return override + if record.target_type == "COUNT": + for suffix in ("_returns", "_claims", "_count"): + if source_variable.endswith(suffix): + base = source_variable.removesuffix(suffix) + return f"{_humanize_snake_label(base)} {suffix.removeprefix('_')}" + return f"{_humanize_snake_label(source_variable)} count" + if record.target_type == "AMOUNT": + if source_variable.endswith("_amount"): + return f"{_humanize_snake_label(source_variable.removesuffix('_amount'))} amount" + return f"{_humanize_snake_label(source_variable)} amount" + return _humanize_snake_label(source_variable) + + +def _arch_target_scope_label(record: ArchTargetRecord) -> str: + if record.stratum_name: + return str(record.stratum_name) + constraint_labels = [ + label + for constraint in record.constraints + for label in [_arch_constraint_display_label(constraint)] + if label + ] + if constraint_labels: + return ", ".join(constraint_labels) + jurisdiction = str(record.jurisdiction or "").strip() + return jurisdiction.upper().replace("_", " ") if jurisdiction else "" + + +def _arch_constraint_display_label( + constraint: tuple[str, str, str], +) -> str: + variable, operator, value = constraint + canonical_operator = _canonical_arch_constraint_operator(operator) + value_text = str(value) + if variable == "agi_bracket": + return f"AGI {ARCH_AGI_BRACKET_LABELS.get(value_text, value_text)}" + if variable == "is_tax_filer" and canonical_operator == "==": + if _truthy_constraint_value(value_text): + return "tax filers" + if _falsey_constraint_value(value_text): + return "non-filers" + if variable == "state_fips" and canonical_operator == "==": + return f"state FIPS {str(value_text).zfill(2)}" + if variable == "congressional_district" and canonical_operator == "==": + return f"congressional district {str(value_text).zfill(2)}" + if variable == "sldu_id" and canonical_operator == "==": + return f"state senate district {value_text}" + if variable == "sldl_id" and canonical_operator == "==": + return f"state house district {value_text}" + positive_feature = ARCH_POSITIVE_CONSTRAINT_ALIASES.get(variable) + if positive_feature is not None and canonical_operator == "==": + feature_label = _humanize_snake_label(positive_feature) + if _truthy_constraint_value(value_text): + return f"{feature_label} > 0" + if _falsey_constraint_value(value_text): + return f"{feature_label} = 0" + return f"{_humanize_snake_label(variable)} {canonical_operator} {value_text}" + + +def _truthy_constraint_value(value: str) -> bool: + try: + return float(str(value)) == 1.0 + except ValueError: + return str(value).strip().lower() in {"true", "yes"} + + +def _falsey_constraint_value(value: str) -> bool: + try: + return float(str(value)) == 0.0 + except ValueError: + return str(value).strip().lower() in {"false", "no"} + + +def _humanize_arch_source(source: str | None) -> str: + if not source: + return "" + return _humanize_snake_label(str(source)) + + +def _humanize_snake_label(value: str) -> str: + words = [ + ARCH_LABEL_WORD_OVERRIDES.get(word.lower(), word.lower()) + for word in str(value).replace("-", "_").split("_") + if word + ] + if not words: + return "" + label = " ".join(words) + label = label[0].upper() + label[1:] + return label.replace("Tax exempt", "Tax-exempt") + + +def _canonical_filters_for_arch_constraints( + constraints: tuple[tuple[str, str, str], ...], +) -> tuple[TargetFilter, ...]: + filters: list[TargetFilter] = [] + equalities = _constraint_equalities(constraints) + for variable, operator, value in constraints: + canonical_operator = _canonical_arch_constraint_operator(operator) + if variable == "agi_bracket": + filters.extend(_agi_bracket_filters(value)) + continue + if variable == "congressional_district": + geoid = _congressional_district_geoid( + state_fips=equalities.get("state_fips"), + district=value, + ) + filters.append( + TargetFilter( + feature="congressional_district_geoid", + operator=canonical_operator, + value=geoid or value, + ) + ) + continue + positive_feature = ARCH_POSITIVE_CONSTRAINT_ALIASES.get(variable) + if positive_feature is not None: + filters.append( + _positive_support_filter_for_arch_constraint( + positive_feature, + operator=canonical_operator, + value=value, + ) + ) + continue + feature = ARCH_CONSTRAINT_VARIABLE_ALIASES.get(variable, variable) + filters.append( + TargetFilter(feature=feature, operator=canonical_operator, value=value) + ) + return _dedupe_target_filters(filters) + + +def _target_filter_for_arch_geography(record: ArchTargetRecord) -> TargetFilter | None: + geography_id = record.geography_id + if geography_id is None: + return None + geo_level = _arch_record_geo_level(record) + if geo_level == "state": + return TargetFilter( + feature="state_fips", + operator="==", + value=_state_fips_from_arch_geography_id(geography_id), + ) + if geo_level == "county": + return TargetFilter( + feature="county_fips", + operator="==", + value=_county_fips_from_arch_geography_id(geography_id), + ) + if geo_level == "district": + return TargetFilter( + feature="congressional_district_geoid", + operator="==", + value=_congressional_district_from_arch_geography_id(geography_id), + ) + if geo_level == "sldu": + return TargetFilter( + feature="sldu_id", + operator="==", + value=_state_legislative_district_from_arch_geography_id( + geography_id, + chamber="upper", + ), + ) + if geo_level == "sldl": + return TargetFilter( + feature="sldl_id", + operator="==", + value=_state_legislative_district_from_arch_geography_id( + geography_id, + chamber="lower", + ), + ) + return None + + +def _state_fips_from_arch_geography_id(geography_id: Any) -> str: + raw = str(geography_id) + if raw.startswith("0400000US"): + return raw[-2:] + if raw.isdigit(): + return raw.zfill(2) + return raw + + +def _county_fips_from_arch_geography_id(geography_id: Any) -> str: + raw = str(geography_id) + if raw.startswith("0500000US"): + return raw[-5:] + if raw.isdigit(): + return raw.zfill(5) + return raw + + +def _congressional_district_from_arch_geography_id(geography_id: Any) -> str: + raw = str(geography_id) + if raw.startswith("5001800US"): + return raw[-4:] + return raw + + +ARCH_STATE_ABBR_BY_FIPS = { + "01": "AL", + "02": "AK", + "04": "AZ", + "05": "AR", + "06": "CA", + "08": "CO", + "09": "CT", + "10": "DE", + "11": "DC", + "12": "FL", + "13": "GA", + "15": "HI", + "16": "ID", + "17": "IL", + "18": "IN", + "19": "IA", + "20": "KS", + "21": "KY", + "22": "LA", + "23": "ME", + "24": "MD", + "25": "MA", + "26": "MI", + "27": "MN", + "28": "MS", + "29": "MO", + "30": "MT", + "31": "NE", + "32": "NV", + "33": "NH", + "34": "NJ", + "35": "NM", + "36": "NY", + "37": "NC", + "38": "ND", + "39": "OH", + "40": "OK", + "41": "OR", + "42": "PA", + "44": "RI", + "45": "SC", + "46": "SD", + "47": "TN", + "48": "TX", + "49": "UT", + "50": "VT", + "51": "VA", + "53": "WA", + "54": "WV", + "55": "WI", + "56": "WY", + "72": "PR", +} + + +def _state_legislative_district_from_arch_geography_id( + geography_id: Any, + *, + chamber: str, +) -> str: + return normalize_state_legislative_district_id( + geography_id, chamber=chamber + ) or str(geography_id) + + +def _canonical_arch_constraint_operator(operator: str) -> str: + value = str(operator).strip() + return ARCH_CONSTRAINT_OPERATOR_ALIASES.get(value.lower(), value) + + +def _constraint_equalities( + constraints: tuple[tuple[str, str, str], ...], +) -> dict[str, str]: + return { + variable: value + for variable, operator, value in constraints + if _canonical_arch_constraint_operator(operator) == "==" + } + + +def _congressional_district_geoid( + *, + state_fips: str | None, + district: str, +) -> str | None: + try: + district_id = str(int(str(district))).zfill(2) + except ValueError: + district_id = str(district) + if len(district_id) >= 4: + return district_id + if state_fips is None: + return None + try: + state_id = str(int(str(state_fips))).zfill(2) + except ValueError: + state_id = str(state_fips).zfill(2) + return f"{state_id}{district_id}" + + +def _positive_support_filter_for_arch_constraint( + feature: str, + *, + operator: str, + value: str, +) -> TargetFilter: + canonical_operator = _canonical_arch_constraint_operator(operator) + if canonical_operator == "==": + try: + numeric_value = float(str(value)) + except ValueError: + numeric_value = None + if numeric_value == 1 or str(value).strip().lower() in {"true", "yes"}: + return TargetFilter(feature=feature, operator=">", value=0) + if numeric_value == 0 or str(value).strip().lower() in {"false", "no"}: + return TargetFilter(feature=feature, operator="==", value=0) + return TargetFilter(feature=feature, operator=canonical_operator, value=value) + + +def _dedupe_target_filters(filters: list[TargetFilter]) -> tuple[TargetFilter, ...]: + seen: set[tuple[str, str, Any]] = set() + deduped: list[TargetFilter] = [] + for target_filter in filters: + operator = getattr(target_filter.operator, "value", target_filter.operator) + key = (str(target_filter.feature), str(operator), str(target_filter.value)) + if key in seen: + continue + seen.add(key) + deduped.append(target_filter) + return tuple(deduped) + + +def _agi_bracket_filters(value: str) -> tuple[TargetFilter, ...]: + bounds = ARCH_AGI_BRACKET_FILTERS.get(value) + if bounds is None: + return (TargetFilter(feature="agi_bracket", operator="==", value=value),) + lower, upper = bounds + filters: list[TargetFilter] = [] + if lower is not None: + filters.append( + TargetFilter(feature="adjusted_gross_income", operator=">=", value=lower) + ) + if upper is not None: + filters.append( + TargetFilter(feature="adjusted_gross_income", operator="<", value=upper) + ) + return tuple(filters) + + +def _positive_measure_for_count_record(source_variable: str) -> str | None: + if source_variable.endswith("_returns"): + amount_variable = f"{source_variable.removesuffix('_returns')}_amount" + elif source_variable.endswith("_claims"): + amount_variable = f"{source_variable.removesuffix('_claims')}_amount" + else: + return None + return ARCH_AMOUNT_VARIABLE_ALIASES.get(amount_variable) + + +def _entity_for_measure( + measure: str, + entity_overrides: dict[str, Any], +) -> EntityType: + override = entity_overrides.get(measure) + if isinstance(override, EntityType): + return override + if override is not None: + return EntityType(override) + return ARCH_ENTITY_HINTS.get(measure, EntityType.TAX_UNIT) + + +def _matches_arch_provider_filters( + record: ArchTargetRecord, + *, + variables: tuple[str, ...], + domain_variables: tuple[str, ...], + geo_levels: tuple[str, ...], + target_cells: tuple[dict[str, Any], ...], + entity_overrides: dict[str, Any] | None = None, +) -> bool: + target: CanonicalTargetSpec | None = None + if variables or domain_variables or target_cells: + target = arch_target_record_to_canonical_spec( + record, + entity_overrides=entity_overrides or {}, + ) + if target is None: + return False + if variables and target is not None: + candidate_variables = _arch_target_query_variables(record, target) + if variables and candidate_variables.isdisjoint(variables): + return False + if domain_variables and target is not None: + candidate_domain_variables = _arch_target_domain_variables(target) + if candidate_domain_variables.isdisjoint(domain_variables): + return False + if geo_levels: + geo_level = _arch_record_geo_level(record) + if geo_level not in {_normalize_geo_level(str(level)) for level in geo_levels}: + return False + if target_cells and target is not None: + if not any(_matches_arch_target_cell(target, cell) for cell in target_cells): + return False + return True + + +def _arch_target_query_variables( + record: ArchTargetRecord, + target: CanonicalTargetSpec, +) -> set[str]: + variables = { + record.variable, + str(target.metadata.get("variable")), + } + if target.measure is not None: + variables.add(str(target.measure)) + if target.aggregation is TargetAggregation.SUM: + variables.update(_arch_target_cell_variables(target)) + return {variable for variable in variables if variable} + + +def _arch_target_cell_variables(target: CanonicalTargetSpec) -> set[str]: + if target.aggregation is TargetAggregation.COUNT: + if target.entity is EntityType.HOUSEHOLD: + return {"household_count"} + if target.entity is EntityType.PERSON: + return {"person_count"} + if target.entity is EntityType.SPM_UNIT: + return {"spm_unit_count"} + return {"tax_unit_count"} + if target.measure is not None: + variable = str(target.measure) + return {variable, *ARCH_TARGET_CELL_VARIABLE_ALIASES.get(variable, ())} + variable = target.metadata.get("variable") + if variable is None: + return set() + variable = str(variable) + return {variable, *ARCH_TARGET_CELL_VARIABLE_ALIASES.get(variable, ())} + + +def _arch_target_domain_variables(target: CanonicalTargetSpec) -> set[str]: + domain_variables: set[str] = set() + for target_filter in target.filters: + feature = str(target_filter.feature) + if feature in { + "state_fips", + "county_fips", + "tract_geoid", + "congressional_district_geoid", + "sldu_id", + "sldl_id", + "program_payment_type", + "tax_unit_is_filer", + }: + continue + domain_variables.add(feature) + variable = str(target.metadata.get("variable") or "") + if ( + target.aggregation is TargetAggregation.COUNT + and variable + and variable not in _arch_target_cell_variables(target) + ): + domain_variables.add(variable) + if ( + target.aggregation is TargetAggregation.SUM + and variable in ARCH_SELF_DOMAIN_AMOUNT_VARIABLES + and not domain_variables + ): + domain_variables.add(variable) + return domain_variables + + +def _matches_arch_target_cell( + target: CanonicalTargetSpec, + raw_cell: dict[str, Any], +) -> bool: + variable = raw_cell.get("variable") + if variable is None or str(variable) not in _arch_target_cell_variables(target): + return False + + target_geo_level = _normalize_geo_level( + str(target.metadata.get("geo_level") or "national") + ) + geo_level = raw_cell.get("geo_level") + cell_geo_level = target_geo_level + if geo_level is not None: + cell_geo_level = _normalize_geo_level(str(geo_level)) + if target_geo_level != cell_geo_level: + return False + + geographic_id = raw_cell.get("geographic_id") + if geographic_id is not None: + target_geographic_id = _arch_target_geographic_id(target) + if target_geographic_id is None: + return False + if _normalize_target_cell_geographic_id( + target_geographic_id, + geo_level=target_geo_level, + ) != _normalize_target_cell_geographic_id( + geographic_id, + geo_level=cell_geo_level, + ): + return False + + domain_variable = raw_cell.get("domain_variable") + if "domain_variable" in raw_cell: + target_domain_variables = _arch_target_domain_variables(target) + cell_domain_variables = set( + _split_target_cell_domain_variables(domain_variable) + ) + if domain_variable is None or not cell_domain_variables: + if _target_self_domain_is_redundant(target, target_domain_variables): + return True + return not target_domain_variables + if not _target_domain_variables_match( + target, + target_domain_variables=target_domain_variables, + cell_domain_variables=cell_domain_variables, + ): + return False + + return True + + +def _target_domain_variables_match( + target: CanonicalTargetSpec, + *, + target_domain_variables: set[str], + cell_domain_variables: set[str], +) -> bool: + if cell_domain_variables == target_domain_variables: + return True + + implied_domain_variables = _arch_target_implied_domain_variables(target) + effective_target_domain_variables = ( + target_domain_variables | implied_domain_variables + ) + if cell_domain_variables == effective_target_domain_variables: + return True + + target_variables = _arch_target_cell_variables(target) + if ( + target.aggregation is TargetAggregation.SUM + and target_variables.issubset(ARCH_SELF_DOMAIN_AMOUNT_VARIABLES) + and cell_domain_variables + == effective_target_domain_variables | target_variables + ): + return True + + model_variable = str(target.metadata.get("variable") or "") + if ( + target.aggregation is TargetAggregation.SUM + and model_variable + and model_variable in effective_target_domain_variables + and cell_domain_variables + == effective_target_domain_variables - {model_variable} + ): + return True + + if ( + target.aggregation is TargetAggregation.COUNT + and model_variable + and cell_domain_variables + == effective_target_domain_variables - {model_variable} + ): + return True + + return False + + +def _arch_target_implied_domain_variables( + target: CanonicalTargetSpec, +) -> set[str]: + if str(target.source) != "IRS_SOI": + return set() + arch_variable = str(target.metadata.get("arch_variable") or "") + if arch_variable in ARCH_IRS_SOI_CREDIT_AGI_DOMAIN_VARIABLES: + return {"adjusted_gross_income"} + if arch_variable in ( + ARCH_IRS_SOI_ITEMIZED_DEDUCTION_AMOUNT_VARIABLES + | ARCH_IRS_SOI_ITEMIZED_DEDUCTION_COUNT_VARIABLES + ): + source_table = str(target.metadata.get("source_table") or "").lower() + if any( + marker in source_table + for marker in ARCH_IRS_SOI_ITEMIZED_DEDUCTION_TABLE_MARKERS + ): + return {"tax_unit_itemizes"} + return set() + + +def _target_self_domain_is_redundant( + target: CanonicalTargetSpec, + target_domain_variables: set[str], +) -> bool: + if target.aggregation is not TargetAggregation.SUM: + return False + target_variables = _arch_target_cell_variables(target) + return ( + len(target_domain_variables) == 1 + and target_domain_variables.issubset(target_variables) + and target_domain_variables.issubset(ARCH_SELF_DOMAIN_AMOUNT_VARIABLES) + ) + + +def _coverage_for_arch_target_cell( + cell_filter: dict[str, str | None], + target_set: TargetSet, +) -> ArchTargetCellCoverage: + matches = [ + target + for target in target_set.targets + if _matches_arch_target_cell(target, cell_filter) + ] + return ArchTargetCellCoverage( + cell=dict(cell_filter), + target_ids=tuple( + int(target.metadata["target_id"]) + for target in matches + if target.metadata.get("target_id") is not None + ), + target_names=tuple(str(target.name) for target in matches), + sources=tuple( + sorted({str(target.source) for target in matches if target.source}) + ), + ) + + +def _arch_gap_loaded_variable_catalog( + provider: ( + ArchSQLiteTargetProvider + | ArchFactSQLiteTargetProvider + | ArchConsumerFactJSONLTargetProvider + | ArchCompositeSQLiteTargetProvider + ), + *, + period: int, + jurisdiction: str | None, + sources: tuple[str, ...], + compose_model_year_targets: bool | None, + age_soi_targets: bool | None, +) -> dict[tuple[str, str], set[str]]: + resolved_jurisdiction = jurisdiction or provider.jurisdiction + if isinstance( + provider, + ( + ArchFactSQLiteTargetProvider, + ArchConsumerFactJSONLTargetProvider, + ArchCompositeSQLiteTargetProvider, + ), + ): + records = provider.load_records(period=period, sources=sources) + else: + resolved_compose = ( + provider.compose_model_year_targets + if compose_model_year_targets is None + else compose_model_year_targets + ) + resolved_age_soi = ( + provider.age_soi_targets if age_soi_targets is None else age_soi_targets + ) + records = ( + provider._compose_model_year_records( + target_year=period, + jurisdiction=resolved_jurisdiction, + sources=sources, + age_soi_targets=resolved_age_soi, + ) + if resolved_compose + else provider.load_records( + period=period, + jurisdiction=resolved_jurisdiction, + sources=sources, + ) + ) + catalog: dict[tuple[str, str], set[str]] = {} + for record in records: + key = (record.source, record.variable) + catalog.setdefault(key, set()).add(_arch_record_geo_level(record)) + return catalog + + +def _arch_gap_queue_row_for_coverage_cell( + coverage: ArchTargetCellCoverage, + *, + profile_name: str, + period: int, + loaded_variable_catalog: dict[tuple[str, str], set[str]], + variable_uncovered_count: int, +) -> ArchTargetGapQueueRow: + cell = coverage.cell + expected_source = _arch_gap_expected_source(cell) + expected_arch_variable = _arch_gap_expected_arch_variable(cell) + expected_target_type = _arch_gap_expected_target_type(cell) + expected_entity = _arch_gap_expected_entity(cell) + expected_aggregation = _arch_gap_expected_aggregation(expected_target_type) + loader_status = _arch_gap_loader_status( + coverage, + expected_source=expected_source, + expected_arch_variable=expected_arch_variable, + loaded_variable_catalog=loaded_variable_catalog, + cell=cell, + ) + gap_category = _arch_gap_category( + cell, + loader_status=loader_status, + expected_source=expected_source, + expected_arch_variable=expected_arch_variable, + ) + return ArchTargetGapQueueRow( + priority=0, + profile_name=profile_name, + period=int(period), + variable=str(cell.get("variable") or ""), + geo_level=cell.get("geo_level"), + domain_variable=cell.get("domain_variable"), + geographic_id=cell.get("geographic_id"), + covered=coverage.covered, + target_count=coverage.target_count, + target_ids=coverage.target_ids, + sources=coverage.sources, + expected_source=expected_source, + expected_source_table=_arch_gap_expected_source_table( + expected_source, + expected_arch_variable, + cell, + ), + expected_arch_variable=expected_arch_variable, + expected_target_type=expected_target_type, + expected_entity=expected_entity, + expected_aggregation=expected_aggregation, + expected_filters=_arch_gap_expected_filters(cell), + gap_category=gap_category, + loader_status=loader_status, + agent_task_kind=_arch_gap_agent_task_kind(gap_category), + notes=_arch_gap_notes( + cell, + expected_source=expected_source, + expected_arch_variable=expected_arch_variable, + gap_category=gap_category, + variable_uncovered_count=variable_uncovered_count, + ), + ) + + +def _arch_gap_queue_sort_key(row: ArchTargetGapQueueRow) -> tuple[Any, ...]: + source_rank = { + "IRS_SOI": 0, + "BEA": 1, + "CENSUS_ACS": 2, + "CMS_ACA": 3, + "CMS_MEDICAID": 4, + "CMS_MEDICARE": 5, + "USDA_SNAP": 6, + "SSA": 7, + "HHS_ACF_TANF": 8, + "HHS_ACF_LIHEAP": 9, + "FEDERAL_RESERVE": 10, + }.get(str(row.expected_source), 99) + return ( + row.covered, + row.loader_status == "needs_source_mapping_review", + -_arch_gap_notes_uncovered_count(row.notes), + source_rank, + str(row.variable), + str(row.geo_level or ""), + str(row.domain_variable or ""), + ) + + +def _arch_gap_notes_uncovered_count(notes: str) -> int: + if not notes.startswith("profile_variable_uncovered_count="): + return 0 + raw_count = notes.split(";", 1)[0].split("=", 1)[1] + try: + return int(raw_count) + except ValueError: + return 0 + + +def _arch_gap_expected_source(cell: dict[str, Any]) -> str | None: + variable = str(cell.get("variable") or "") + domain_variables = set( + _split_target_cell_domain_variables(cell.get("domain_variable")) + ) + if not domain_variables and variable in ARCH_BEA_FULL_POP_AMOUNT_VARIABLES: + return "BEA" + if variable == "tax_unit_count" and "aca_ptc" in domain_variables: + return "IRS_SOI" + if variable == "snap" or "snap" in domain_variables: + return "USDA_SNAP" + if variable == "tanf" or "tanf" in domain_variables: + return "HHS_ACF_TANF" + if "spm_unit_energy_subsidy_reported" in domain_variables: + return "HHS_ACF_LIHEAP" + if variable == "aca_ptc" or "aca_ptc" in domain_variables: + return "CMS_ACA" + if variable == "medicaid" or "medicaid_enrolled" in domain_variables: + return "CMS_MEDICAID" + if variable == "ssi" or variable.startswith("social_security"): + return "SSA" + if variable == "state_income_tax": + return "CENSUS_STC" + if variable == "medicare_part_b_premiums": + return "CMS_MEDICARE" + if variable == "net_worth": + return "FEDERAL_RESERVE" + if variable == "person_count": + if _normalize_geo_level(cell.get("geo_level")) in {"sldu", "sldl"}: + return "CENSUS_DECENNIAL" + if "adjusted_gross_income" in domain_variables: + return "IRS_SOI" + if "age" in domain_variables or not domain_variables: + return "CENSUS_PEP" + return None + if variable == "household_count": + if _normalize_geo_level(cell.get("geo_level")) in {"sldu", "sldl"}: + return "CENSUS_DECENNIAL" + if not domain_variables: + return "CENSUS_ACS" + return None + if variable in ARCH_IRS_SOI_GAP_VARIABLES: + return "IRS_SOI" + if domain_variables & ARCH_IRS_SOI_GAP_VARIABLES: + return "IRS_SOI" + return None + + +def _arch_gap_expected_arch_variable(cell: dict[str, Any]) -> str | None: + variable = str(cell.get("variable") or "") + domain_variables = tuple( + _split_target_cell_domain_variables(cell.get("domain_variable")) + ) + domain_variable = domain_variables[0] if len(domain_variables) == 1 else None + if not domain_variables and variable in ARCH_BEA_FULL_POP_AMOUNT_ARCH_VARIABLES: + return ARCH_BEA_FULL_POP_AMOUNT_ARCH_VARIABLES[variable] + if variable == "tax_unit_count": + if set(domain_variables) in ( + {"eitc_child_count"}, + {"adjusted_gross_income", "eitc", "eitc_child_count"}, + ): + return "eitc_claims" + if { + "adjusted_gross_income", + "income_tax_before_credits", + }.issubset(domain_variables): + return "income_tax_before_credits_returns" + if set(domain_variables) == {"aca_ptc"}: + return "aca_ptc_returns" + itemized_domain_variables = set(domain_variables) - {"tax_unit_itemizes"} + if ( + "tax_unit_itemizes" in domain_variables + and len(itemized_domain_variables) == 1 + ): + return ARCH_MODEL_COUNT_DOMAIN_VARIABLE_HINTS.get( + next(iter(itemized_domain_variables)) + ) + if domain_variable is None: + return "tax_unit_count" if not domain_variables else None + return ARCH_MODEL_COUNT_DOMAIN_VARIABLE_HINTS.get(domain_variable) + if variable == "household_count": + if domain_variable == "snap": + return "snap_household_count" + if domain_variable == "spm_unit_energy_subsidy_reported": + return "liheap_household_count" + return "household_count" if domain_variable is None else None + if variable == "spm_unit_count": + if domain_variable == "tanf": + return "tanf_family_count" + return None + if variable == "person_count": + if domain_variable == "snap": + return "snap_participant_count" + if domain_variable == "aca_ptc": + return "aca_marketplace_enrollment" + if domain_variable == "medicaid_enrolled": + return "medicaid_total_enrollment" + if domain_variable == "adjusted_gross_income": + return "tax_filer_individual_count" + if domain_variable == "age" or not domain_variables: + return "population" + return None + if variable == "snap": + return "snap_benefits" + if variable == "aca_ptc": + return "aca_aptc_amount" + if variable == "medicaid": + return "medicaid_benefits" + if variable == "tanf": + return "tanf_cash_assistance" + if variable == "state_income_tax": + return "state_individual_income_tax_collections" + return ARCH_MODEL_AMOUNT_VARIABLE_HINTS.get(variable) + + +def _arch_gap_expected_target_type(cell: dict[str, Any]) -> str | None: + variable = str(cell.get("variable") or "") + if variable in { + "household_count", + "person_count", + "spm_unit_count", + "tax_unit_count", + }: + return "COUNT" + if _arch_gap_expected_arch_variable(cell) is not None: + return "AMOUNT" + return None + + +def _arch_gap_expected_entity(cell: dict[str, Any]) -> str | None: + variable = str(cell.get("variable") or "") + if variable == "tax_unit_count": + return EntityType.TAX_UNIT.value + if variable == "person_count": + return EntityType.PERSON.value + if variable == "spm_unit_count": + return EntityType.SPM_UNIT.value + if variable in {"household_count", "snap"}: + return EntityType.HOUSEHOLD.value + entity = ARCH_ENTITY_HINTS.get(variable) + return entity.value if entity is not None else None + + +def _arch_gap_expected_aggregation(target_type: str | None) -> str | None: + if target_type == "COUNT": + return "count" + if target_type == "AMOUNT": + return "sum" + return None + + +def _arch_gap_expected_filters(cell: dict[str, Any]) -> tuple[dict[str, Any], ...]: + filters: list[dict[str, Any]] = [] + geo_level = _normalize_geo_level(cell.get("geo_level")) + geographic_id = cell.get("geographic_id") + if geo_level == "state": + filters.append( + { + "kind": "geography", + "feature": "state_fips", + "operator": "==", + "value": ( + _state_fips_from_arch_geography_id(geographic_id) + if geographic_id is not None + else "" + ), + } + ) + if geo_level == "sldu": + filters.append( + { + "kind": "geography", + "feature": "sldu_id", + "operator": "==", + "value": ( + _normalize_target_cell_geographic_id( + geographic_id, + geo_level=geo_level, + ) + if geographic_id is not None + else "" + ), + } + ) + if geo_level == "sldl": + filters.append( + { + "kind": "geography", + "feature": "sldl_id", + "operator": "==", + "value": ( + _normalize_target_cell_geographic_id( + geographic_id, + geo_level=geo_level, + ) + if geographic_id is not None + else "" + ), + } + ) + for domain_variable in _split_target_cell_domain_variables( + cell.get("domain_variable") + ): + filters.append( + { + "kind": "domain", + "feature": domain_variable, + "operator": ">", + "value": 0, + } + ) + return tuple(filters) + + +def _arch_gap_expected_source_table( + expected_source: str | None, + expected_arch_variable: str | None, + cell: dict[str, Any], +) -> str | None: + variable = str(cell.get("variable") or "") + if expected_source == "BEA": + geo_level = _normalize_geo_level(cell.get("geo_level")) + if geo_level == "state" and expected_arch_variable in { + "proprietors_income_amount", + "wages_salaries_amount", + }: + return "BEA Regional SAINC5N annual state personal income" + if expected_arch_variable == "wages_salaries_amount": + return "BEA NIPA annual total wages and salaries" + if expected_arch_variable in { + "medicaid_benefits", + "personal_dividend_income_amount", + "proprietors_income_amount", + "rental_income_amount", + "social_security_benefits", + "unemployment_insurance_benefits", + }: + return "BEA NIPA annual personal income components" + return "BEA NIPA or Regional personal income tables" + if expected_arch_variable in ARCH_GAP_SOURCE_TABLE_HINTS: + return ARCH_GAP_SOURCE_TABLE_HINTS[expected_arch_variable] + if variable in ARCH_GAP_SOURCE_TABLE_HINTS: + return ARCH_GAP_SOURCE_TABLE_HINTS[variable] + if expected_source == "IRS_SOI": + if expected_arch_variable and ( + expected_arch_variable.startswith("wages_salaries_") + or expected_arch_variable.startswith("net_capital_gains_") + or expected_arch_variable.startswith("taxable_ira_distributions_") + or expected_arch_variable.startswith("taxable_pension_income_") + or expected_arch_variable.startswith("taxable_social_security_") + or expected_arch_variable.startswith("unemployment_compensation_") + ): + return "IRS SOI Publication 1304 Table 1.4" + if expected_arch_variable and ( + expected_arch_variable.endswith("_claims") + or expected_arch_variable + in {"real_estate_taxes_amount", "real_estate_taxes_claims"} + ): + return "IRS SOI itemized deduction or credit tables" + return "IRS SOI Publication 1304" + if expected_source == "CENSUS_ACS": + return "Census ACS summary tables" + if expected_source == "CENSUS_DECENNIAL": + return "Census 2020 CD119 state legislative district summary file" + if expected_source == "CENSUS_PEP": + return "Census Population Estimates Program age-sex files" + if expected_source == "CENSUS_STC": + return "Census State Tax Collections item T40" + if expected_source == "CMS_ACA": + return "CMS Marketplace Open Enrollment public-use files" + if expected_source == "CMS_MEDICAID": + return "CMS Medicaid enrollment and expenditure reports" + if expected_source == "CMS_MEDICARE": + return "CMS Medicare Trustees Report Part B premium income" + if expected_source == "FEDERAL_RESERVE": + return "Federal Reserve Financial Accounts Z.1 household net worth" + if expected_source == "SSA": + return "SSA Annual Statistical Supplement" + if expected_source == "HHS_ACF_TANF": + return "ACF TANF Financial Data" + if expected_source == "HHS_ACF_LIHEAP": + return "HHS ACF LIHEAP National Profile" + return None + + +def _arch_gap_loader_status( + coverage: ArchTargetCellCoverage, + *, + expected_source: str | None, + expected_arch_variable: str | None, + loaded_variable_catalog: dict[tuple[str, str], set[str]], + cell: dict[str, Any], +) -> str: + if coverage.covered: + return "covered" + if expected_source is None or expected_arch_variable is None: + return "needs_source_mapping_review" + loaded_geo_levels = loaded_variable_catalog.get( + (expected_source, expected_arch_variable) + ) + if loaded_geo_levels: + expected_geo_level = _normalize_geo_level(cell.get("geo_level")) + if expected_geo_level not in loaded_geo_levels: + return "loaded_arch_variable_missing_geography" + return "loaded_arch_variable_missing_filter_or_adapter" + return "missing_arch_target_record" + + +def _arch_gap_category( + cell: dict[str, Any], + *, + loader_status: str, + expected_source: str | None, + expected_arch_variable: str | None, +) -> str: + if loader_status == "covered": + return "covered" + if _arch_gap_is_deprioritized_survey_or_model_input(cell): + return "survey_or_model_input_deprioritized" + if loader_status == "missing_arch_target_record": + return "ready_primary_loader" + if loader_status == "loaded_arch_variable_missing_geography": + return "ready_rollup_or_geography" + if loader_status == "loaded_arch_variable_missing_filter_or_adapter": + return "adapter_or_constraint_review" + if expected_source is None or expected_arch_variable is None: + return "source_mapping_review" + return "source_mapping_review" + + +def _arch_gap_is_deprioritized_survey_or_model_input(cell: dict[str, Any]) -> bool: + variable = str(cell.get("variable") or "") + if variable in ARCH_DEPRIORITIZED_SURVEY_OR_MODEL_GAP_VARIABLES: + return True + domain_variables = set( + _split_target_cell_domain_variables(cell.get("domain_variable")) + ) + return bool(domain_variables & ARCH_DEPRIORITIZED_SURVEY_OR_MODEL_GAP_DOMAINS) + + +def _arch_gap_agent_task_kind(gap_category: str) -> str: + if gap_category == "covered": + return "none" + if gap_category == "survey_or_model_input_deprioritized": + return "defer_or_review_non_primary_source" + if gap_category == "ready_rollup_or_geography": + return "add_arch_rollup_or_geography_records" + if gap_category == "adapter_or_constraint_review": + return "review_adapter_or_constraints" + if gap_category == "ready_primary_loader": + return "add_arch_source_loader_or_target_record" + return "review_source_mapping" + + +def _arch_gap_notes( + cell: dict[str, Any], + *, + expected_source: str | None, + expected_arch_variable: str | None, + gap_category: str, + variable_uncovered_count: int, +) -> str: + parts = [f"profile_variable_uncovered_count={variable_uncovered_count}"] + if gap_category == "survey_or_model_input_deprioritized": + parts.append( + "survey/model-input proxy deprioritized until primary source review" + ) + if expected_source is None: + parts.append("expected_source requires review") + if expected_arch_variable is None: + parts.append("expected Arch variable requires review") + if "," in str(cell.get("domain_variable") or ""): + parts.append("multi-domain cells may need a grouped source-record spec") + return "; ".join(parts) + + +def _arch_target_gap_queue_csv(report: ArchTargetGapQueueReport) -> str: + import csv + import io + import json + + fieldnames = [ + "priority", + "profile_name", + "period", + "variable", + "geo_level", + "domain_variable", + "geographic_id", + "covered", + "target_count", + "target_ids", + "sources", + "expected_source", + "expected_source_table", + "expected_arch_variable", + "expected_target_type", + "expected_entity", + "expected_aggregation", + "expected_filters", + "gap_category", + "loader_status", + "agent_task_kind", + "notes", + ] + buffer = io.StringIO() + writer = csv.DictWriter(buffer, fieldnames=fieldnames) + writer.writeheader() + for row in report.rows: + writer.writerow( + { + "priority": row.priority, + "profile_name": row.profile_name, + "period": row.period, + "variable": row.variable, + "geo_level": row.geo_level, + "domain_variable": row.domain_variable, + "geographic_id": row.geographic_id, + "covered": row.covered, + "target_count": row.target_count, + "target_ids": json.dumps(list(row.target_ids)), + "sources": json.dumps(list(row.sources)), + "expected_source": row.expected_source, + "expected_source_table": row.expected_source_table, + "expected_arch_variable": row.expected_arch_variable, + "expected_target_type": row.expected_target_type, + "expected_entity": row.expected_entity, + "expected_aggregation": row.expected_aggregation, + "expected_filters": json.dumps(list(row.expected_filters)), + "gap_category": row.gap_category, + "loader_status": row.loader_status, + "agent_task_kind": row.agent_task_kind, + "notes": row.notes, + } + ) + return buffer.getvalue() + + +def _summarize_arch_cell_coverage( + coverage_cells: tuple[ArchTargetCellCoverage, ...], + *, + field: str, +) -> dict[str, dict[str, int]]: + summary: dict[str, dict[str, int]] = {} + for coverage in coverage_cells: + raw_value = coverage.cell.get(field) + value = ( + _normalize_geo_level(raw_value) + if field == "geo_level" + else str(raw_value or "") + ) + if not value: + value = "none" + item = summary.setdefault( + value, + { + "target_cell_count": 0, + "covered_cell_count": 0, + "uncovered_cell_count": 0, + }, + ) + item["target_cell_count"] += 1 + if coverage.covered: + item["covered_cell_count"] += 1 + else: + item["uncovered_cell_count"] += 1 + return dict(sorted(summary.items())) + + +def _target_cell_to_provider_filter( + cell: PolicyEngineUSTargetCell | dict[str, Any], +) -> dict[str, str | None]: + if isinstance(cell, PolicyEngineUSTargetCell): + return cell.to_provider_filter() + return { + "variable": cell.get("variable"), + "geo_level": cell.get("geo_level"), + "domain_variable": cell.get("domain_variable"), + "geographic_id": cell.get("geographic_id"), + } + + +def _arch_target_geographic_id(target: CanonicalTargetSpec) -> str | None: + geo_level = str(target.metadata.get("geo_level") or "national").lower() + feature_by_level = { + "state": "state_fips", + "county": "county_fips", + "tract": "tract_geoid", + "district": "congressional_district_geoid", + "congressional_district": "congressional_district_geoid", + "sldu": "sldu_id", + "sldl": "sldl_id", + } + feature = feature_by_level.get(geo_level) + if feature is None: + return None + for target_filter in target.filters: + if str(target_filter.feature) != feature: + continue + operator = getattr(target_filter.operator, "value", target_filter.operator) + if _canonical_arch_constraint_operator(str(operator)) == "==": + return str(target_filter.value) + return None + + +def _split_target_cell_domain_variables(value: Any) -> tuple[str, ...]: + if value is None: + return () + return tuple( + _normalize_target_cell_domain_variable(part) + for part in str(value).split(",") + if part.strip() + ) + + +def _normalize_target_cell_domain_variable(value: Any) -> str: + raw = str(value).strip() + return ARCH_POSITIVE_CONSTRAINT_ALIASES.get(raw, raw) + + +def _normalize_target_cell_geographic_id( + value: Any, + *, + geo_level: str | None = None, +) -> str: + raw = str(value) + normalized_geo_level = _normalize_geo_level(geo_level) + chamber = None + if normalized_geo_level == "sldu": + chamber = "upper" + elif normalized_geo_level == "sldl": + chamber = "lower" + normalized_sld = normalize_state_legislative_district_id(raw, chamber=chamber) + if normalized_sld != raw: + return str(normalized_sld) + try: + return str(int(raw)) + except (TypeError, ValueError): + return raw + + +def _arch_record_composition_key( + record: ArchTargetRecord, +) -> tuple[str, str, str, tuple[tuple[str, str, str], ...]]: + return ( + record.variable, + record.target_type, + _arch_record_geo_level(record), + tuple(sorted(record.constraints)), + ) + + +def _arch_record_geo_level(record: ArchTargetRecord) -> str: + return _geo_level_for_constraints(record.constraints) or _normalize_geo_level( + record.geographic_level + ) + + +def _geo_level_for_constraints( + constraints: tuple[tuple[str, str, str], ...], +) -> str | None: + constraint_variables = {variable for variable, _, _ in constraints} + for variable, geo_level in ( + ("tract_geoid", "tract"), + ("county_fips", "county"), + ("congressional_district", "district"), + ("congressional_district_geoid", "district"), + ("sldu_id", "sldu"), + ("sldl_id", "sldl"), + ("state_fips", "state"), + ): + if variable in constraint_variables: + return geo_level + return None + + +def _normalize_arch_source(source: str) -> str: + value = str(source) + return ARCH_SOURCE_ALIASES.get(value.lower(), value.upper().replace("-", "_")) + + +def _normalize_geo_level(geo_level: str | None) -> str: + if not geo_level: + return "national" + normalized = geo_level.lower() + if normalized in {"congressional_district", "congressional-district"}: + return "district" + if normalized in { + "sldu", + "state_legislative_district_upper", + "state-legislative-district-upper", + "state_senate_district", + "state-senate-district", + }: + return "sldu" + if normalized in { + "sldl", + "state_legislative_district_lower", + "state-legislative-district-lower", + "state_house_district", + "state-house-district", + }: + return "sldl" + return normalized + + +def _sqlite_table_has_column( + conn: sqlite3.Connection, + table: str, + column: str, +) -> bool: + return column in _sqlite_table_columns(conn, table) + + +def _sqlite_table_columns(conn: sqlite3.Connection, table: str) -> set[str]: + names: set[str] = set() + for row in conn.execute(f"PRAGMA table_info({table})"): + names.add(str(row["name"] if isinstance(row, sqlite3.Row) else row[1])) + return names + + +def _sqlite_table_exists(conn: sqlite3.Connection, table: str) -> bool: + row = conn.execute( + "SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = ?", + (table,), + ).fetchone() + return row is not None + + +def _looks_like_arch_consumer_fact_jsonl(path: Path) -> bool: + return path.suffix.lower() in {".jsonl", ".ndjson"} + + +def _as_arch_db_path_tuple( + value: str | Path | tuple[str | Path, ...], +) -> tuple[Path, ...]: + if isinstance(value, (str, Path)): + return (Path(value),) + paths = tuple(Path(path) for path in value) + if not paths: + raise ValueError("At least one Arch targets DB path is required") + return paths + + +def _single_or_many_paths(paths: list[str]) -> str | tuple[str, ...]: + return paths[0] if len(paths) == 1 else tuple(paths) + + +def _default_arch_target_artifact_roots() -> tuple[Path, ...]: + candidates = ( + Path.cwd() / "artifacts", + Path.cwd().parent / "arch", + Path("/tmp"), + ) + return tuple(path for path in candidates if path.exists()) + + +def discover_arch_target_artifacts( + roots: tuple[str | Path, ...], + *, + max_depth: int = 6, +) -> tuple[Path, ...]: + """Find local Arch target artifacts under bounded discovery roots.""" + + discovered: list[Path] = [] + seen: set[Path] = set() + for raw_root in roots: + root = Path(raw_root).expanduser() + if root.is_file(): + candidates = (root,) + elif root.is_dir(): + candidates = tuple(_walk_arch_target_artifact_candidates(root, max_depth)) + else: + continue + for candidate in candidates: + resolved = candidate.resolve() + if resolved in seen or not _is_arch_target_artifact(resolved): + continue + discovered.append(resolved) + seen.add(resolved) + return tuple(sorted(discovered, key=lambda path: str(path))) + + +def _walk_arch_target_artifact_candidates( + root: Path, max_depth: int +) -> tuple[Path, ...]: + import os + + skip_dir_names = { + ".git", + ".mypy_cache", + ".pytest_cache", + ".ruff_cache", + ".tox", + ".venv", + "__pycache__", + "node_modules", + "site-packages", + } + candidates: list[Path] = [] + root = root.resolve() + for directory, dirnames, filenames in os.walk(root): + current = Path(directory) + try: + depth = len(current.relative_to(root).parts) + except ValueError: + depth = 0 + if depth >= max_depth: + dirnames[:] = [] + else: + dirnames[:] = [ + dirname for dirname in dirnames if dirname not in skip_dir_names + ] + for filename in filenames: + candidate = current / filename + if _is_arch_target_artifact_candidate_name(candidate): + candidates.append(candidate) + return tuple(candidates) + + +def _is_arch_target_artifact_candidate_name(path: Path) -> bool: + name = path.name.lower() + suffix = path.suffix.lower() + if name in {"consumer_facts.jsonl", "consumer_facts.ndjson"}: + return True + if suffix not in {".db", ".sqlite", ".sqlite3"}: + return False + return name == "targets.db" or "arch_targets" in name + + +def _is_arch_target_artifact(path: Path) -> bool: + if not path.is_file(): + return False + if path.suffix.lower() in {".jsonl", ".ndjson"}: + return _is_arch_consumer_fact_jsonl(path) + if path.suffix.lower() in {".db", ".sqlite", ".sqlite3"}: + return _is_arch_sqlite_artifact(path) + return False + + +def _is_arch_consumer_fact_jsonl(path: Path) -> bool: + try: + with path.open() as file: + for line in file: + if not line.strip(): + continue + row = json.loads(line) + schema_version = str(row.get("schema_version") or "") + return schema_version.startswith("arch.consumer_fact") or ( + "aggregate_fact_key" in row and "observed_measure" in row + ) + except (OSError, json.JSONDecodeError): + return False + return False + + +def _is_arch_sqlite_artifact(path: Path) -> bool: + try: + conn = sqlite3.connect(path) + except sqlite3.Error: + return False + try: + tables = { + row[0] + for row in conn.execute( + "SELECT name FROM sqlite_master WHERE type = 'table'" + ).fetchall() + } + if "aggregate_facts" in tables: + return True + if not {"targets", "strata", "stratum_constraints"}.issubset(tables): + return False + target_columns = _sqlite_table_columns(conn, "targets") + required_target_columns = { + "id", + "stratum_id", + "variable", + "period", + "value", + "target_type", + "geographic_level", + "source", + } + return required_target_columns.issubset(target_columns) + except sqlite3.Error: + return False + finally: + conn.close() + + +def _filename_slug(value: str) -> str: + slug = "".join(character if character.isalnum() else "_" for character in value) + slug = "_".join(part for part in slug.split("_") if part) + return slug.lower() or "profile" + + +def _arch_target_refresh_summary_markdown( + coverage: ArchTargetProfileCoverageReport, + gaps: ArchTargetGapQueueReport, + *, + artifact_paths: tuple[Path, ...], + output_paths: tuple[Path, ...], +) -> str: + lines = [ + "# Arch Target Coverage Snapshot", + "", + f"- Profile: `{coverage.profile_name}`", + f"- Period: `{coverage.period}`", + f"- Target cells: `{coverage.target_cell_count}`", + f"- Covered cells: `{coverage.covered_cell_count}`", + f"- Uncovered cells: `{coverage.uncovered_cell_count}`", + f"- Coverage rate: `{coverage.coverage_rate:.1%}`", + "", + "## Coverage By Geography", + "", + "| Geography | Target cells | Covered | Uncovered |", + "| --- | ---: | ---: | ---: |", + ] + for geo_level, counts in sorted(coverage.by_geo_level.items()): + lines.append( + "| {geo_level} | {target_cell_count} | {covered_cell_count} | " + "{uncovered_cell_count} |".format(geo_level=geo_level, **counts) + ) + lines.extend( + [ + "", + "## Gap Categories", + "", + "| Category | Rows |", + "| --- | ---: |", + ] + ) + for category, count in sorted(gaps.by_gap_category.items()): + lines.append(f"| `{category}` | {count} |") + lines.extend( + [ + "", + "## Inputs", + "", + *(f"- `{path}`" for path in artifact_paths), + "", + "## Outputs", + "", + *(f"- `{path}`" for path in output_paths), + "", + ] + ) + return "\n".join(lines) + + +def _target_filter_tuple( + target: CanonicalTargetSpec, +) -> tuple[tuple[str, str, str], ...]: + return tuple( + sorted( + ( + str(target_filter.feature), + str(getattr(target_filter.operator, "value", target_filter.operator)), + _json_scalar_text(target_filter.value), + ) + for target_filter in target.filters + ) + ) + + +def _jurisdiction_clause(jurisdiction: str) -> str: + normalized = jurisdiction.upper().replace("-", "_") + if normalized == "US": + return "upper(s.jurisdiction) LIKE 'US%'" + return f"upper(s.jurisdiction) = '{normalized}'" + + +def _as_string_tuple(value: Any) -> tuple[str, ...]: + if value is None: + return () + if isinstance(value, str): + return (value,) + return tuple(str(item) for item in value) + + +def _as_target_cell_filters(value: Any) -> tuple[dict[str, Any], ...]: + if value is None: + return () + if isinstance(value, dict): + return (dict(value),) + return tuple(dict(item) for item in value if item is not None) + + +__all__ = [ + "ArchCompositeSQLiteTargetProvider", + "ArchConsumerFactJSONLTargetProvider", + "ArchFactSQLiteTargetProvider", + "ArchTargetCellCoverage", + "ArchTargetGapQueueReport", + "ArchTargetGapQueueRow", + "ArchTargetParityReport", + "ArchTargetParityRow", + "ArchTargetProfileCoverageReport", + "ArchSQLiteTargetProvider", + "ArchTargetRecord", + "SOIAgingFactors", + "arch_target_record_to_canonical_spec", + "resolve_arch_sqlite_target_provider", + "summarize_arch_target_gap_queue", + "summarize_arch_target_parity", + "summarize_arch_target_profile_coverage", +] diff --git a/src/microplex_us/targets/census_blocks.py b/src/microplex_us/targets/census_blocks.py new file mode 100644 index 0000000..f1b53ca --- /dev/null +++ b/src/microplex_us/targets/census_blocks.py @@ -0,0 +1,363 @@ +"""Census block-derived target providers.""" + +from __future__ import annotations + +from collections.abc import Iterable +from pathlib import Path +from typing import Any + +import pandas as pd +from microplex.core import EntityType +from microplex.targets import ( + TabularRollupSpec, + TabularRollupTargetProvider, + TargetAggregation, + TargetQuery, + TargetSpec, + as_string_tuple, + build_tabular_rollup_targets, +) + +from microplex_us.geography import ( + load_block_probabilities, + normalize_state_legislative_district_id, +) + +CENSUS_BLOCK_POPULATION_VARIABLE = "person_count" +CENSUS_BLOCK_POPULATION_SOURCE = "Census 2020 PL 94-171" +CENSUS_BLOCK_POPULATION_UNITS = "persons" +CENSUS_BLOCK_TARGET_PERIOD = 2024 +CENSUS_BLOCK_SOURCE_YEAR = 2020 +CENSUS_BLOCK_GEOGRAPHY_YEAR = 2020 + +DEFAULT_CENSUS_BLOCK_POPULATION_GEO_LEVELS: tuple[str, ...] = ( + "national", + "state", + "county", + "cd", + "sldu", + "sldl", + "cbsa", + "spm_metro_area", +) + + +CensusBlockPopulationRollup = TabularRollupSpec + + +CENSUS_BLOCK_POPULATION_ROLLUPS: dict[str, CensusBlockPopulationRollup] = { + "national": CensusBlockPopulationRollup( + geo_level="national", + source_column=None, + filter_feature=None, + group_name="census_block_population_national", + name_prefix="census_block_population_national", + ), + "state": CensusBlockPopulationRollup( + geo_level="state", + source_column="state_fips", + filter_feature="state_fips", + group_name="census_block_population_state", + name_prefix="census_block_population_state", + ), + "county": CensusBlockPopulationRollup( + geo_level="county", + source_column="county_fips", + filter_feature="county_fips", + group_name="census_block_population_county", + name_prefix="census_block_population_county", + ), + "tract": CensusBlockPopulationRollup( + geo_level="tract", + source_column="tract_geoid", + filter_feature="tract_geoid", + group_name="census_block_population_tract", + name_prefix="census_block_population_tract", + ), + "block": CensusBlockPopulationRollup( + geo_level="block", + source_column="geoid", + filter_feature="block_geoid", + group_name="census_block_population_block", + name_prefix="census_block_population_block", + ), + "cd": CensusBlockPopulationRollup( + geo_level="cd", + source_column="cd_id", + filter_feature="cd_id", + group_name="census_block_population_cd", + name_prefix="census_block_population_cd", + ), + "sldu": CensusBlockPopulationRollup( + geo_level="sldu", + source_column="sldu_id", + filter_feature="sldu_id", + group_name="census_block_population_sldu", + name_prefix="census_block_population_sldu", + ), + "sldl": CensusBlockPopulationRollup( + geo_level="sldl", + source_column="sldl_id", + filter_feature="sldl_id", + group_name="census_block_population_sldl", + name_prefix="census_block_population_sldl", + ), + "cbsa": CensusBlockPopulationRollup( + geo_level="cbsa", + source_column="cbsa_code", + filter_feature="cbsa_code", + group_name="census_block_population_cbsa", + name_prefix="census_block_population_cbsa", + ), + "spm_metro_area": CensusBlockPopulationRollup( + geo_level="spm_metro_area", + source_column="spm_metro_area", + filter_feature="spm_metro_area", + group_name="census_block_population_spm_metro_area", + name_prefix="census_block_population_spm_metro_area", + ), +} +CENSUS_BLOCK_POPULATION_GEO_LEVELS: tuple[str, ...] = tuple( + CENSUS_BLOCK_POPULATION_ROLLUPS +) + + +class CensusBlockPopulationTargetProvider(TabularRollupTargetProvider): + """Build population count targets by rolling Census blocks to parent geos.""" + + def __init__( + self, + block_probabilities: pd.DataFrame | None = None, + *, + block_probabilities_path: str | Path | None = None, + default_geo_levels: Iterable[str] = DEFAULT_CENSUS_BLOCK_POPULATION_GEO_LEVELS, + period: int = CENSUS_BLOCK_TARGET_PERIOD, + ) -> None: + super().__init__( + block_probabilities, + data_path=block_probabilities_path, + data_loader=load_block_probabilities, + prepare_data=_prepare_block_probabilities, + rollups=CENSUS_BLOCK_POPULATION_ROLLUPS, + value_column="population", + variable=CENSUS_BLOCK_POPULATION_VARIABLE, + variable_aliases=("population",), + entity=EntityType.PERSON, + aggregation=TargetAggregation.COUNT, + period=period, + source=CENSUS_BLOCK_POPULATION_SOURCE, + units=CENSUS_BLOCK_POPULATION_UNITS, + default_geo_levels=default_geo_levels, + min_value=0.0, + normalize_geographic_id=_normalize_census_block_geographic_id, + base_metadata={ + "source_year": CENSUS_BLOCK_SOURCE_YEAR, + "geography_year": CENSUS_BLOCK_GEOGRAPHY_YEAR, + "source_artifact": "census_2020_pl_94_171_state_files", + "support_artifact": "block_probabilities.parquet", + "block_rollup": True, + }, + ) + + def load_target_set(self, query: TargetQuery | None = None): + """Load Census block rollup targets with US SLD ID alias support.""" + if query is None or "geographic_ids" not in query.provider_filters: + return super().load_target_set(query) + provider_filters = dict(query.provider_filters) + geo_levels = _requested_census_block_geo_levels( + provider_filters, + default_geo_levels=self.default_geo_levels, + ) + provider_filters["geographic_ids"] = _expand_census_block_geographic_ids( + provider_filters["geographic_ids"], + geo_levels=geo_levels, + ) + return super().load_target_set( + TargetQuery( + period=query.period, + entity=query.entity, + names=query.names, + metadata_filters=query.metadata_filters, + provider_filters=provider_filters, + ) + ) + + +def build_census_block_population_targets( + block_probabilities: pd.DataFrame, + *, + geo_levels: Iterable[str] = DEFAULT_CENSUS_BLOCK_POPULATION_GEO_LEVELS, + geographic_ids: Iterable[str] | None = None, + period: int = CENSUS_BLOCK_TARGET_PERIOD, +) -> list[TargetSpec]: + """Roll block-level Census population counts to canonical target specs.""" + requested_geo_levels = as_string_tuple(geo_levels) + resolved_geo_levels = ( + CENSUS_BLOCK_POPULATION_GEO_LEVELS + if requested_geo_levels == ("all",) + else requested_geo_levels + ) + return build_tabular_rollup_targets( + _prepare_block_probabilities(block_probabilities), + rollups=CENSUS_BLOCK_POPULATION_ROLLUPS, + value_column="population", + variable=CENSUS_BLOCK_POPULATION_VARIABLE, + entity=EntityType.PERSON, + aggregation=TargetAggregation.COUNT, + period=period, + source=CENSUS_BLOCK_POPULATION_SOURCE, + units=CENSUS_BLOCK_POPULATION_UNITS, + geo_levels=resolved_geo_levels, + geographic_ids=_expand_census_block_geographic_ids( + geographic_ids, + geo_levels=resolved_geo_levels, + ), + min_value=0.0, + normalize_geographic_id=_normalize_census_block_geographic_id, + base_metadata={ + "source_year": CENSUS_BLOCK_SOURCE_YEAR, + "geography_year": CENSUS_BLOCK_GEOGRAPHY_YEAR, + "source_artifact": "census_2020_pl_94_171_state_files", + "support_artifact": "block_probabilities.parquet", + "block_rollup": True, + }, + ) + + +def _prepare_block_probabilities(block_probabilities: pd.DataFrame) -> pd.DataFrame: + if "population" not in block_probabilities.columns: + raise ValueError("Block probabilities must include a population column") + blocks = block_probabilities.copy() + blocks["population"] = pd.to_numeric(blocks["population"], errors="coerce") + if "state_fips" in blocks.columns: + blocks["state_fips"] = _zero_pad_series(blocks["state_fips"], 2) + if "county_fips" in blocks.columns: + blocks["county_fips"] = _zero_pad_series(blocks["county_fips"], 5) + elif {"state_fips", "county"}.issubset(blocks.columns): + blocks["county_fips"] = blocks["state_fips"] + _zero_pad_series( + blocks["county"], 3 + ) + if "tract_geoid" not in blocks.columns and { + "state_fips", + "county", + "tract", + }.issubset(blocks.columns): + blocks["tract_geoid"] = ( + blocks["state_fips"] + + _zero_pad_series(blocks["county"], 3) + + _zero_pad_series(blocks["tract"], 6) + ) + if "sldu_id" in blocks.columns: + blocks["sldu_id"] = blocks["sldu_id"].map( + lambda value: ( + normalize_state_legislative_district_id( + value, + chamber="upper", + ) + or "" + ) + ) + if "sldl_id" in blocks.columns: + blocks["sldl_id"] = blocks["sldl_id"].map( + lambda value: ( + normalize_state_legislative_district_id( + value, + chamber="lower", + ) + or "" + ) + ) + for column in ( + "geoid", + "tract_geoid", + "cd_id", + "cbsa_code", + "spm_metro_area", + ): + if column in blocks.columns: + blocks[column] = blocks[column].map(_normalize_geographic_id) + return blocks + + +def _zero_pad_series(values: pd.Series, width: int) -> pd.Series: + text = values.astype("string").str.strip() + numeric = pd.to_numeric(text, errors="coerce") + numeric_text = numeric.round().astype("Int64").astype("string").str.zfill(width) + return text.where(numeric.isna(), numeric_text).str.zfill(width) + + +def _normalize_geographic_id(value: Any) -> str: + if pd.isna(value): + return "" + text = str(value).strip() + if not text: + return "" + if text.endswith(".0") and text[:-2].isdigit(): + return text[:-2] + return text + + +def _normalize_census_block_geographic_id(value: Any) -> str: + raw = "" if pd.isna(value) else str(value).strip() + normalized_sld = normalize_state_legislative_district_id(value) + if normalized_sld is not None and normalized_sld != raw: + return normalized_sld + return _normalize_geographic_id(value) + + +def _requested_census_block_geo_levels( + provider_filters: dict[str, Any], + *, + default_geo_levels: Iterable[str], +) -> tuple[str, ...]: + if "geo_levels" in provider_filters: + requested = as_string_tuple(provider_filters["geo_levels"]) + elif "geographic_levels" in provider_filters: + requested = as_string_tuple(provider_filters["geographic_levels"]) + else: + requested = tuple(default_geo_levels) + return ( + tuple(CENSUS_BLOCK_POPULATION_ROLLUPS) if requested == ("all",) else requested + ) + + +def _expand_census_block_geographic_ids( + geographic_ids: Iterable[str] | Any | None, + *, + geo_levels: Iterable[str], +) -> tuple[str, ...] | None: + if geographic_ids is None: + return None + levels = set(as_string_tuple(geo_levels)) + include_upper = "sldu" in levels + include_lower = "sldl" in levels + expanded: list[str] = [] + for value in as_string_tuple(geographic_ids): + normalized = _normalize_census_block_geographic_id(value) + if normalized: + expanded.append(normalized) + if include_upper: + upper = normalize_state_legislative_district_id(value, chamber="upper") + if upper: + expanded.append(upper) + if include_lower: + lower = normalize_state_legislative_district_id(value, chamber="lower") + if lower: + expanded.append(lower) + return tuple(dict.fromkeys(expanded)) + + +__all__ = [ + "CENSUS_BLOCK_GEOGRAPHY_YEAR", + "CENSUS_BLOCK_POPULATION_GEO_LEVELS", + "CENSUS_BLOCK_POPULATION_ROLLUPS", + "CENSUS_BLOCK_POPULATION_SOURCE", + "CENSUS_BLOCK_POPULATION_UNITS", + "CENSUS_BLOCK_POPULATION_VARIABLE", + "CENSUS_BLOCK_SOURCE_YEAR", + "CENSUS_BLOCK_TARGET_PERIOD", + "DEFAULT_CENSUS_BLOCK_POPULATION_GEO_LEVELS", + "CensusBlockPopulationRollup", + "CensusBlockPopulationTargetProvider", + "build_census_block_population_targets", +] diff --git a/tests/pipelines/test_cd_age_reweighting.py b/tests/pipelines/test_cd_age_reweighting.py new file mode 100644 index 0000000..2531fec --- /dev/null +++ b/tests/pipelines/test_cd_age_reweighting.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import sqlite3 + +import h5py +import numpy as np + +from microplex_us.pipelines.cd_age_reweighting import ( + normalize_at_large_cd_geoids, + reweight_h5_to_cd_age_targets, +) + + +def test_normalize_at_large_cd_geoids_maps_statewide_zero_to_one() -> None: + values = np.asarray([200, 201, 1000, 3601, 0], dtype=np.int64) + + normalized = normalize_at_large_cd_geoids(values) + + np.testing.assert_array_equal( + normalized, + np.asarray([201, 201, 1001, 3601, 0], dtype=np.int64), + ) + + +def test_reweight_h5_to_cd_age_targets_matches_simple_at_large_targets(tmp_path) -> None: + dataset = tmp_path / "input.h5" + output = tmp_path / "output.h5" + db = tmp_path / "policy_data.db" + _write_minimal_h5(dataset) + _write_cd_age_target_db(db) + + summary = reweight_h5_to_cd_age_targets( + input_dataset=dataset, + target_db=db, + output_dataset=output, + period=2024, + max_iter=100, + preserve_district_weight_sum=False, + ) + + assert summary["n_targets"] == 2 + assert summary["max_abs_relative_error_after"] < 1e-5 + with h5py.File(output, "r") as handle: + np.testing.assert_allclose( + handle["household_weight"]["2024"][:], + np.asarray([10.0, 20.0], dtype=np.float32), + rtol=1e-5, + ) + np.testing.assert_array_equal( + handle["congressional_district_geoid"]["2024"][:], + np.asarray([201, 201]), + ) + + +def _write_minimal_h5(path): + with h5py.File(path, "w") as handle: + _write_period(handle, "household_id", [1, 2]) + _write_period(handle, "household_weight", [1.0, 1.0]) + _write_period(handle, "congressional_district_geoid", [200, 200]) + _write_period(handle, "person_household_id", [1, 2]) + _write_period(handle, "age", [4, 40]) + + +def _write_period(handle, variable, values): + group = handle.create_group(variable) + group.create_dataset("2024", data=np.asarray(values)) + + +def _write_cd_age_target_db(path): + conn = sqlite3.connect(path) + try: + conn.executescript( + """ + CREATE TABLE targets ( + target_id INTEGER PRIMARY KEY, + variable TEXT, + period INTEGER, + stratum_id INTEGER, + reform_id INTEGER DEFAULT 0, + value REAL, + active INTEGER DEFAULT 1, + tolerance REAL, + source TEXT, + notes TEXT + ); + CREATE TABLE strata ( + stratum_id INTEGER PRIMARY KEY, + definition_hash TEXT, + parent_stratum_id INTEGER + ); + CREATE TABLE stratum_constraints ( + stratum_id INTEGER, + constraint_variable TEXT, + operation TEXT, + value TEXT + ); + CREATE VIEW target_overview AS + SELECT + target_id, + stratum_id, + variable, + value, + period, + active, + 'district' AS geo_level, + '201' AS geographic_id, + 'age' AS domain_variable + FROM targets; + """ + ) + _insert_target(conn, 1, 101, 10.0, [("age", "<", "18"), ("age", ">", "-1")]) + _insert_target(conn, 2, 102, 20.0, [("age", ">=", "18")]) + conn.commit() + finally: + conn.close() + + +def _insert_target(conn, target_id, stratum_id, value, constraints): + conn.execute( + """ + INSERT INTO targets + (target_id, variable, period, stratum_id, reform_id, value, active) + VALUES (?, 'person_count', 2024, ?, 0, ?, 1) + """, + (target_id, stratum_id, value), + ) + conn.execute("INSERT INTO strata (stratum_id) VALUES (?)", (stratum_id,)) + for constraint in [ + ("congressional_district_geoid", "==", "201"), + *constraints, + ]: + conn.execute( + """ + INSERT INTO stratum_constraints + (stratum_id, constraint_variable, operation, value) + VALUES (?, ?, ?, ?) + """, + (stratum_id, *constraint), + ) diff --git a/tests/pipelines/test_dashboard.py b/tests/pipelines/test_dashboard.py new file mode 100644 index 0000000..f9f7a70 --- /dev/null +++ b/tests/pipelines/test_dashboard.py @@ -0,0 +1,348 @@ +import json + +import numpy as np +import pytest + +from microplex_us.pipelines.dashboard import build_dashboard_payload +from microplex_us.pipelines.run_contract import RunContractWriter + + +def test_dashboard_payload_marks_missing_pe_l0_comparators(tmp_path): + artifacts = tmp_path / "artifacts" + run_dir = artifacts / "latest" + run_dir.mkdir(parents=True) + (run_dir / "scores.json").write_text( + json.dumps( + [ + { + "metric": "pe_native_broad_loss", + "period": 2024, + "summary": { + "baseline_enhanced_cps_native_loss": 0.0977, + "candidate_beats_baseline": True, + "candidate_enhanced_cps_native_loss": 0.0252, + "enhanced_cps_native_loss_delta": -0.0725, + "n_targets_kept": 2805, + "n_targets_total": 2816, + }, + "broad_loss": { + "baseline_dataset": "enhanced_cps_2024.h5", + "candidate_dataset": "pe_l0_candidate.h5", + "baseline_weight_sum": 153.8, + "candidate_weight_sum": 153.7, + }, + } + ] + ) + ) + screen_dir = artifacts / "local_screen" + screen_dir.mkdir() + (screen_dir / "split_loss_summary.json").write_text( + json.dumps( + { + "candidate": "cd_age_w8", + "broad_objective_on_latest_pe_matrix_rows": 0.0262, + "latest_pe_baseline_broad_loss": 0.0977, + "cd_age_mean_abs_relative_error": 0.0155, + } + ) + ) + (screen_dir / "scores.json").write_text( + json.dumps( + [ + { + "summary": { + "baseline_enhanced_cps_native_loss": 0.0977, + "candidate_beats_baseline": True, + "candidate_enhanced_cps_native_loss": 0.0263, + "enhanced_cps_native_loss_delta": -0.0714, + } + } + ] + ) + ) + local_l0_dir = artifacts / "pe_local_area_l0_compare" + local_l0_dir.mkdir() + (local_l0_dir / "pe_local_area_l0_state_stack_vs_legacy_ecps.json").write_text( + json.dumps( + { + "metric": "enhanced_cps_native_loss_target_delta", + "from_dataset": "legacy-pe-ecps", + "to_dataset": "pe-local-area-l0-state-stack", + "state_score_count": 51, + "state_weight_sum": 121.0, + "summary": { + "n_targets": 2814, + "from_loss": 0.1747, + "to_loss": 3.0, + "loss_delta": 2.8253, + }, + } + ) + ) + microplex_l0_dir = artifacts / "microplex_actual_l0" + microplex_l0_dir.mkdir() + (microplex_l0_dir / "unified_diagnostics.csv").write_text( + "\n".join( + [ + "target,true_value,estimate,rel_error,abs_rel_error,achievable", + "a,100,90,-0.10,0.10,True", + "b,100,100,0.00,0.00,True", + ] + ) + ) + (microplex_l0_dir / "unified_run_config.json").write_text( + json.dumps({"n_clones": 10, "epochs": 300}) + ) + np.save(microplex_l0_dir / "calibration_weights.npy", np.array([1.0, 0.0, 200.0])) + target_diagnostics = artifacts / "pe_native_target_diagnostics_current.json" + target_diagnostics.write_text( + json.dumps( + { + "dataset_labels": {"from": "PE", "to": "Microplex"}, + "summary": {"n_targets": 0}, + "targets": [], + } + ) + ) + pe_repo = tmp_path / "policyengine-us-data" + for dirname, epochs, mean_error in [ + ("local_net_worth_100", 100, 5.5), + ("local_net_worth_100_e300", 300, 2.5), + ]: + model_dir = ( + pe_repo + / "policyengine_us_data" + / "storage" + / "calibration" + / dirname + ) + model_dir.mkdir(parents=True) + (model_dir / "unified_run_config.json").write_text( + json.dumps( + { + "dataset": "source_imputed_stratified_extended_cps_2024.h5", + "db_path": "policy_data.db", + "n_clones": 430, + "epochs": epochs, + "n_targets": 2, + "n_records": 3_000_000, + "weight_sum": 153.0, + "weight_nonzero": 1000, + "mean_error_pct": mean_error, + } + ) + ) + (model_dir / "unified_diagnostics.csv").write_text( + "\n".join( + [ + "target,true_value,estimate,rel_error,abs_rel_error,achievable", + "a,100,95,-0.05,0.05,True", + "b,100,80,-0.20,0.20,True", + ] + ) + ) + + payload = build_dashboard_payload( + artifact_root=artifacts, + target_diagnostics_path=target_diagnostics, + policyengine_us_data_repo=pe_repo, + include_tmux=False, + ) + + assertions = payload["run_board"]["assertions"] + assert assertions["microplex_beats_legacy_ecps_latest_pe_broad"] is True + assert assertions["policyengine_small_l0_weight_package_available"] is True + assert assertions["policyengine_big_l0_weight_package_available"] is True + assert assertions["microplex_vs_small_l0_complete"] is False + assert assertions["microplex_vs_big_l0_complete"] is False + assert ( + assertions["microplex_vs_all_three_pe_models_on_both_metrics"] is False + ) + assert assertions["policyengine_materialized_l0_same_harness_available"] is True + assert assertions["apples_to_apples_groups_available"] is True + assert payload["run_board"]["score_runs"][0]["candidate_loss"] == 0.0252 + assert payload["run_board"]["local_target_screens"][0]["label"] == "cd_age_w8" + assert payload["run_board"]["local_target_screens"][0]["status"] == ( + "screen_scored_latest_pe" + ) + assert payload["run_board"]["local_target_screens"][0][ + "pe_native_broad_loss" + ] == 0.0263 + assert ( + payload["run_board"]["materialized_policyengine_l0_scores"][0][ + "candidate_loss" + ] + == 3.0 + ) + actual_l0_runs = payload["run_board"]["actual_l0_objective_runs"] + assert actual_l0_runs[0]["model_id"] == "microplex_actual_l0" + assert actual_l0_runs[0]["actual_l0_data_loss"] == pytest.approx(100 / (101**2)) + assert actual_l0_runs[0]["weights"]["nonzero"] == 2 + groups = {row["id"]: row for row in payload["run_board"]["apples_to_apples"]} + assert groups["latest_pe_broad"]["rows"][0]["score"] == 0.0977 + assert groups["legacy_broad"]["rows"][2]["model_id"] == ( + "policyengine_local_area_l0_state_stack" + ) + models = { + row["id"]: row for row in payload["run_board"]["policyengine_l0_models"] + } + assert models["policyengine_small_l0"]["epochs"] == 100 + assert ( + models["policyengine_big_l0"]["diagnostics"][ + "mean_abs_relative_error_pct" + ] + == 12.5 + ) + assert models["policyengine_big_l0"]["diagnostics"][ + "actual_l0_objective" + ] == "sum(((estimate - target) / (target + 1)) ** 2)" + assert models["policyengine_big_l0"]["diagnostics"][ + "actual_l0_data_loss" + ] == pytest.approx(425 / (101**2)) + + +def test_dashboard_payload_wires_materialized_pe_l0_score_jsons(tmp_path): + artifacts = tmp_path / "artifacts" + artifacts.mkdir() + latest_dir = artifacts / "latest_microplex" + legacy_dir = artifacts / "legacy_microplex" + latest_dir.mkdir() + legacy_dir.mkdir() + (latest_dir / "scores.json").write_text( + json.dumps( + [ + { + "metric": "pe_native_broad_loss", + "summary": { + "baseline_enhanced_cps_native_loss": 0.10, + "candidate_beats_baseline": True, + "candidate_enhanced_cps_native_loss": 0.03, + "n_targets_kept": 2805, + }, + "broad_loss": { + "candidate_dataset": "microplex_latest.h5", + "baseline_dataset": "enhanced_cps_2024.h5", + }, + } + ] + ) + ) + (legacy_dir / "scores.json").write_text( + json.dumps( + [ + { + "metric": "pe_native_broad_loss", + "summary": { + "baseline_enhanced_cps_native_loss": 0.17, + "candidate_beats_baseline": True, + "candidate_enhanced_cps_native_loss": 0.06, + "n_targets_kept": 2814, + }, + "broad_loss": { + "candidate_dataset": "microplex_legacy.h5", + "baseline_dataset": "enhanced_cps_2024.h5", + }, + } + ] + ) + ) + score_dir = artifacts / "pe_l0_clone_apples_to_apples" + score_dir.mkdir() + for metric, targets, small_loss, big_loss in [ + ("legacy_targets", 2814, 0.15, 0.12), + ("new_targets", 2805, 0.09, 0.08), + ]: + for label, loss in [ + ("pe_small_l0", small_loss), + ("pe_big_l0", big_loss), + ]: + (score_dir / f"{metric}_{label}_score.json").write_text( + json.dumps( + { + "metric": "enhanced_cps_native_loss", + "candidate_dataset": f"/tmp/{label}.h5", + "baseline_dataset": "/tmp/enhanced_cps_2024.h5", + "baseline_enhanced_cps_native_loss": ( + 0.10 if metric == "new_targets" else 0.17 + ), + "candidate_beats_baseline": loss < ( + 0.10 if metric == "new_targets" else 0.17 + ), + "candidate_enhanced_cps_native_loss": loss, + "enhanced_cps_native_loss_delta": loss + - (0.10 if metric == "new_targets" else 0.17), + "n_targets_kept": targets, + "n_targets_total": targets + 10, + } + ) + ) + + pe_repo = tmp_path / "policyengine-us-data" + for dirname in ["local_net_worth_100", "local_net_worth_100_e300"]: + model_dir = ( + pe_repo + / "policyengine_us_data" + / "storage" + / "calibration" + / dirname + ) + model_dir.mkdir(parents=True) + (model_dir / "unified_run_config.json").write_text( + json.dumps({"n_targets": 2, "epochs": 100}) + ) + (model_dir / "unified_diagnostics.csv").write_text( + "\n".join( + [ + "target,true_value,estimate,rel_error,abs_rel_error,achievable", + "a,100,95,-0.05,0.05,True", + ] + ) + ) + + payload = build_dashboard_payload( + artifact_root=artifacts, + target_diagnostics_path=artifacts / "missing.json", + policyengine_us_data_repo=pe_repo, + include_tmux=False, + ) + + assertions = payload["run_board"]["assertions"] + assert assertions["microplex_vs_small_l0_complete"] is True + assert assertions["microplex_vs_big_l0_complete"] is True + assert assertions["microplex_vs_all_three_pe_models_on_both_metrics"] is True + groups = {row["id"]: row for row in payload["run_board"]["apples_to_apples"]} + latest_rows = { + row["model_id"]: row for row in groups["latest_pe_broad"]["rows"] + } + legacy_rows = { + row["model_id"]: row for row in groups["legacy_broad"]["rows"] + } + assert latest_rows["policyengine_small_l0"]["score"] == 0.09 + assert latest_rows["policyengine_big_l0"]["score"] == 0.08 + assert legacy_rows["policyengine_small_l0"]["score"] == 0.15 + assert legacy_rows["policyengine_big_l0"]["score"] == 0.12 + + +def test_dashboard_payload_reads_run_contract_summaries(tmp_path): + artifacts = tmp_path / "artifacts" + writer = RunContractWriter( + artifacts / "contracted_run", + run_id="contracted-run", + attempt_id="attempt-1", + ) + with writer.stage("preflight"): + pass + + payload = build_dashboard_payload( + artifact_root=artifacts, + policyengine_us_data_repo=None, + include_tmux=False, + ) + + contracts = payload["run_board"]["run_contracts"] + assert len(contracts) == 1 + assert contracts[0]["status_source"] == "contract" + assert contracts[0]["run_id"] == "contracted-run" + assert contracts[0]["status"] == "running" + assert contracts[0]["completed_stages"] == ["preflight"] diff --git a/tests/pipelines/test_pe_native_calibration_benchmark.py b/tests/pipelines/test_pe_native_calibration_benchmark.py new file mode 100644 index 0000000..622bbd6 --- /dev/null +++ b/tests/pipelines/test_pe_native_calibration_benchmark.py @@ -0,0 +1,210 @@ +"""Tests for PE-native calibration strategy benchmarking.""" + +from __future__ import annotations + +import shutil +from pathlib import Path + +import h5py +import numpy as np + +from microplex_us.pipelines.pe_native_calibration_benchmark import ( + build_policyengine_us_native_calibration_benchmark, + compute_household_weight_diagnostics, +) + + +def _write_dataset(path: Path, weights: list[float]) -> Path: + household_ids = np.arange(1, len(weights) + 1, dtype=np.int64) + with h5py.File(path, "w") as handle: + household_id = handle.create_group("household_id") + household_id.create_dataset("2024", data=household_ids) + household_weight = handle.create_group("household_weight") + household_weight.create_dataset( + "2024", + data=np.asarray(weights, dtype=np.float32), + ) + return path + + +def test_compute_household_weight_diagnostics_compares_reference_by_id( + tmp_path: Path, +) -> None: + candidate = _write_dataset(tmp_path / "candidate.h5", [3.0, 0.0, 9.0]) + reference = tmp_path / "reference.h5" + with h5py.File(reference, "w") as handle: + household_id = handle.create_group("household_id") + household_id.create_dataset("2024", data=np.asarray([3, 1, 2])) + household_weight = handle.create_group("household_weight") + household_weight.create_dataset( + "2024", + data=np.asarray([6.0, 2.0, 1.0], dtype=np.float32), + ) + + diagnostics = compute_household_weight_diagnostics( + candidate, + reference_dataset_path=reference, + ) + + assert diagnostics["household_count"] == 3 + assert diagnostics["positive_household_count"] == 2 + assert diagnostics["weight_sum"] == 12.0 + assert diagnostics["reference_alignment"] == "matched_by_household_id" + assert diagnostics["reference_weight_sum"] == 9.0 + assert diagnostics["weight_sum_delta"] == 3.0 + assert diagnostics["changed_household_count"] == 3 + assert np.isclose(diagnostics["effective_sample_size"], 1.6) + + +def test_build_policyengine_us_native_calibration_benchmark_scores_variants( + monkeypatch, + tmp_path: Path, +) -> None: + input_dataset = _write_dataset(tmp_path / "input.h5", [1.0, 1.0]) + baseline_dataset = _write_dataset(tmp_path / "baseline.h5", [2.0, 2.0]) + existing_dataset = _write_dataset(tmp_path / "current_weight_diff.h5", [1.2, 0.8]) + output_dir = tmp_path / "benchmark" + + def fake_extract(**_kwargs): + return { + "scaled_matrix": np.eye(2), + "scaled_target": np.asarray([1.0, 0.0]), + "initial_weights": np.asarray([1.0, 1.0]), + "metadata": { + "target_names": ["nation/fake", "state/fake"], + "skip_tax_expenditure_targets": True, + }, + } + + def fake_optimize_weights(**kwargs): + penalty = float(kwargs["l2_penalty"]) + weights = np.asarray([1.9, 0.1] if penalty == 0.0 else [1.4, 0.6]) + return weights, { + "initial_loss": 1.25, + "optimized_loss": 0.5 if penalty == 0.0 else 0.75, + "loss_delta": -0.75 if penalty == 0.0 else -0.5, + "initial_weight_sum": 2.0, + "optimized_weight_sum": float(weights.sum()), + "household_count": 2, + "positive_household_count": 2, + "budget": None, + "iterations": 3, + "converged": True, + } + + def fake_rewrite(**kwargs): + output_path = Path(kwargs["output_dataset_path"]) + shutil.copy2(kwargs["input_dataset_path"], output_path) + with h5py.File(output_path, "r+") as handle: + handle["household_weight"]["2024"][...] = np.asarray( + kwargs["household_weights"], + dtype=np.float32, + ) + return output_path.resolve() + + def fake_scores(**kwargs): + results = [] + for candidate_path in kwargs["candidate_dataset_paths"]: + path = Path(candidate_path).resolve() + if path.name == "input.h5": + loss = 1.0 + elif path.name == "current_weight_diff.h5": + loss = 0.8 + elif "unconstrained" in path.name: + loss = 0.4 + else: + loss = 0.6 + results.append( + { + "metric": "enhanced_cps_native_loss", + "period": 2024, + "summary": { + "candidate_enhanced_cps_native_loss": loss, + "baseline_enhanced_cps_native_loss": 0.5, + "enhanced_cps_native_loss_delta": loss - 0.5, + "candidate_beats_baseline": loss < 0.5, + "candidate_unweighted_msre": loss + 0.1, + "baseline_unweighted_msre": 0.7, + "unweighted_msre_delta": loss - 0.6, + "n_targets_total": 4, + "n_targets_kept": 3, + "n_targets_zero_dropped": 1, + "n_targets_bad_dropped": 0, + "n_national_targets": 1, + "n_state_targets": 2, + "skip_tax_expenditure_targets": True, + }, + "broad_loss": { + "metric": "enhanced_cps_native_loss", + "period": 2024, + "candidate_dataset": str(path), + "baseline_dataset": str(baseline_dataset.resolve()), + "candidate_enhanced_cps_native_loss": loss, + "baseline_enhanced_cps_native_loss": 0.5, + "enhanced_cps_native_loss_delta": loss - 0.5, + "candidate_beats_baseline": loss < 0.5, + "candidate_unweighted_msre": loss + 0.1, + "baseline_unweighted_msre": 0.7, + "unweighted_msre_delta": loss - 0.6, + "n_targets_total": 4, + "n_targets_kept": 3, + "n_targets_zero_dropped": 1, + "n_targets_bad_dropped": 0, + "n_national_targets": 1, + "n_state_targets": 2, + "candidate_weight_sum": 2.0, + "baseline_weight_sum": 4.0, + "skip_tax_expenditure_targets": True, + "family_breakdown": [], + }, + "family_breakdown": [], + } + ) + return results + + monkeypatch.setattr( + "microplex_us.pipelines.pe_native_calibration_benchmark." + "_extract_pe_native_loss_inputs", + fake_extract, + ) + monkeypatch.setattr( + "microplex_us.pipelines.pe_native_calibration_benchmark." + "optimize_pe_native_loss_weights", + fake_optimize_weights, + ) + monkeypatch.setattr( + "microplex_us.pipelines.pe_native_calibration_benchmark." + "rewrite_policyengine_us_dataset_weights", + fake_rewrite, + ) + monkeypatch.setattr( + "microplex_us.pipelines.pe_native_calibration_benchmark." + "compute_batch_us_pe_native_scores", + fake_scores, + ) + + payload = build_policyengine_us_native_calibration_benchmark( + input_dataset_path=input_dataset, + baseline_dataset_path=baseline_dataset, + output_dir=output_dir, + l2_penalties=(0.0, 1e-8), + max_iter=5, + target_total_weight_source="baseline", + existing_candidates={"current_weight_diff": existing_dataset}, + skip_tax_expenditure_targets=True, + ) + + assert payload["variant_count"] == 4 + assert payload["target_total_weight"] == 4.0 + assert payload["target_total_weight_resolved_from"] == "baseline" + assert payload["best_variant_label"] == "pe_native_unconstrained_baseline_total" + assert [row["label"] for row in payload["ranking"][:2]] == [ + "pe_native_unconstrained_baseline_total", + "pe_native_l2_1e-08_baseline_total", + ] + unconstrained = next( + row for row in payload["rows"] if row["label"].startswith("pe_native_unconstrained") + ) + assert unconstrained["optimization"]["l2_penalty"] == 0.0 + assert unconstrained["weight_diagnostics"]["reference_alignment"] == "same_order" + assert unconstrained["weight_diagnostics"]["changed_household_count"] == 2 diff --git a/tests/pipelines/test_pe_us_dataset_readiness.py b/tests/pipelines/test_pe_us_dataset_readiness.py new file mode 100644 index 0000000..8943678 --- /dev/null +++ b/tests/pipelines/test_pe_us_dataset_readiness.py @@ -0,0 +1,147 @@ +"""Tests for lightweight PE-US H5 readiness audits.""" + +from __future__ import annotations + +import json + +import h5py +import numpy as np + +from microplex_us.pipelines.pe_us_dataset_readiness import ( + DEFAULT_EXPECTED_MATERIALIZED_VARIABLES, + build_policyengine_us_dataset_readiness_audit, + write_policyengine_us_dataset_readiness_audit, +) + + +def test_build_policyengine_us_dataset_readiness_audit_passes_complete_artifact( + tmp_path, +): + artifact_dir = tmp_path / "run" + artifact_dir.mkdir() + dataset_path = artifact_dir / "policyengine_us.h5" + _write_dataset(dataset_path) + (artifact_dir / "manifest.json").write_text( + json.dumps( + { + "rows": {"calibrated": 2}, + "weights": {"total": 3.0}, + "artifacts": { + "policyengine_dataset": "policyengine_us.h5", + "source_spine_composition": "source_spine_composition.json", + }, + } + ) + ) + (artifact_dir / "source_spine_composition.json").write_text( + json.dumps( + { + "household_count": 2, + "nonzero_household_count": 2, + "total_active_weight": 3.0, + "effective_sample_size": 1.8, + "groups": [ + { + "spine": "cps_asec", + "household_count": 1, + "nonzero_household_count": 1, + "total_active_weight": 2.0, + "total_source_weight": 2.0, + }, + { + "spine": "acs_pums", + "household_count": 1, + "nonzero_household_count": 1, + "total_active_weight": 1.0, + "total_source_weight": 5.0, + }, + ], + } + ) + ) + + audit = build_policyengine_us_dataset_readiness_audit(artifact_dir, period=2024) + + assert audit["valid"] is True + assert audit["entityCounts"] == { + "household": 2, + "person": 3, + "tax_unit": 2, + "spm_unit": 2, + } + assert audit["variableSummaries"]["state_fips"]["entity"] == "household" + assert audit["variableSummaries"]["spm_unit_spm_threshold"]["positiveShare"] == 1.0 + assert audit["sourceSpineComposition"]["groups"][1]["spine"] == "acs_pums" + assert audit["issues"] == [] + + +def test_build_policyengine_us_dataset_readiness_audit_reports_missing_outputs( + tmp_path, +): + dataset_path = tmp_path / "policyengine_us.h5" + _write_dataset(dataset_path, omit=("snap", "county_fips")) + + audit = build_policyengine_us_dataset_readiness_audit( + dataset_path, + expected_spines=(), + ) + issues_by_variable = { + issue.get("variable"): issue for issue in audit["issues"] if issue.get("variable") + } + + assert audit["valid"] is False + assert issues_by_variable["county_fips"]["severity"] == "error" + assert issues_by_variable["snap"]["severity"] == "error" + + +def test_write_policyengine_us_dataset_readiness_audit_writes_sidecar(tmp_path): + dataset_path = tmp_path / "policyengine_us.h5" + _write_dataset(dataset_path) + + output_path = write_policyengine_us_dataset_readiness_audit( + dataset_path, + expected_spines=(), + ) + + assert output_path == tmp_path / "policyengine_us_readiness.json" + payload = json.loads(output_path.read_text()) + assert payload["valid"] is True + assert payload["expectedMaterializedVariables"] == list( + DEFAULT_EXPECTED_MATERIALIZED_VARIABLES + ) + + +def _write_dataset(path, *, omit=()): + omit = set(omit) + arrays = { + "household_id": np.array([1, 2]), + "household_weight": np.array([2.0, 1.0]), + "person_id": np.array([10, 11, 20]), + "person_household_id": np.array([1, 1, 2]), + "tax_unit_id": np.array([100, 200]), + "person_tax_unit_id": np.array([100, 100, 200]), + "spm_unit_id": np.array([500, 600]), + "person_spm_unit_id": np.array([500, 500, 600]), + "state_fips": np.array([6, 36]), + "county_fips": np.array([b"06001", b"36061"]), + "congressional_district_geoid": np.array([605, 3610]), + "spm_unit_spm_threshold": np.array([30_000.0, 36_000.0]), + "spm_unit_tenure_type": np.array([b"OWN_WITH_MORTGAGE", b"RENT"]), + "income_tax": np.array([100.0, 200.0]), + "income_tax_positive": np.array([100.0, 200.0]), + "eitc": np.array([0.0, 50.0]), + "ctc": np.array([1_000.0, 0.0]), + "refundable_ctc": np.array([400.0, 0.0]), + "non_refundable_ctc": np.array([600.0, 0.0]), + "snap": np.array([10.0, 0.0]), + "ssi": np.array([0.0, 100.0, 0.0]), + "tanf": np.array([0.0, 0.0]), + "medicaid": np.array([1.0, 0.0, 1.0]), + "aca_ptc": np.array([0.0, 75.0]), + } + with h5py.File(path, "w") as handle: + for variable, values in arrays.items(): + if variable in omit: + continue + group = handle.create_group(variable) + group.create_dataset("2024", data=values) diff --git a/tests/pipelines/test_r2_artifacts.py b/tests/pipelines/test_r2_artifacts.py new file mode 100644 index 0000000..5d3f779 --- /dev/null +++ b/tests/pipelines/test_r2_artifacts.py @@ -0,0 +1,183 @@ +"""Tests for R2 artifact archiving.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from microplex_us.pipelines.r2_artifacts import ( + R2_ARCHIVE_MANIFEST_FILENAME, + R2ArchiveConfig, + append_archive_index_entry, + build_archive_manifest, + build_r2_object_key, + upload_artifact_manifest_to_r2, +) + + +class MissingObjectError(Exception): + def __init__(self) -> None: + self.response = {"Error": {"Code": "404"}} + super().__init__("missing") + + +class FakeS3Client: + def __init__(self, *, existing_keys: set[str] | None = None) -> None: + self.existing_keys = existing_keys or set() + self.head_calls: list[tuple[str, str]] = [] + self.upload_calls: list[tuple[str, str, str]] = [] + + def head_object(self, *, Bucket: str, Key: str) -> dict[str, object]: + self.head_calls.append((Bucket, Key)) + if Key not in self.existing_keys: + raise MissingObjectError() + return {} + + def upload_file(self, filename: str, bucket: str, key: str) -> None: + self.upload_calls.append((filename, bucket, key)) + self.existing_keys.add(key) + + +def test_build_r2_object_key_normalizes_prefix() -> None: + assert ( + build_r2_object_key("/microplex-us/artifacts/", "run-a", "scores.json") + == "microplex-us/artifacts/run-a/scores.json" + ) + + +def test_build_archive_manifest_hashes_files_and_excludes_r2_sidecar( + tmp_path: Path, +) -> None: + artifact_dir = tmp_path / "run-a" + artifact_dir.mkdir() + (artifact_dir / "scores.json").write_text('{"loss": 0.1}\n') + (artifact_dir / "data").mkdir() + (artifact_dir / "data" / "weights.npy").write_bytes(b"weights") + (artifact_dir / R2_ARCHIVE_MANIFEST_FILENAME).write_text("{}") + config = R2ArchiveConfig( + bucket="microplex-artifacts", + endpoint_url="https://example.r2.cloudflarestorage.com", + prefix="experiments", + ) + + manifest = build_archive_manifest(artifact_dir, config) + + assert manifest["artifact_id"] == "run-a" + assert manifest["file_count"] == 2 + files = {entry["path"]: entry for entry in manifest["files"]} + assert files["scores.json"]["summary"] is True + assert files["scores.json"]["object_key"] == "experiments/run-a/scores.json" + assert len(files["scores.json"]["sha256"]) == 64 + assert "r2_archive_manifest.json" not in files + assert files["data/weights.npy"]["summary"] is False + + +def test_upload_artifact_manifest_to_r2_uploads_files_and_sidecar( + tmp_path: Path, +) -> None: + artifact_dir = tmp_path / "run-a" + artifact_dir.mkdir() + (artifact_dir / "scores.json").write_text('{"loss": 0.1}\n') + (artifact_dir / "summary.md").write_text("# Run\n") + config = R2ArchiveConfig( + bucket="microplex-artifacts", + endpoint_url="https://example.r2.cloudflarestorage.com", + prefix="experiments", + ) + client = FakeS3Client() + + manifest = upload_artifact_manifest_to_r2( + artifact_dir, + config, + client=client, + hash_files=False, + ) + + assert manifest["status"] == "uploaded" + assert {entry["status"] for entry in manifest["files"]} == {"uploaded"} + uploaded_keys = [key for _, _, key in client.upload_calls] + assert "experiments/run-a/scores.json" in uploaded_keys + assert "experiments/run-a/summary.md" in uploaded_keys + assert "experiments/run-a/r2_archive_manifest.json" in uploaded_keys + local_manifest = json.loads( + (artifact_dir / R2_ARCHIVE_MANIFEST_FILENAME).read_text() + ) + assert local_manifest["r2"]["bucket"] == "microplex-artifacts" + + +def test_upload_artifact_manifest_to_r2_skips_existing_objects( + tmp_path: Path, +) -> None: + artifact_dir = tmp_path / "run-a" + artifact_dir.mkdir() + (artifact_dir / "scores.json").write_text('{"loss": 0.1}\n') + config = R2ArchiveConfig( + bucket="microplex-artifacts", + endpoint_url="https://example.r2.cloudflarestorage.com", + prefix="experiments", + ) + client = FakeS3Client(existing_keys={"experiments/run-a/scores.json"}) + + manifest = upload_artifact_manifest_to_r2( + artifact_dir, + config, + client=client, + hash_files=False, + ) + + assert manifest["files"][0]["status"] == "already_exists" + uploaded_keys = [key for _, _, key in client.upload_calls] + assert uploaded_keys == ["experiments/run-a/r2_archive_manifest.json"] + + +def test_append_archive_index_entry_records_compact_upload( + tmp_path: Path, +) -> None: + artifact_dir = tmp_path / "run-a" + artifact_dir.mkdir() + (artifact_dir / "scores.json").write_text('{"loss": 0.1}\n') + config = R2ArchiveConfig( + bucket="microplex-artifacts", + endpoint_url="https://example.r2.cloudflarestorage.com", + prefix="experiments", + ) + manifest = build_archive_manifest(artifact_dir, config, hash_files=False) + + index_path = append_archive_index_entry( + tmp_path / "r2_archive_index.jsonl", + manifest, + pruned_local=True, + ) + + rows = [json.loads(line) for line in index_path.read_text().splitlines()] + assert rows == [ + { + "recorded_at": rows[0]["recorded_at"], + "artifact_id": "run-a", + "artifact_dir": str(artifact_dir.resolve()), + "bucket": "microplex-artifacts", + "prefix": "experiments", + "manifest_object_key": "experiments/run-a/r2_archive_manifest.json", + "file_count": 1, + "total_bytes": 14, + "status": None, + "pruned_local": True, + } + ] + + +def test_r2_archive_config_from_env_uses_account_endpoint( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("MICROPLEX_R2_BUCKET", "microplex-artifacts") + monkeypatch.setenv("CLOUDFLARE_ACCOUNT_ID", "abc123") + monkeypatch.setenv("R2_ACCESS_KEY_ID", "key") + monkeypatch.setenv("R2_SECRET_ACCESS_KEY", "secret") + + config = R2ArchiveConfig.from_env() + + assert config.endpoint_url == "https://abc123.r2.cloudflarestorage.com" + assert config.access_key_id == "key" + assert config.secret_access_key == "secret" diff --git a/tests/policyengine/test_target_profiles.py b/tests/policyengine/test_target_profiles.py index 1a54868..0aeaa46 100644 --- a/tests/policyengine/test_target_profiles.py +++ b/tests/policyengine/test_target_profiles.py @@ -1,6 +1,7 @@ from __future__ import annotations from microplex_us.policyengine.target_profiles import ( + policyengine_us_target_profile_exclusion_reasons, policyengine_us_target_profile_names, resolve_policyengine_us_target_profile, ) @@ -9,6 +10,250 @@ def test_policyengine_us_target_profile_names_include_no_state_aca_variant() -> None: assert "pe_native_broad" in policyengine_us_target_profile_names() assert "pe_native_broad_no_state_aca" in policyengine_us_target_profile_names() + assert "pe_native_broad_source_backed" in policyengine_us_target_profile_names() + + +def test_broad_profile_includes_soi_employment_income_cells() -> None: + broad = resolve_policyengine_us_target_profile("pe_native_broad") + broad_cells = { + (cell.variable, cell.geo_level, cell.domain_variable, cell.geographic_id) + for cell in broad + } + + assert ( + "employment_income", + "national", + "employment_income", + None, + ) in broad_cells + assert ( + "tax_unit_count", + "national", + "employment_income", + None, + ) in broad_cells + assert ( + "employment_income", + "state", + "employment_income", + None, + ) in broad_cells + assert ( + "tax_unit_count", + "state", + "employment_income", + None, + ) in broad_cells + + +def test_broad_profile_includes_bea_full_population_amount_cells() -> None: + broad = resolve_policyengine_us_target_profile("pe_native_broad") + broad_cells = { + (cell.variable, cell.geo_level, cell.domain_variable, cell.geographic_id) + for cell in broad + } + + assert ( + "dividend_income", + "national", + None, + None, + ) in broad_cells + assert ( + "employment_income", + "national", + None, + None, + ) in broad_cells + assert ( + "rental_income", + "national", + None, + None, + ) in broad_cells + assert ( + "self_employment_income", + "national", + None, + None, + ) in broad_cells + assert ( + "employment_income", + "state", + None, + None, + ) in broad_cells + assert ( + "self_employment_income", + "state", + None, + None, + ) in broad_cells + + +def test_broad_profile_covers_current_policyengine_target_db_cells() -> None: + broad = resolve_policyengine_us_target_profile("pe_native_broad") + broad_cells = { + (cell.variable, cell.geo_level, cell.domain_variable, cell.geographic_id) + for cell in broad + } + + added_policyengine_cells = { + ("aca_ptc", "national", "aca_ptc", None), + ("adjusted_gross_income", "national", "adjusted_gross_income", None), + ( + "adjusted_gross_income", + "national", + "adjusted_gross_income,filing_status,income_tax_before_credits", + None, + ), + ( + "adjusted_gross_income", + "national", + "adjusted_gross_income,income_tax_before_credits", + None, + ), + ("childcare_expenses", "national", None, None), + ("deductible_mortgage_interest", "national", None, None), + ("household_count", "national", "spm_unit_energy_subsidy_reported", None), + ( + "medical_expense_deduction", + "national", + "medical_expense_deduction,tax_unit_itemizes", + None, + ), + ( + "non_refundable_ctc", + "national", + "adjusted_gross_income,non_refundable_ctc", + None, + ), + ("non_refundable_ctc", "national", "non_refundable_ctc", None), + ( + "real_estate_taxes", + "national", + "real_estate_taxes,tax_unit_itemizes", + None, + ), + ( + "refundable_ctc", + "national", + "adjusted_gross_income,refundable_ctc", + None, + ), + ("roth_401k_contributions", "national", None, None), + ("salt", "national", "salt,tax_unit_itemizes", None), + ("self_employed_pension_contribution_ald", "national", None, None), + ("spm_unit_count", "national", "tanf", None), + ("tanf", "national", "tanf", None), + ("tax_unit_count", "national", "adjusted_gross_income", None), + ( + "tax_unit_count", + "national", + "adjusted_gross_income,filing_status,income_tax_before_credits", + None, + ), + ( + "tax_unit_count", + "national", + "adjusted_gross_income,income_tax_before_credits", + None, + ), + ( + "tax_unit_count", + "national", + "adjusted_gross_income,non_refundable_ctc", + None, + ), + ( + "tax_unit_count", + "national", + "adjusted_gross_income,refundable_ctc", + None, + ), + ( + "tax_unit_count", + "national", + "medical_expense_deduction,tax_unit_itemizes", + None, + ), + ("tax_unit_count", "national", "non_refundable_ctc", None), + ( + "tax_unit_count", + "national", + "real_estate_taxes,tax_unit_itemizes", + None, + ), + ("tax_unit_count", "national", "salt,tax_unit_itemizes", None), + ("tax_unit_count", "national", "total_self_employment_income", None), + ( + "total_self_employment_income", + "national", + "total_self_employment_income", + None, + ), + ("traditional_401k_contributions", "national", None, None), + ("aca_ptc", "state", "aca_ptc", None), + ("adjusted_gross_income", "state", "adjusted_gross_income", None), + ( + "medical_expense_deduction", + "state", + "medical_expense_deduction,tax_unit_itemizes", + None, + ), + ("non_refundable_ctc", "state", "non_refundable_ctc", None), + ("person_count", "state", "aca_ptc,is_aca_ptc_eligible", None), + ("person_count", "state", "is_pregnant", None), + ( + "real_estate_taxes", + "state", + "real_estate_taxes,tax_unit_itemizes", + None, + ), + ("salt", "state", "salt,tax_unit_itemizes", None), + ("spm_unit_count", "state", "tanf", None), + ("tanf", "state", "tanf", None), + ( + "tax_unit_count", + "state", + "medical_expense_deduction,tax_unit_itemizes", + None, + ), + ("tax_unit_count", "state", "non_refundable_ctc", None), + ( + "tax_unit_count", + "state", + "real_estate_taxes,tax_unit_itemizes", + None, + ), + ("tax_unit_count", "state", "salt,tax_unit_itemizes", None), + ( + "tax_unit_count", + "state", + "selected_marketplace_plan_benchmark_ratio,used_aca_ptc", + None, + ), + ("tax_unit_count", "state", "total_self_employment_income", None), + ("tax_unit_count", "state", "used_aca_ptc", None), + ( + "total_self_employment_income", + "state", + "total_self_employment_income", + None, + ), + } + + assert added_policyengine_cells <= broad_cells + + +def test_broad_profile_has_no_duplicate_cells() -> None: + broad = resolve_policyengine_us_target_profile("pe_native_broad") + broad_cells = [ + (cell.variable, cell.geo_level, cell.domain_variable, cell.geographic_id) + for cell in broad + ] + + assert len(broad_cells) == len(set(broad_cells)) def test_no_state_aca_profile_excludes_only_state_aca_cells() -> None: @@ -29,7 +274,7 @@ def test_no_state_aca_profile_excludes_only_state_aca_cells() -> None: assert ( "aca_ptc", "state", - "aca_ptc", + None, None, ) in broad_cells assert ( @@ -40,10 +285,22 @@ def test_no_state_aca_profile_excludes_only_state_aca_cells() -> None: ) in broad_cells assert ( "aca_ptc", - "national", + "state", "aca_ptc", None, - ) in no_state_aca_cells + ) in broad_cells + assert ( + "person_count", + "state", + "aca_ptc,is_aca_ptc_eligible", + None, + ) in broad_cells + assert ( + "tax_unit_count", + "state", + "used_aca_ptc", + None, + ) in broad_cells assert ( "tax_unit_count", "national", @@ -53,7 +310,7 @@ def test_no_state_aca_profile_excludes_only_state_aca_cells() -> None: assert ( "aca_ptc", "state", - "aca_ptc", + None, None, ) not in no_state_aca_cells assert ( @@ -62,4 +319,89 @@ def test_no_state_aca_profile_excludes_only_state_aca_cells() -> None: "aca_ptc", None, ) not in no_state_aca_cells + assert ( + "aca_ptc", + "state", + "aca_ptc", + None, + ) not in no_state_aca_cells + assert ( + "person_count", + "state", + "aca_ptc", + None, + ) not in no_state_aca_cells + assert ( + "person_count", + "state", + "aca_ptc,is_aca_ptc_eligible", + None, + ) not in no_state_aca_cells + assert ( + "tax_unit_count", + "state", + "selected_marketplace_plan_benchmark_ratio,used_aca_ptc", + None, + ) not in no_state_aca_cells + assert ( + "tax_unit_count", + "state", + "used_aca_ptc", + None, + ) not in no_state_aca_cells + + +def test_source_backed_profile_excludes_only_documented_non_source_cells() -> None: + broad = resolve_policyengine_us_target_profile("pe_native_broad") + source_backed = resolve_policyengine_us_target_profile( + "pe_native_broad_source_backed" + ) + exclusion_reasons = policyengine_us_target_profile_exclusion_reasons( + "pe_native_broad_source_backed" + ) + + broad_cells = { + (cell.variable, cell.geo_level, cell.domain_variable, cell.geographic_id) + for cell in broad + } + source_backed_cells = { + (cell.variable, cell.geo_level, cell.domain_variable, cell.geographic_id) + for cell in source_backed + } + assert len(broad_cells) == 189 + assert len(exclusion_reasons) == 15 + assert all(reason for reason in exclusion_reasons.values()) + assert set(exclusion_reasons) <= broad_cells + assert len(source_backed_cells) == 174 + assert source_backed_cells == broad_cells - set(exclusion_reasons) + assert ( + "childcare_expenses", + "national", + None, + None, + ) not in source_backed_cells + assert ( + "person_count", + "state", + "is_pregnant", + None, + ) not in source_backed_cells + assert ( + "employment_income", + "national", + None, + None, + ) in source_backed_cells + assert ( + "medicare_part_b_premiums", + "national", + None, + None, + ) in source_backed_cells + assert ( + "net_worth", + "national", + None, + None, + ) in source_backed_cells diff --git a/tests/targets/test_aca_ptc.py b/tests/targets/test_aca_ptc.py new file mode 100644 index 0000000..1927013 --- /dev/null +++ b/tests/targets/test_aca_ptc.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +import csv +import json +from pathlib import Path +from typing import Any + +import pytest + +from microplex_us.targets import ( + ACA_AVERAGE_MONTHLY_APTC_CONCEPT, + ACA_MARKETPLACE_EFFECTUATED_ENROLLMENT_CONCEPT, + ACAPTCMultiplierInput, + aca_ptc_multiplier_inputs_from_arch_consumer_facts, + build_aca_ptc_multiplier_rows, + load_arch_consumer_fact_jsonl_rows, + write_policyengine_aca_ptc_multiplier_csv, +) +from microplex_us.targets.aca_ptc import main + + +def test_build_aca_ptc_multiplier_rows_matches_policyengine_formula() -> None: + rows = build_aca_ptc_multiplier_rows( + [ + ACAPTCMultiplierInput( + state="California", + enroll_base=1_701_375, + enroll_target=1_795_695, + aptc_base=459, + aptc_target=526, + ) + ] + ) + + row = rows[0] + assert row.vol_mult == pytest.approx(1_795_695 / 1_701_375) + assert row.val_mult == pytest.approx(526 / 459) + assert row.amount_mult == pytest.approx((1_795_695 / 1_701_375) * (526 / 459)) + assert row.target_factors() == { + "tax_unit_count": pytest.approx(1_795_695 / 1_701_375), + "aca_ptc": pytest.approx((1_795_695 / 1_701_375) * (526 / 459)), + } + + +def test_arch_consumer_fact_inputs_use_oep_with_effectuated_fallback() -> None: + facts = [ + _enrollment_fact("California", "ca", 2022, 1_701_375), + _enrollment_fact("California", "ca", 2024, 1_795_695), + _oep_aptc_fact("California", "ca", 2022, 459), + _effectuated_aptc_fact("California", "ca", 2022, 469.44), + _oep_aptc_fact("California", "ca", 2024, 526), + _enrollment_fact("Nevada", "nv", 2022, 90_397), + _enrollment_fact("Nevada", "nv", 2024, 92_949), + _effectuated_aptc_fact("Nevada", "nv", 2022, 429.75), + _oep_aptc_fact("Nevada", "nv", 2024, 438), + ] + + inputs = aca_ptc_multiplier_inputs_from_arch_consumer_facts(facts) + + by_state = {item.state: item for item in inputs} + assert by_state["California"].aptc_base == 459 + assert by_state["California"].aptc_base_source_kind == "oep" + assert by_state["Nevada"].aptc_base == 429.75 + assert by_state["Nevada"].aptc_base_source_kind == "effectuated" + + rows = build_aca_ptc_multiplier_rows(inputs) + nevada = {row.state: row for row in rows}["Nevada"] + assert nevada.vol_mult == pytest.approx(92_949 / 90_397) + assert nevada.val_mult == pytest.approx(438 / 429.75) + assert nevada.val_mult != pytest.approx(438 / 435) + + +def test_arch_consumer_fact_inputs_can_require_oep_base_aptc() -> None: + facts = [ + _enrollment_fact("Nevada", "nv", 2022, 90_397), + _enrollment_fact("Nevada", "nv", 2024, 92_949), + _effectuated_aptc_fact("Nevada", "nv", 2022, 429.75), + _oep_aptc_fact("Nevada", "nv", 2024, 438), + ] + + with pytest.raises(ValueError, match="Nevada 2022 average APTC"): + aca_ptc_multiplier_inputs_from_arch_consumer_facts( + facts, + base_aptc_policy="oep", + ) + + +def test_write_policyengine_aca_ptc_multiplier_csv(tmp_path: Path) -> None: + rows = build_aca_ptc_multiplier_rows( + [ + ACAPTCMultiplierInput( + state="Nevada", + enroll_base=90_397, + enroll_target=92_949, + aptc_base=429.75, + aptc_target=438, + ) + ] + ) + path = tmp_path / "aca_ptc_multipliers_2022_2024.csv" + + write_policyengine_aca_ptc_multiplier_csv(rows, path) + + with path.open() as file: + records = list(csv.DictReader(file)) + assert records[0]["state"] == "Nevada" + assert records[0]["enroll_2022"] == "90397" + assert records[0]["aptc_2024"] == "438" + assert float(records[0]["enroll_2022"]) == 90_397 + assert float(records[0]["aptc_2022"]) == 429.75 + assert float(records[0]["vol_mult"]) == pytest.approx(92_949 / 90_397) + assert float(records[0]["val_mult"]) == pytest.approx(438 / 429.75) + + +def test_main_builds_policyengine_csv_from_consumer_fact_jsonl( + tmp_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: + consumer_facts = tmp_path / "consumer_facts.jsonl" + _write_jsonl( + consumer_facts, + [ + _enrollment_fact("Nevada", "nv", 2022, 90_397), + _enrollment_fact("Nevada", "nv", 2024, 92_949), + _effectuated_aptc_fact("Nevada", "nv", 2022, 429.75), + _oep_aptc_fact("Nevada", "nv", 2024, 438), + ], + ) + out = tmp_path / "aca_ptc_multipliers_2022_2024.csv" + + assert main([str(consumer_facts), "--out", str(out)]) == 0 + + captured = capsys.readouterr() + assert f"Wrote 1 ACA PTC multiplier rows to {out}" in captured.out + rows = list(csv.DictReader(out.open())) + assert rows[0]["state"] == "Nevada" + assert rows[0]["aptc_2022"] == "429.75" + + +def test_load_arch_consumer_fact_jsonl_rows_rejects_non_consumer_rows( + tmp_path: Path, +) -> None: + path = tmp_path / "facts.jsonl" + path.write_text(json.dumps({"schema_version": "arch.fact.v1"}) + "\n") + + with pytest.raises(ValueError, match="Unsupported Arch consumer fact schema"): + load_arch_consumer_fact_jsonl_rows([path]) + + +def _enrollment_fact( + state: str, + state_abbr: str, + year: int, + value: float, +) -> dict[str, Any]: + return _fact( + state=state, + period=year, + value=value, + concept=ACA_MARKETPLACE_EFFECTUATED_ENROLLMENT_CONCEPT, + source_record_id=( + f"kff.marketplace_effectuated_enrollment.{year}.state." + f"{state_abbr}.total_effectuated_marketplace_enrollment" + ), + ) + + +def _oep_aptc_fact( + state: str, + state_abbr: str, + year: int, + value: float, +) -> dict[str, Any]: + return _fact( + state=state, + period=year, + value=value, + concept=ACA_AVERAGE_MONTHLY_APTC_CONCEPT, + source_record_id=( + f"cms_aca.oep{year}.state_marketplace.{state_abbr}.average_monthly_aptc" + ), + ) + + +def _effectuated_aptc_fact( + state: str, + state_abbr: str, + year: int, + value: float, +) -> dict[str, Any]: + return _fact( + state=state, + period=year, + value=value, + concept=ACA_AVERAGE_MONTHLY_APTC_CONCEPT, + source_record_id=( + f"cms_aca.effectuated_enrollment.{year}.state_marketplace." + f"{state_abbr}.average_monthly_aptc" + ), + ) + + +def _fact( + *, + state: str, + period: int, + value: float, + concept: str, + source_record_id: str, +) -> dict[str, Any]: + return { + "schema_version": "arch.consumer_fact.v1", + "period": {"type": "calendar_year", "value": period}, + "geography": {"level": "state", "name": state}, + "observed_measure": {"source_concept": concept}, + "lineage": {"source_record_id": source_record_id}, + "value": value, + } + + +def _write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: + path.write_text("\n".join(json.dumps(row, sort_keys=True) for row in rows) + "\n") diff --git a/tests/targets/test_arch.py b/tests/targets/test_arch.py new file mode 100644 index 0000000..5bbce96 --- /dev/null +++ b/tests/targets/test_arch.py @@ -0,0 +1,3976 @@ +from __future__ import annotations + +import json +import sqlite3 +from pathlib import Path + +import pytest +from microplex.core import EntityType +from microplex.targets import TargetAggregation, TargetQuery + +from microplex_us.geography import US_STATE_ABBR_BY_FIPS +from microplex_us.pipelines.us import USMicroplexBuildConfig, USMicroplexPipeline +from microplex_us.policyengine.target_profiles import PolicyEngineUSTargetCell +from microplex_us.targets import ( + ArchConsumerFactJSONLTargetProvider, + ArchSQLiteTargetProvider, + summarize_arch_target_gap_queue, + summarize_arch_target_profile_coverage, +) +from microplex_us.targets.arch import ( + ArchTargetRecord, + arch_target_record_to_canonical_spec, + main_gaps, + main_refresh, +) + + +def _create_arch_targets_db(path: Path) -> None: + conn = sqlite3.connect(path) + conn.executescript( + """ + CREATE TABLE strata ( + id INTEGER PRIMARY KEY, + name TEXT, + jurisdiction TEXT, + parent_id INTEGER, + definition_hash TEXT + ); + + CREATE TABLE stratum_constraints ( + id INTEGER PRIMARY KEY, + stratum_id INTEGER NOT NULL, + variable TEXT NOT NULL, + operator TEXT NOT NULL, + value TEXT NOT NULL + ); + + CREATE TABLE targets ( + id INTEGER PRIMARY KEY, + stratum_id INTEGER NOT NULL, + variable TEXT NOT NULL, + period INTEGER NOT NULL, + value REAL NOT NULL, + target_type TEXT NOT NULL, + geographic_level TEXT, + source TEXT NOT NULL, + source_table TEXT, + source_url TEXT, + notes TEXT + ); + """ + ) + conn.executemany( + """ + INSERT INTO strata (id, name, jurisdiction, definition_hash) + VALUES (?, ?, ?, ?) + """, + [ + (1, "US", "US", "root"), + (2, "US All Filers", "US", "filers"), + (3, "CA Filers AGI $50k-$75k", "US", "ca_50k_75k"), + ], + ) + conn.executemany( + """ + INSERT INTO stratum_constraints ( + stratum_id, + variable, + operator, + value + ) VALUES (?, ?, ?, ?) + """, + [ + (2, "is_tax_filer", "==", "1"), + (3, "is_tax_filer", "==", "1"), + (3, "state_fips", "==", "06"), + (3, "agi_bracket", "==", "50k_to_75k"), + ], + ) + conn.executemany( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + 1, + 2, + "tax_exempt_interest_returns", + 2023, + 10.0, + "COUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ( + 2, + 2, + "tax_exempt_interest_amount", + 2023, + 100.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ( + 3, + 2, + "adjusted_gross_income", + 2022, + 1_000.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ( + 4, + 2, + "adjusted_gross_income", + 2023, + 1_100.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ( + 5, + 1, + "labor_force_count", + 2023, + 100.0, + "COUNT", + None, + "BLS", + "BLS", + None, + None, + ), + (6, 1, "labor_force", 2024, 110.0, "COUNT", None, "CBO", "CBO", None, None), + ( + 7, + 3, + "tax_unit_count", + 2023, + 20.0, + "COUNT", + "STATE", + "IRS_SOI", + "SOI", + None, + None, + ), + ], + ) + conn.commit() + conn.close() + + +def _insert_multi_domain_soi_targets(path: Path) -> None: + conn = sqlite3.connect(path) + conn.execute( + """ + INSERT INTO strata (id, name, jurisdiction, definition_hash) + VALUES (?, ?, ?, ?) + """, + (4, "US Filers AGI $50k-$75k", "US", "national_50k_75k"), + ) + conn.executemany( + """ + INSERT INTO stratum_constraints ( + stratum_id, + variable, + operator, + value + ) VALUES (?, ?, ?, ?) + """, + [ + (4, "is_tax_filer", "==", "1"), + (4, "agi_bracket", "==", "50k_to_75k"), + ], + ) + conn.executemany( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + 8, + 4, + "tax_exempt_interest_returns", + 2023, + 5.0, + "COUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ( + 9, + 4, + "adjusted_gross_income", + 2023, + 500.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ], + ) + conn.commit() + conn.close() + + +def _insert_w2_tip_income_target(path: Path) -> None: + conn = sqlite3.connect(path) + conn.execute( + """ + INSERT INTO strata (id, name, jurisdiction, definition_hash) + VALUES (?, ?, ?, ?) + """, + (5, "US taxpayers with Form W-2 social security tips", "US", "w2_tips"), + ) + conn.execute( + """ + INSERT INTO stratum_constraints ( + stratum_id, + variable, + operator, + value + ) VALUES (?, ?, ?, ?) + """, + (5, "tip_income", ">", "0"), + ) + conn.executemany( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + 10, + 5, + "tip_income", + 2020, + 80.0, + "AMOUNT", + "NATIONAL", + "IRS_SOI", + "W-2", + None, + None, + ), + ( + 11, + 2, + "adjusted_gross_income", + 2020, + 800.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ], + ) + conn.commit() + conn.close() + + +def _insert_irs_soi_itemized_deduction_targets(path: Path) -> None: + conn = sqlite3.connect(path) + conn.executemany( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + 12, + 2, + "medical_amount", + 2023, + 100.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI Individual Returns - Itemized Deductions", + None, + None, + ), + ( + 13, + 2, + "real_estate_taxes_amount", + 2023, + 200.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI Individual Returns - Itemized Deductions", + None, + None, + ), + ( + 14, + 2, + "salt_amount", + 2023, + 300.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI Individual Returns - Itemized Deductions", + None, + None, + ), + ( + 15, + 2, + "medical_claims", + 2023, + 10.0, + "COUNT", + None, + "IRS_SOI", + "SOI Individual Returns - Itemized Deductions", + None, + None, + ), + ( + 16, + 2, + "real_estate_taxes_claims", + 2023, + 20.0, + "COUNT", + None, + "IRS_SOI", + "SOI Individual Returns - Itemized Deductions", + None, + None, + ), + ( + 17, + 2, + "salt_claims", + 2023, + 30.0, + "COUNT", + None, + "IRS_SOI", + "SOI Individual Returns - Itemized Deductions", + None, + None, + ), + ], + ) + conn.commit() + conn.close() + + +def _insert_complete_state_rollup_targets(path: Path) -> None: + conn = sqlite3.connect(path) + state_fips_values = sorted( + state_fips for state_fips in US_STATE_ABBR_BY_FIPS if state_fips != "72" + ) + ctc_strata = [ + (1_000 + index, f"{state_fips} CTC filers", "US", f"ctc_{state_fips}") + for index, state_fips in enumerate(state_fips_values) + ] + aca_strata = [ + (2_000 + index, f"{state_fips} ACA marketplace", "US", f"aca_{state_fips}") + for index, state_fips in enumerate(state_fips_values) + ] + conn.executemany( + """ + INSERT INTO strata (id, name, jurisdiction, definition_hash) + VALUES (?, ?, ?, ?) + """, + [*ctc_strata, *aca_strata], + ) + conn.executemany( + """ + INSERT INTO stratum_constraints ( + stratum_id, + variable, + operator, + value + ) VALUES (?, ?, ?, ?) + """, + [ + *((stratum_id, "is_tax_filer", "==", "1") for stratum_id, *_ in ctc_strata), + *( + (stratum_id, "state_fips", "==", state_fips) + for stratum_id, _, _, definition_hash in ctc_strata + for state_fips in (definition_hash.removeprefix("ctc_"),) + ), + *( + (stratum_id, "state_fips", "==", state_fips) + for stratum_id, _, _, definition_hash in aca_strata + for state_fips in (definition_hash.removeprefix("aca_"),) + ), + ], + ) + ctc_targets = [ + ( + 10_000 + index * 2, + stratum_id, + "ctc_amount", + 2024, + 1_000.0 + index, + "AMOUNT", + None, + "IRS_SOI", + "State Data FY", + None, + None, + ) + for index, (stratum_id, *_rest) in enumerate(ctc_strata) + ] + ctc_count_targets = [ + ( + 10_001 + index * 2, + stratum_id, + "ctc_claims", + 2024, + 100.0 + index, + "COUNT", + None, + "IRS_SOI", + "State Data FY", + None, + None, + ) + for index, (stratum_id, *_rest) in enumerate(ctc_strata) + ] + aca_targets = [ + ( + 20_000 + index, + stratum_id, + "aca_aptc_amount", + 2024, + 10_000.0 + index, + "AMOUNT", + "STATE", + "CMS_ACA", + "2024 OEP State-Level Public Use File", + None, + None, + ) + for index, (stratum_id, *_rest) in enumerate(aca_strata) + ] + conn.executemany( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [*ctc_targets, *ctc_count_targets, *aca_targets], + ) + conn.commit() + conn.close() + + +def test_arch_provider_ages_soi_and_maps_return_counts_to_positive_amounts(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "variables": ["tax_exempt_interest_income"], + "entity_overrides": { + "tax_exempt_interest_income": EntityType.PERSON, + }, + }, + ) + ) + + assert len(target_set.targets) == 2 + count_target = next( + target + for target in target_set.targets + if target.aggregation is TargetAggregation.COUNT + ) + amount_target = next( + target + for target in target_set.targets + if target.aggregation is TargetAggregation.SUM + ) + assert count_target.entity is EntityType.TAX_UNIT + assert count_target.name == "arch_target_1" + assert count_target.description == ( + "Tax-exempt interest returns for US All Filers (IRS SOI, 2024)" + ) + assert count_target.metadata["display_label"] == count_target.description + assert count_target.metadata["target_semantic"] == "count" + assert count_target.metadata["model_variable_role"] == "preserved_input" + assert count_target.measure is None + assert count_target.value == pytest.approx(11.0) + assert { + (target_filter.feature, target_filter.operator.value, target_filter.value) + for target_filter in count_target.filters + } == { + ("tax_unit_is_filer", "==", "1"), + ("tax_exempt_interest_income", ">", 0), + } + + assert amount_target.entity is EntityType.PERSON + assert amount_target.description == ( + "Tax-exempt interest amount for US All Filers (IRS SOI, 2024)" + ) + assert amount_target.metadata["display_label"] == amount_target.description + assert amount_target.metadata["target_semantic"] == "amount" + assert amount_target.metadata["model_variable_role"] == "preserved_input" + assert amount_target.measure == "tax_exempt_interest_income" + assert amount_target.value == pytest.approx(110.0) + assert amount_target.metadata["arch_source_period"] == 2023 + assert amount_target.metadata["arch_aging_amount_method"] == ( + "soi_total_agi_last_growth_extrapolation" + ) + + +def test_arch_provider_maps_agi_bracket_constraints_to_agi_ranges(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "variables": ["tax_unit_count"], + }, + ) + ) + + assert {target.metadata["target_id"] for target in target_set.targets} == {7} + bracket_target = next( + target for target in target_set.targets if target.metadata["target_id"] == 7 + ) + assert bracket_target.name == "arch_target_7" + assert bracket_target.description == ( + "Tax unit count for CA Filers AGI $50k-$75k (IRS SOI, 2024)" + ) + assert bracket_target.metadata["display_label"] == bracket_target.description + assert bracket_target.value == pytest.approx(22.0) + assert { + (target_filter.feature, target_filter.operator.value, target_filter.value) + for target_filter in bracket_target.filters + } == { + ("tax_unit_is_filer", "==", "1"), + ("state_fips", "==", "06"), + ("adjusted_gross_income", ">=", 50_000), + ("adjusted_gross_income", "<", 75_000), + } + + +def test_arch_provider_includes_parent_stratum_constraints(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.execute( + """ + INSERT INTO strata (id, name, jurisdiction, parent_id, definition_hash) + VALUES (?, ?, ?, ?, ?) + """, + (4, "US Filers AGI $75k-$100k", "US", 2, "national_75k_100k"), + ) + conn.execute( + """ + INSERT INTO stratum_constraints ( + stratum_id, + variable, + operator, + value + ) VALUES (?, ?, ?, ?) + """, + (4, "agi_bracket", "==", "75k_to_100k"), + ) + conn.execute( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + 8, + 4, + "adjusted_gross_income", + 2023, + 750.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "variables": ["adjusted_gross_income"], + }, + ) + ) + + child_target = next( + target for target in target_set.targets if target.metadata["target_id"] == 8 + ) + assert { + (target_filter.feature, target_filter.operator.value, target_filter.value) + for target_filter in child_target.filters + } == { + ("tax_unit_is_filer", "==", "1"), + ("adjusted_gross_income", ">=", 75_000), + ("adjusted_gross_income", "<", 100_000), + } + + +def test_arch_provider_honors_policyengine_target_cells(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "tax_unit_count", + "geo_level": "national", + "domain_variable": "tax_exempt_interest_income", + } + ], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [1] + target = target_set.targets[0] + assert target.aggregation is TargetAggregation.COUNT + assert target.measure is None + + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "tax_exempt_interest_income", + "geo_level": "national", + "domain_variable": "tax_exempt_interest_income", + } + ], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [2] + + +def test_arch_provider_target_cell_domain_match_is_exact(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + _insert_multi_domain_soi_targets(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "tax_unit_count", + "geo_level": "national", + "domain_variable": "tax_exempt_interest_income", + } + ], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [1] + + +def test_arch_provider_matches_aliased_amount_self_domain_target_cells(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.execute( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + 8, + 2, + "income_tax_liability", + 2023, + 80.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "income_tax", + "geo_level": "national", + "domain_variable": "income_tax", + } + ], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [8] + + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "income_tax_positive", + "geo_level": "national", + "domain_variable": None, + } + ], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [8] + assert target_set.targets[0].measure == "income_tax" + assert target_set.targets[0].metadata["arch_variable"] == "income_tax_liability" + + +def test_arch_provider_matches_current_profile_aliases(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.executemany( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + 8, + 2, + "alimony_received_amount", + 2023, + 20.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ( + 9, + 2, + "schedule_c_income_amount", + 2023, + 30.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ( + 10, + 1, + "medicaid_total_enrollment", + 2024, + 40.0, + "COUNT", + None, + "CMS_MEDICAID", + "CMS", + None, + None, + ), + ( + 11, + 2, + "wages_salaries_amount", + 2023, + 50.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ( + 12, + 2, + "wages_salaries_returns", + 2023, + 60.0, + "COUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ( + 13, + 2, + "schedule_c_income_returns", + 2023, + 70.0, + "COUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ], + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "target_cells": [ + { + "variable": "alimony_income", + "geo_level": "national", + "domain_variable": None, + }, + { + "variable": "self_employment_income", + "geo_level": "national", + "domain_variable": "self_employment_income", + }, + { + "variable": "total_self_employment_income", + "geo_level": "national", + "domain_variable": "total_self_employment_income", + }, + { + "variable": "tax_unit_count", + "geo_level": "national", + "domain_variable": "total_self_employment_income", + }, + { + "variable": "person_count", + "geo_level": "national", + "domain_variable": "medicaid", + }, + { + "variable": "employment_income", + "geo_level": "national", + "domain_variable": "employment_income", + }, + { + "variable": "tax_unit_count", + "geo_level": "national", + "domain_variable": "employment_income", + }, + ], + }, + ) + ) + + assert {target.metadata["target_id"] for target in target_set.targets} == { + 8, + 9, + 10, + 11, + 12, + 13, + } + variables_by_id = { + target.metadata["target_id"]: target.metadata["variable"] + for target in target_set.targets + } + assert variables_by_id == { + 8: "alimony_income", + 9: "self_employment_income", + 10: "person_count", + 11: "employment_income", + 12: "employment_income", + 13: "self_employment_income", + } + + +def test_arch_target_rejects_broad_proprietors_income_as_self_employment(): + record = ArchTargetRecord( + target_id=1, + stratum_id=1, + variable="schedule_c_income_amount", + period=2024, + value=2_023_080_000_000, + target_type="AMOUNT", + geographic_level=None, + geography_id=None, + source="BEA", + source_table="NIPA annual personal income components", + source_url=None, + notes=None, + stratum_name="US", + jurisdiction="US", + constraints=(), + concept=( + "bea_nipa.proprietors_income_with_inventory_valuation_and_capital_" + "consumption_adjustments" + ), + source_concept=( + "bea_nipa.a041rc_proprietors_income_with_inventory_valuation_and_" + "capital_consumption_adjustments" + ), + ) + + with pytest.raises( + ValueError, + match="cannot be exposed as plain self_employment_income", + ): + arch_target_record_to_canonical_spec(record) + + +def test_arch_provider_maps_eitc_child_count_constraints(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.execute( + """ + INSERT INTO strata (id, name, jurisdiction, definition_hash) + VALUES (?, ?, ?, ?) + """, + (4, "US EITC 3+ Children", "US", "eitc_3plus_children"), + ) + conn.executemany( + """ + INSERT INTO stratum_constraints ( + stratum_id, + variable, + operator, + value + ) VALUES (?, ?, ?, ?) + """, + [ + (4, "is_tax_filer", "==", "1"), + (4, "eitc_qualifying_children", ">=", "3"), + ], + ) + conn.execute( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + 8, + 4, + "eitc_amount", + 2023, + 15.0, + "AMOUNT", + None, + "IRS_SOI", + "EITC", + None, + None, + ), + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "eitc", + "geo_level": "national", + "domain_variable": "eitc_child_count", + } + ], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [8] + target = target_set.targets[0] + assert target.measure == "eitc" + assert { + (target_filter.feature, target_filter.operator.value, target_filter.value) + for target_filter in target.filters + } == { + ("tax_unit_is_filer", "==", "1"), + ("eitc_child_count", ">=", "3"), + } + + +def test_arch_provider_matches_eitc_count_and_multi_domain_cells(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.executemany( + """ + INSERT INTO strata (id, name, jurisdiction, definition_hash) + VALUES (?, ?, ?, ?) + """, + [ + (4, "US EITC 3+ Children", "US", "eitc_3plus_children"), + (5, "US AGI 1_to_10k EITC 1 Child", "US", "eitc_1_child_agi"), + ], + ) + conn.executemany( + """ + INSERT INTO stratum_constraints ( + stratum_id, + variable, + operator, + value + ) VALUES (?, ?, ?, ?) + """, + [ + (4, "is_tax_filer", "==", "1"), + (4, "eitc_qualifying_children", ">=", "3"), + (5, "is_tax_filer", "==", "1"), + (5, "agi_bracket", "==", "1_to_10k"), + (5, "eitc_qualifying_children", "==", "1"), + ], + ) + conn.executemany( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + 8, + 4, + "eitc_claims", + 2023, + 10.0, + "COUNT", + None, + "IRS_SOI", + "EITC", + None, + None, + ), + ( + 9, + 5, + "eitc_amount", + 2023, + 20.0, + "AMOUNT", + None, + "IRS_SOI", + "EITC", + None, + None, + ), + ( + 10, + 5, + "eitc_claims", + 2023, + 30.0, + "COUNT", + None, + "IRS_SOI", + "EITC", + None, + None, + ), + ], + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "tax_unit_count", + "geo_level": "national", + "domain_variable": "eitc_child_count", + }, + { + "variable": "eitc", + "geo_level": "national", + "domain_variable": ( + "adjusted_gross_income,eitc,eitc_child_count" + ), + }, + { + "variable": "tax_unit_count", + "geo_level": "national", + "domain_variable": ( + "adjusted_gross_income,eitc,eitc_child_count" + ), + }, + ], + }, + ) + ) + + assert {target.metadata["target_id"] for target in target_set.targets} == { + 8, + 9, + 10, + } + + +def test_arch_provider_maps_census_stc_state_income_tax(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.execute( + """ + INSERT INTO strata (id, name, jurisdiction, definition_hash) + VALUES (?, ?, ?, ?) + """, + (4, "CA state government", "US", "ca_state_government"), + ) + conn.execute( + """ + INSERT INTO stratum_constraints ( + stratum_id, + variable, + operator, + value + ) VALUES (?, ?, ?, ?) + """, + (4, "state_fips", "==", "06"), + ) + conn.execute( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + 8, + 4, + "state_individual_income_tax_collections", + 2024, + 123.0, + "AMOUNT", + "STATE", + "CENSUS_STC", + "STC T40", + None, + None, + ), + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "target_cells": [ + { + "variable": "state_income_tax", + "geo_level": "state", + "domain_variable": None, + } + ], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [8] + target = target_set.targets[0] + assert target.measure == "state_income_tax" + assert target.entity is EntityType.TAX_UNIT + assert target.aggregation is TargetAggregation.SUM + assert target.metadata["source"] == "CENSUS_STC" + assert { + (target_filter.feature, target_filter.operator.value, target_filter.value) + for target_filter in target.filters + } == {("state_fips", "==", "06")} + + +def test_arch_provider_maps_soi_itemized_deduction_targets(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.executemany( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + 8, + 2, + "limited_state_local_taxes_amount", + 2024, + 122.0, + "AMOUNT", + None, + "IRS_SOI", + "Historic Table 2", + None, + None, + ), + ( + 9, + 2, + "interest_paid_deduction_amount", + 2024, + 169.0, + "AMOUNT", + None, + "IRS_SOI", + "Historic Table 2", + None, + "Composed from Schedule A interest lines.", + ), + ], + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "salt_deduction", + "geo_level": "national", + "domain_variable": None, + }, + { + "variable": "interest_deduction", + "geo_level": "national", + "domain_variable": None, + }, + ], + }, + ) + ) + + targets_by_measure = {target.measure: target for target in target_set.targets} + assert set(targets_by_measure) == {"interest_deduction", "salt_deduction"} + + salt_target = targets_by_measure["salt_deduction"] + assert salt_target.metadata["target_id"] == 8 + assert salt_target.entity is EntityType.TAX_UNIT + assert salt_target.aggregation is TargetAggregation.SUM + + interest_target = targets_by_measure["interest_deduction"] + assert interest_target.metadata["target_id"] == 9 + assert interest_target.entity is EntityType.TAX_UNIT + assert interest_target.metadata["notes"] == ( + "Composed from Schedule A interest lines." + ) + + +def test_arch_provider_infers_geo_level_from_constraints(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.execute( + """ + INSERT INTO strata (id, name, jurisdiction, definition_hash) + VALUES (?, ?, ?, ?) + """, + (4, "CA Filers", "US", "ca_filers"), + ) + conn.executemany( + """ + INSERT INTO stratum_constraints ( + stratum_id, + variable, + operator, + value + ) VALUES (?, ?, ?, ?) + """, + [ + (4, "is_tax_filer", "==", "1"), + (4, "state_fips", "==", "06"), + ], + ) + conn.execute( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + 8, + 4, + "adjusted_gross_income", + 2023, + 500.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "adjusted_gross_income", + "geo_level": "state", + "geographic_id": "6", + "domain_variable": None, + } + ], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [8] + assert target_set.targets[0].metadata["geo_level"] == "state" + + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "variables": ["adjusted_gross_income"], + "geo_levels": ["national"], + }, + ) + ) + + assert 8 not in {target.metadata["target_id"] for target in target_set.targets} + + +def test_arch_provider_maps_program_indicator_constraints_to_support_filters(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.execute( + """ + INSERT INTO strata (id, name, jurisdiction, definition_hash) + VALUES (?, ?, ?, ?) + """, + (4, "SNAP households", "US", "snap_households"), + ) + conn.execute( + """ + INSERT INTO stratum_constraints ( + stratum_id, + variable, + operator, + value + ) VALUES (?, ?, ?, ?) + """, + (4, "snap", "==", "1"), + ) + conn.executemany( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + 8, + 4, + "snap_household_count", + 2024, + 10.0, + "COUNT", + None, + "USDA_SNAP", + "USDA", + None, + None, + ), + ( + 9, + 4, + "snap_benefits", + 2024, + 500.0, + "AMOUNT", + None, + "USDA_SNAP", + "USDA", + None, + None, + ), + ], + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["USDA_SNAP"], + "target_cells": [ + { + "variable": "snap", + "geo_level": "national", + "domain_variable": "snap", + }, + { + "variable": "household_count", + "geo_level": "national", + "domain_variable": "snap", + }, + ], + }, + ) + ) + + assert {target.metadata["target_id"] for target in target_set.targets} == {8, 9} + for target in target_set.targets: + assert [ + (target_filter.feature, target_filter.operator.value, target_filter.value) + for target_filter in target.filters + ] == [("snap", ">", 0)] + + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["USDA_SNAP"], + "target_cells": [ + { + "variable": "snap", + "geo_level": "national", + "domain_variable": None, + }, + ], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [9] + + +def test_arch_provider_normalizes_congressional_district_constraints(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.execute( + """ + INSERT INTO strata (id, name, jurisdiction, definition_hash) + VALUES (?, ?, ?, ?) + """, + (4, "CA-01 Filers", "US", "ca_01_filers"), + ) + conn.executemany( + """ + INSERT INTO stratum_constraints ( + stratum_id, + variable, + operator, + value + ) VALUES (?, ?, ?, ?) + """, + [ + (4, "is_tax_filer", "=", "1"), + (4, "state_fips", "=", "06"), + (4, "congressional_district", "=", "01"), + ], + ) + conn.execute( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + 8, + 4, + "adjusted_gross_income", + 2023, + 500.0, + "AMOUNT", + "CONGRESSIONAL_DISTRICT", + "IRS_SOI", + "SOI", + None, + None, + ), + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "adjusted_gross_income", + "geo_level": "district", + "geographic_id": "0601", + "domain_variable": None, + } + ], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [8] + target = target_set.targets[0] + assert target.metadata["geo_level"] == "district" + assert ( + "congressional_district_geoid", + "==", + "0601", + ) in { + (target_filter.feature, target_filter.operator.value, target_filter.value) + for target_filter in target.filters + } + assert ( + "tax_unit_is_filer", + "==", + "1", + ) in { + (target_filter.feature, target_filter.operator.value, target_filter.value) + for target_filter in target.filters + } + + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "variables": ["adjusted_gross_income"], + "geo_levels": ["state"], + }, + ) + ) + + assert 8 not in {target.metadata["target_id"] for target in target_set.targets} + + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "variables": ["adjusted_gross_income"], + "geo_levels": ["congressional_district"], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [8] + + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "variables": ["adjusted_gross_income"], + "geo_levels": ["congressional-district"], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [8] + + +def test_arch_provider_no_domain_target_cell_excludes_domain_strata(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + _insert_multi_domain_soi_targets(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "adjusted_gross_income", + "geo_level": "national", + "domain_variable": None, + } + ], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [4] + + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "adjusted_gross_income", + "geo_level": "national", + "domain_variable": "", + } + ], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [4] + + +def test_arch_provider_current_year_partial_soi_falls_back_to_latest_soi(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.executemany( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + 8, + 2, + "tax_unit_count", + 2024, + 25.0, + "COUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ( + 9, + 2, + "adjusted_gross_income", + 2024, + 1_200.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ( + 10, + 2, + "income_tax_liability", + 2024, + 80.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ], + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "variables": ["tax_exempt_interest_income"], + }, + ) + ) + + assert {target.metadata["target_id"] for target in target_set.targets} == {1, 2} + assert {target.metadata["arch_source_period"] for target in target_set.targets} == { + 2023 + } + + +def test_arch_provider_uses_latest_soi_record_per_composition(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.execute( + """ + INSERT INTO strata (id, name, jurisdiction, definition_hash) + VALUES (?, ?, ?, ?) + """, + (4, "CA Filers", "US", "ca_filers"), + ) + conn.executemany( + """ + INSERT INTO stratum_constraints ( + stratum_id, + variable, + operator, + value + ) VALUES (?, ?, ?, ?) + """, + [ + (4, "is_tax_filer", "==", "1"), + (4, "state_fips", "==", "06"), + ], + ) + conn.executemany( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + 8, + 1, + "labor_force_count", + 2022, + 100.0, + "COUNT", + None, + "BLS", + "BLS", + None, + None, + ), + ( + 9, + 2, + "tax_unit_count", + 2024, + 25.0, + "COUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ( + 10, + 2, + "adjusted_gross_income", + 2024, + 1_200.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ( + 11, + 2, + "income_tax_liability", + 2024, + 80.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ( + 12, + 4, + "wages_salaries_amount", + 2022, + 90.0, + "AMOUNT", + "STATE", + "IRS_SOI", + "SOI", + None, + None, + ), + ], + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "employment_income", + "geo_level": "state", + "domain_variable": "employment_income", + } + ], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [12] + assert target_set.targets[0].period == 2024 + assert target_set.targets[0].metadata["arch_source_period"] == 2022 + + +def test_arch_provider_maps_income_tax_before_credits_targets(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.executemany( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + 8, + 2, + "income_tax_before_credits_returns", + 2023, + 50.0, + "COUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ( + 9, + 2, + "income_tax_before_credits_amount", + 2023, + 500.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ], + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "tax_unit_count", + "geo_level": "national", + "domain_variable": "income_tax_before_credits", + }, + { + "variable": "income_tax_before_credits", + "geo_level": "national", + "domain_variable": "income_tax_before_credits", + }, + ], + }, + ) + ) + + assert {target.metadata["target_id"] for target in target_set.targets} == {8, 9} + count_target = next( + target + for target in target_set.targets + if target.aggregation is TargetAggregation.COUNT + ) + assert { + (target_filter.feature, target_filter.operator.value, target_filter.value) + for target_filter in count_target.filters + } == { + ("tax_unit_is_filer", "==", "1"), + ("income_tax_before_credits", ">", 0), + } + + +def test_arch_provider_maps_real_estate_tax_targets(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.executemany( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + 8, + 2, + "real_estate_taxes_claims", + 2023, + 12.0, + "COUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ( + 9, + 2, + "real_estate_taxes_amount", + 2023, + 120.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI", + None, + None, + ), + ], + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "tax_unit_count", + "geo_level": "national", + "domain_variable": "real_estate_taxes", + }, + { + "variable": "real_estate_taxes", + "geo_level": "national", + "domain_variable": "real_estate_taxes", + }, + ], + }, + ) + ) + + assert {target.metadata["target_id"] for target in target_set.targets} == {8, 9} + assert {target.metadata["variable"] for target in target_set.targets} == { + "real_estate_taxes" + } + + +def test_arch_provider_maps_aca_aptc_amount_targets(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.execute( + """ + INSERT INTO strata (id, name, jurisdiction, definition_hash) + VALUES (?, ?, ?, ?) + """, + (4, "CA ACA Marketplace", "US", "ca_aca"), + ) + conn.execute( + """ + INSERT INTO stratum_constraints ( + stratum_id, + variable, + operator, + value + ) VALUES (?, ?, ?, ?) + """, + (4, "state_fips", "==", "06"), + ) + conn.executemany( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + 8, + 4, + "aca_aptc_amount", + 2024, + 100.0, + "AMOUNT", + "STATE", + "CMS_ACA", + "CMS OEP", + None, + None, + ), + ( + 9, + 4, + "aca_marketplace_enrollment", + 2024, + 200.0, + "COUNT", + "STATE", + "CMS_ACA", + "CMS OEP", + None, + None, + ), + ], + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["CMS_ACA"], + "target_cells": [ + { + "variable": "aca_ptc", + "geo_level": "state", + "domain_variable": None, + } + ], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [8] + assert target_set.targets[0].measure == "aca_ptc" + + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["CMS_ACA"], + "target_cells": [ + { + "variable": "person_count", + "geo_level": "state", + "domain_variable": "aca_ptc,is_aca_ptc_eligible", + } + ], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [9] + + +def test_arch_provider_maps_soi_aca_ptc_return_counts(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.execute( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + 8, + 2, + "aca_ptc_returns", + 2023, + 7_841_370.0, + "COUNT", + None, + "IRS_SOI", + "Historic Table 2", + None, + None, + ), + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "tax_unit_count", + "geo_level": "national", + "domain_variable": "used_aca_ptc", + } + ], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [8] + target = target_set.targets[0] + assert target.aggregation is TargetAggregation.COUNT + assert target.entity is EntityType.TAX_UNIT + assert { + (target_filter.feature, target_filter.operator.value, target_filter.value) + for target_filter in target.filters + } == { + ("tax_unit_is_filer", "==", "1"), + ("aca_ptc", ">", 0), + } + + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "tax_unit_count", + "geo_level": "national", + "domain_variable": ( + "selected_marketplace_plan_benchmark_ratio,used_aca_ptc" + ), + } + ], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [8] + + +def test_arch_provider_maps_soi_tax_filer_individual_counts(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.execute( + """ + INSERT INTO strata (id, name, jurisdiction, definition_hash) + VALUES (?, ?, ?, ?) + """, + (4, "CA AGI 1_to_10k", "US", "ca_agi_1_to_10k"), + ) + conn.executemany( + """ + INSERT INTO stratum_constraints ( + stratum_id, + variable, + operator, + value + ) VALUES (?, ?, ?, ?) + """, + [ + (4, "is_tax_filer", "==", "1"), + (4, "state_fips", "==", "06"), + (4, "agi_bracket", "==", "1_to_10k"), + ], + ) + conn.execute( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + 8, + 4, + "tax_filer_individual_count", + 2023, + 1_930_150.0, + "COUNT", + "STATE", + "IRS_SOI", + "Historic Table 2", + None, + "SOI number of individuals does not represent full population.", + ), + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "person_count", + "geo_level": "state", + "domain_variable": "adjusted_gross_income", + } + ], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [8] + target = target_set.targets[0] + assert target.aggregation is TargetAggregation.COUNT + assert target.entity is EntityType.PERSON + assert { + (target_filter.feature, target_filter.operator.value, target_filter.value) + for target_filter in target.filters + } == { + ("tax_unit_is_filer", "==", "1"), + ("state_fips", "==", "06"), + ("adjusted_gross_income", ">=", 1), + ("adjusted_gross_income", "<", 10_000), + } + + +def test_arch_provider_maps_medicaid_benefit_targets(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.execute( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + 8, + 1, + "medicaid_benefits", + 2024, + 931_692_000_000.0, + "AMOUNT", + "NATIONAL", + "CMS_MEDICAID", + "CMS NHE", + None, + None, + ), + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["CMS_MEDICAID"], + "target_cells": [ + { + "variable": "medicaid", + "geo_level": "national", + "domain_variable": None, + } + ], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [8] + target = target_set.targets[0] + assert target.measure == "medicaid" + assert target.entity is EntityType.PERSON + + +def test_arch_consumer_fact_provider_maps_wealth_and_part_b_targets(tmp_path): + jsonl_path = tmp_path / "consumer_facts.jsonl" + rows = [ + { + "schema_version": "arch.consumer_fact.v1", + "aggregate_fact_key": "arch.aggregate_fact.v2:net_worth", + "semantic_fact_key": "arch.semantic_fact.v2:net_worth", + "concept_alignment": { + "canonical_concept": ( + "federal_reserve.z1.households_nonprofits_net_worth" + ), + "source_concept": "federal_reserve.z1.fl152090005", + "relation": "source_label", + "authority": "arch-us", + }, + "geography": { + "id": "0100000US", + "level": "country", + }, + "label": "United States household net worth", + "observed_measure": { + "source_concept": "federal_reserve.z1.fl152090005", + "source_measure_id": "amount_outstanding", + "source_name": "federal_reserve", + "source_table": ( + "Z.1 B.101 Households and nonprofit organizations" + ), + "unit": "usd", + }, + "period": {"type": "calendar_year", "value": 2024}, + "source": { + "source_name": "federal_reserve", + "source_table": ( + "Z.1 B.101 Households and nonprofit organizations" + ), + "url": "https://www.federalreserve.gov/releases/z1/", + }, + "universe_constraints": {"domain": "household_balance_sheet"}, + "value": 169_619_200_000_000, + }, + { + "schema_version": "arch.consumer_fact.v1", + "aggregate_fact_key": "arch.aggregate_fact.v2:part_b", + "semantic_fact_key": "arch.semantic_fact.v2:part_b", + "concept_alignment": { + "canonical_concept": "cms_medicare.part_b_premium_income", + "source_concept": "cms_medicare.part_b_premium_income", + }, + "geography": { + "id": "0100000US", + "level": "country", + }, + "label": "United States Medicare Part B premium income", + "observed_measure": { + "source_concept": "cms_medicare.part_b_premium_income", + "source_measure_id": "actual_amount", + "source_name": "cms_medicare", + "source_table": "2025 Medicare Trustees Report Table III.C3", + "unit": "usd", + }, + "period": {"type": "calendar_year", "value": 2024}, + "source": { + "source_name": "cms_medicare", + "source_table": "2025 Medicare Trustees Report Table III.C3", + "url": "https://www.cms.gov/oact/tr/2025", + }, + "universe_constraints": { + "domain": "medicare_financing", + "constraints": [ + { + "operator": "==", + "role": "filter", + "value": "actual", + "variable": "amount_basis", + }, + { + "operator": "==", + "role": "filter", + "value": "part_b", + "variable": "medicare.part", + }, + { + "operator": "==", + "role": "filter", + "value": "premiums_from_enrollees", + "variable": "medicare.financing_component", + }, + ], + }, + "value": 139_837_000_000, + }, + ] + jsonl_path.write_text( + "".join(f"{json.dumps(row, sort_keys=True)}\n" for row in rows) + ) + + provider = ArchConsumerFactJSONLTargetProvider(jsonl_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "target_cells": [ + { + "variable": "net_worth", + "geo_level": "national", + "domain_variable": None, + }, + { + "variable": "medicare_part_b_premiums", + "geo_level": "national", + "domain_variable": None, + }, + ], + }, + ) + ) + + targets_by_measure = {target.measure: target for target in target_set.targets} + assert set(targets_by_measure) == {"medicare_part_b_premiums", "net_worth"} + + net_worth = targets_by_measure["net_worth"] + assert net_worth.entity is EntityType.HOUSEHOLD + assert net_worth.aggregation is TargetAggregation.SUM + assert net_worth.value == pytest.approx(169_619_200_000_000) + assert net_worth.filters == () + assert net_worth.metadata["source"] == "FEDERAL_RESERVE" + assert net_worth.metadata["arch_source_concept"] == ( + "federal_reserve.z1.fl152090005" + ) + + part_b = targets_by_measure["medicare_part_b_premiums"] + assert part_b.entity is EntityType.PERSON + assert part_b.aggregation is TargetAggregation.SUM + assert part_b.value == pytest.approx(139_837_000_000) + assert part_b.filters == () + assert part_b.metadata["source"] == "CMS_MEDICARE" + assert part_b.metadata["arch_concept"] == "cms_medicare.part_b_premium_income" + + +def test_arch_provider_maps_ssa_benefit_targets(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.executemany( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + 8, + 1, + "social_security_benefits", + 2024, + 1_471_195_000_000.0, + "AMOUNT", + "NATIONAL", + "SSA", + "SSA Supplement", + None, + None, + ), + ( + 9, + 1, + "social_security_retirement_benefits", + 2024, + 1_111_728_000_000.0, + "AMOUNT", + "NATIONAL", + "SSA", + "SSA Supplement", + None, + None, + ), + ( + 10, + 1, + "ssi_payments", + 2024, + 63_079_493_000.0, + "AMOUNT", + "NATIONAL", + "SSA", + "SSA Supplement", + None, + None, + ), + ], + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["SSA"], + "target_cells": [ + { + "variable": "social_security", + "geo_level": "national", + "domain_variable": None, + }, + { + "variable": "social_security_retirement", + "geo_level": "national", + "domain_variable": None, + }, + { + "variable": "ssi", + "geo_level": "national", + "domain_variable": None, + }, + ], + }, + ) + ) + + assert {target.metadata["target_id"] for target in target_set.targets} == { + 8, + 9, + 10, + } + assert {target.measure for target in target_set.targets} == { + "social_security", + "social_security_retirement", + "ssi", + } + assert {target.entity for target in target_set.targets} == {EntityType.PERSON} + + +def test_arch_provider_maps_tanf_cash_assistance_target(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.executemany( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + 8, + 1, + "tanf_cash_assistance", + 2024, + 7_788_317_474.55, + "AMOUNT", + "NATIONAL", + "HHS_ACF_TANF", + "ACF TANF Financial Data", + None, + None, + ), + ( + 9, + 1, + "tanf_family_count", + 2024, + 841_208.67, + "COUNT", + "NATIONAL", + "HHS_ACF_TANF", + "ACF TANF Caseload Data", + None, + None, + ), + ], + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["HHS_ACF_TANF"], + "target_cells": [ + { + "variable": "tanf", + "geo_level": "national", + "domain_variable": None, + }, + { + "variable": "spm_unit_count", + "geo_level": "national", + "domain_variable": "tanf", + }, + ], + }, + ) + ) + + assert [target.metadata["target_id"] for target in target_set.targets] == [8, 9] + targets_by_id = { + target.metadata["target_id"]: target for target in target_set.targets + } + assert targets_by_id[8].measure == "tanf" + assert targets_by_id[8].entity is EntityType.SPM_UNIT + assert targets_by_id[9].measure is None + assert targets_by_id[9].metadata["variable"] == "spm_unit_count" + assert targets_by_id[9].entity is EntityType.SPM_UNIT + + +def test_arch_provider_maps_w2_tip_income_without_source_year_labor_force(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + _insert_w2_tip_income_target(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "variables": ["tip_income"], + }, + ) + ) + + assert len(target_set.targets) == 1 + target = target_set.targets[0] + assert target.entity is EntityType.PERSON + assert target.measure == "tip_income" + assert target.aggregation is TargetAggregation.SUM + assert target.value == pytest.approx(121.0) + assert { + (target_filter.feature, target_filter.operator.value, target_filter.value) + for target_filter in target.filters + } == {("tip_income", ">", "0")} + assert target.metadata["arch_source_period"] == 2020 + assert target.metadata["arch_aging_count_method"] == "not_required" + assert target.metadata["arch_aging_amount_method"] == ( + "soi_total_agi_last_growth_extrapolation" + ) + + +def test_arch_provider_maps_ira_contribution_targets(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.executemany( + """ + INSERT INTO strata (id, name, jurisdiction, definition_hash) + VALUES (?, ?, ?, ?) + """, + [ + (12, "US taxpayers with traditional IRA contributions", "US", "trad_ira"), + (13, "US taxpayers with Roth IRA contributions", "US", "roth_ira"), + ], + ) + conn.executemany( + """ + INSERT INTO stratum_constraints ( + stratum_id, + variable, + operator, + value + ) VALUES (?, ?, ?, ?) + """, + [ + (12, "traditional_ira_contributions", ">", "0"), + (13, "roth_ira_contributions", ">", "0"), + ], + ) + conn.executemany( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + 12, + 12, + "traditional_ira_contributions", + 2022, + 50.0, + "AMOUNT", + "NATIONAL", + "IRS_SOI", + "IRA", + None, + None, + ), + ( + 13, + 13, + "roth_ira_contributions", + 2022, + 75.0, + "AMOUNT", + "NATIONAL", + "IRS_SOI", + "IRA", + None, + None, + ), + ], + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "traditional_ira_contributions", + "geo_level": "national", + "domain_variable": None, + }, + { + "variable": "roth_ira_contributions", + "geo_level": "national", + "domain_variable": None, + }, + ], + }, + ) + ) + + assert {target.measure for target in target_set.targets} == { + "traditional_ira_contributions", + "roth_ira_contributions", + } + assert {target.entity for target in target_set.targets} == {EntityType.PERSON} + assert { + target.metadata["arch_aging_count_method"] for target in target_set.targets + } == {"not_required"} + + +def test_us_pipeline_can_select_arch_target_provider(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + pipeline = USMicroplexPipeline( + USMicroplexBuildConfig( + arch_targets_db=str(db_path), + calibration_target_source="arch", + ) + ) + + provider, source = pipeline._resolve_calibration_target_provider() + + assert source == "arch" + assert isinstance(provider, ArchSQLiteTargetProvider) + + +def test_arch_target_profile_coverage_summarizes_custom_cells(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + report = summarize_arch_target_profile_coverage( + provider, + period=2024, + profile_name="custom", + target_cells=( + PolicyEngineUSTargetCell( + "adjusted_gross_income", + geo_level="national", + domain_variable=None, + ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="tax_exempt_interest_income", + ), + PolicyEngineUSTargetCell( + "employment_income", + geo_level="national", + domain_variable="employment_income", + ), + ), + ) + + assert report.target_cell_count == 3 + assert report.covered_cell_count == 2 + assert report.uncovered_cell_count == 1 + assert report.coverage_rate == pytest.approx(2 / 3) + assert report.by_geo_level == { + "national": { + "target_cell_count": 3, + "covered_cell_count": 2, + "uncovered_cell_count": 1, + } + } + assert report.by_variable["adjusted_gross_income"]["covered_cell_count"] == 1 + assert report.by_variable["employment_income"]["uncovered_cell_count"] == 1 + + payload = report.to_dict() + assert payload["profile_name"] == "custom" + assert payload["cells"][0]["target_ids"] == [4] + assert payload["cells"][1]["target_ids"] == [1] + assert payload["cells"][2]["covered"] is False + + +def test_arch_target_profile_coverage_accepts_soi_itemized_domain(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + _insert_irs_soi_itemized_deduction_targets(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + report = summarize_arch_target_profile_coverage( + provider, + period=2024, + profile_name="custom", + target_cells=( + PolicyEngineUSTargetCell( + "medical_expense_deduction", + geo_level="national", + domain_variable="medical_expense_deduction", + ), + PolicyEngineUSTargetCell( + "medical_expense_deduction", + geo_level="national", + domain_variable=("medical_expense_deduction,tax_unit_itemizes"), + ), + PolicyEngineUSTargetCell( + "real_estate_taxes", + geo_level="national", + domain_variable="real_estate_taxes,tax_unit_itemizes", + ), + PolicyEngineUSTargetCell( + "salt", + geo_level="national", + domain_variable="salt,tax_unit_itemizes", + ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable=("medical_expense_deduction,tax_unit_itemizes"), + ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="real_estate_taxes,tax_unit_itemizes", + ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="salt,tax_unit_itemizes", + ), + ), + ) + + assert report.target_cell_count == 7 + assert report.covered_cell_count == 7 + assert { + (cell.cell["variable"], cell.cell["domain_variable"]): cell.target_ids + for cell in report.cells + } == { + ( + "medical_expense_deduction", + "medical_expense_deduction", + ): (12,), + ( + "medical_expense_deduction", + "medical_expense_deduction,tax_unit_itemizes", + ): (12,), + ( + "real_estate_taxes", + "real_estate_taxes,tax_unit_itemizes", + ): (13,), + ("salt", "salt,tax_unit_itemizes"): (14,), + ( + "tax_unit_count", + "medical_expense_deduction,tax_unit_itemizes", + ): (15,), + ( + "tax_unit_count", + "real_estate_taxes,tax_unit_itemizes", + ): (16,), + ("tax_unit_count", "salt,tax_unit_itemizes"): (17,), + } + + +def test_arch_target_profile_coverage_accepts_soi_medical_dental_domain( + tmp_path, +): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + conn = sqlite3.connect(db_path) + conn.execute( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + 800, + 2, + "medical_dental_expense_amount", + 2023, + 100.0, + "AMOUNT", + None, + "IRS_SOI", + "SOI Historic Table 2 state broad totals", + None, + None, + ), + ) + conn.commit() + conn.close() + + provider = ArchSQLiteTargetProvider(db_path) + report = summarize_arch_target_profile_coverage( + provider, + period=2024, + profile_name="custom", + target_cells=( + PolicyEngineUSTargetCell( + "medical_expense_deduction", + geo_level="national", + domain_variable=( + "medical_expense_deduction,tax_unit_itemizes" + ), + ), + ), + ) + + assert report.covered_cell_count == 1 + assert report.cells[0].target_ids == (800,) + + +def test_arch_target_profile_coverage_rolls_complete_state_targets_to_national( + tmp_path, +): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + _insert_complete_state_rollup_targets(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + report = summarize_arch_target_profile_coverage( + provider, + period=2024, + profile_name="custom", + target_cells=( + PolicyEngineUSTargetCell( + "non_refundable_ctc", + geo_level="national", + domain_variable="non_refundable_ctc", + ), + PolicyEngineUSTargetCell( + "non_refundable_ctc", + geo_level="national", + domain_variable="adjusted_gross_income,non_refundable_ctc", + ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="non_refundable_ctc", + ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="adjusted_gross_income,non_refundable_ctc", + ), + PolicyEngineUSTargetCell( + "aca_ptc", + geo_level="national", + domain_variable="aca_ptc", + ), + ), + ) + + assert report.target_cell_count == 5 + assert report.covered_cell_count == 5 + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "target_cells": [cell.cell for cell in report.cells], + }, + ) + ) + rollup_targets = { + (target.measure or target.metadata["variable"], target.aggregation): target + for target in target_set.targets + if target.metadata["geo_level"] == "national" + and str(target.metadata["target_id"]).startswith("-") + } + assert rollup_targets[ + ("non_refundable_ctc", TargetAggregation.SUM) + ].value == pytest.approx(sum(1_000.0 + index for index in range(51))) + assert rollup_targets[ + ("non_refundable_ctc", TargetAggregation.COUNT) + ].value == pytest.approx(sum(100.0 + index for index in range(51))) + assert rollup_targets[("aca_ptc", TargetAggregation.SUM)].value == pytest.approx( + sum(10_000.0 + index for index in range(51)) + ) + + +def test_arch_target_profile_coverage_reports_current_pe_profile(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + report = summarize_arch_target_profile_coverage( + provider, + period=2024, + profile_name="pe_native_broad", + ) + + assert report.target_cell_count == 189 + assert report.covered_cell_count == 4 + assert report.uncovered_cell_count == 185 + assert report.by_geo_level["national"]["covered_cell_count"] == 3 + assert report.by_geo_level["state"]["covered_cell_count"] == 1 + + covered_cells = { + ( + cell.cell["variable"], + cell.cell["geo_level"], + cell.cell["domain_variable"], + ): cell.target_ids + for cell in report.cells + if cell.covered + } + assert covered_cells == { + ("adjusted_gross_income", "national", None): (4,), + ( + "tax_exempt_interest_income", + "national", + "tax_exempt_interest_income", + ): (2,), + ("tax_unit_count", "national", "tax_exempt_interest_income"): (1,), + ("tax_unit_count", "state", "adjusted_gross_income"): (7,), + } + + +def test_arch_target_gap_queue_describes_missing_loader_rows(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + report = summarize_arch_target_gap_queue( + provider, + period=2024, + profile_name="custom", + target_cells=( + PolicyEngineUSTargetCell( + "employment_income", + geo_level="national", + domain_variable="employment_income", + ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="employment_income", + ), + ), + ) + + assert report.row_count == 2 + assert report.covered_row_count == 0 + assert report.uncovered_row_count == 2 + assert report.by_loader_status == {"missing_arch_target_record": 2} + assert report.by_gap_category == {"ready_primary_loader": 2} + + rows_by_variable = {row.variable: row for row in report.rows} + amount_row = rows_by_variable["employment_income"] + assert amount_row.priority == 1 + assert amount_row.expected_source == "IRS_SOI" + assert amount_row.expected_source_table == "IRS SOI Publication 1304 Table 1.4" + assert amount_row.expected_arch_variable == "wages_salaries_amount" + assert amount_row.expected_target_type == "AMOUNT" + assert amount_row.expected_entity == "person" + assert amount_row.expected_aggregation == "sum" + assert amount_row.gap_category == "ready_primary_loader" + assert amount_row.expected_filters == ( + { + "kind": "domain", + "feature": "employment_income", + "operator": ">", + "value": 0, + }, + ) + assert amount_row.agent_task_kind == "add_arch_source_loader_or_target_record" + + count_row = rows_by_variable["tax_unit_count"] + assert count_row.expected_arch_variable == "wages_salaries_returns" + assert count_row.expected_target_type == "COUNT" + assert count_row.expected_entity == "tax_unit" + assert count_row.expected_aggregation == "count" + assert count_row.gap_category == "ready_primary_loader" + + +def test_arch_target_gap_queue_points_full_population_amounts_to_bea(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + report = summarize_arch_target_gap_queue( + provider, + period=2024, + profile_name="custom", + target_cells=( + PolicyEngineUSTargetCell( + "employment_income", + geo_level="national", + domain_variable=None, + ), + PolicyEngineUSTargetCell( + "self_employment_income", + geo_level="national", + domain_variable=None, + ), + PolicyEngineUSTargetCell( + "dividend_income", + geo_level="national", + domain_variable=None, + ), + PolicyEngineUSTargetCell( + "self_employment_income", + geo_level="state", + domain_variable=None, + ), + ), + ) + + rows_by_cell = {(row.variable, row.geo_level): row for row in report.rows} + assert rows_by_cell[("employment_income", "national")].expected_source == "BEA" + assert rows_by_cell[("employment_income", "national")].expected_arch_variable == ( + "wages_salaries_amount" + ) + assert rows_by_cell[("employment_income", "national")].expected_source_table == ( + "BEA NIPA annual total wages and salaries" + ) + assert ( + rows_by_cell[("self_employment_income", "national")].expected_source + == "IRS_SOI" + ) + assert rows_by_cell[ + ("self_employment_income", "national") + ].expected_arch_variable == ("schedule_c_income_amount") + assert rows_by_cell[ + ("self_employment_income", "national") + ].expected_source_table == ("IRS SOI Publication 1304") + assert rows_by_cell[("dividend_income", "national")].expected_source == "BEA" + assert rows_by_cell[("dividend_income", "national")].expected_arch_variable == ( + "personal_dividend_income_amount" + ) + assert ( + rows_by_cell[("self_employment_income", "state")].expected_source == "IRS_SOI" + ) + assert rows_by_cell[("self_employment_income", "state")].expected_source_table == ( + "IRS SOI Publication 1304" + ) + + +def test_arch_target_gap_queue_marks_multi_domain_rows_for_review(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + report = summarize_arch_target_gap_queue( + provider, + period=2024, + profile_name="custom", + target_cells=( + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="adjusted_gross_income,medical_expense_deduction", + ), + ), + ) + + row = report.rows[0] + assert row.expected_source == "IRS_SOI" + assert row.expected_arch_variable is None + assert row.loader_status == "needs_source_mapping_review" + assert row.gap_category == "source_mapping_review" + assert row.agent_task_kind == "review_source_mapping" + assert "multi-domain cells" in row.notes + + +def test_arch_target_gap_queue_points_eitc_child_rows_to_soi_table_2(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + report = summarize_arch_target_gap_queue( + provider, + period=2024, + profile_name="custom", + target_cells=( + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="eitc_child_count", + ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="adjusted_gross_income,eitc,eitc_child_count", + ), + ), + ) + + assert {row.expected_arch_variable for row in report.rows} == {"eitc_claims"} + assert {row.expected_source_table for row in report.rows} == { + "IRS SOI Historic Table 2" + } + assert {row.expected_target_type for row in report.rows} == {"COUNT"} + assert {row.loader_status for row in report.rows} == {"missing_arch_target_record"} + assert {row.gap_category for row in report.rows} == {"ready_primary_loader"} + + +def test_arch_target_gap_queue_points_aca_ptc_counts_to_soi_table_2(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + report = summarize_arch_target_gap_queue( + provider, + period=2024, + profile_name="custom", + target_cells=( + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="used_aca_ptc", + ), + ), + ) + + row = report.rows[0] + assert row.expected_source == "IRS_SOI" + assert row.expected_source_table == "IRS SOI Historic Table 2" + assert row.expected_arch_variable == "aca_ptc_returns" + assert row.expected_target_type == "COUNT" + assert row.expected_entity == "tax_unit" + assert row.loader_status == "missing_arch_target_record" + assert row.gap_category == "ready_primary_loader" + + +def test_arch_target_gap_queue_points_income_tax_return_counts_to_soi_table_2( + tmp_path, +): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + report = summarize_arch_target_gap_queue( + provider, + period=2024, + profile_name="custom", + target_cells=( + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable=( + "adjusted_gross_income,income_tax_before_credits" + ), + ), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable=( + "adjusted_gross_income,filing_status," + "income_tax_before_credits" + ), + ), + ), + ) + + assert {row.expected_source for row in report.rows} == {"IRS_SOI"} + assert {row.expected_source_table for row in report.rows} == { + "IRS SOI Historic Table 2" + } + assert {row.expected_arch_variable for row in report.rows} == { + "income_tax_before_credits_returns" + } + assert {row.expected_target_type for row in report.rows} == {"COUNT"} + assert {row.expected_entity for row in report.rows} == {"tax_unit"} + + +def test_arch_target_gap_queue_points_energy_subsidy_households_to_liheap( + tmp_path, +): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + report = summarize_arch_target_gap_queue( + provider, + period=2024, + profile_name="custom", + target_cells=( + PolicyEngineUSTargetCell( + "household_count", + geo_level="national", + domain_variable="spm_unit_energy_subsidy_reported", + ), + ), + ) + + row = report.rows[0] + assert row.expected_source == "HHS_ACF_LIHEAP" + assert row.expected_source_table == "HHS ACF LIHEAP National Profile" + assert row.expected_arch_variable == "liheap_household_count" + assert row.expected_target_type == "COUNT" + assert row.expected_entity == "household" + assert row.loader_status == "missing_arch_target_record" + assert row.gap_category == "ready_primary_loader" + + +def test_arch_target_gap_queue_points_retirement_contributions_to_soi( + tmp_path, +): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + report = summarize_arch_target_gap_queue( + provider, + period=2024, + profile_name="custom", + target_cells=( + PolicyEngineUSTargetCell( + "traditional_401k_contributions", + geo_level="national", + ), + PolicyEngineUSTargetCell("roth_401k_contributions", geo_level="national"), + PolicyEngineUSTargetCell( + "self_employed_pension_contribution_ald", + geo_level="national", + ), + ), + ) + + rows_by_variable = {row.variable: row for row in report.rows} + traditional = rows_by_variable["traditional_401k_contributions"] + roth = rows_by_variable["roth_401k_contributions"] + self_employed = rows_by_variable["self_employed_pension_contribution_ald"] + + assert {row.expected_source for row in report.rows} == {"IRS_SOI"} + assert traditional.expected_source_table == "IRS SOI Form W-2 Statistics Table 4.B" + assert traditional.expected_arch_variable == "traditional_401k_contributions" + assert traditional.expected_entity == "person" + assert roth.expected_source_table == "IRS SOI Form W-2 Statistics Table 4.B" + assert roth.expected_arch_variable == "roth_401k_contributions" + assert roth.expected_entity == "person" + assert self_employed.expected_source_table == ( + "IRS SOI Publication 1304 Table 1.4" + ) + assert self_employed.expected_arch_variable == ( + "self_employed_pension_contribution_ald" + ) + assert self_employed.expected_entity == "tax_unit" + assert {row.gap_category for row in report.rows} == {"ready_primary_loader"} + + +def test_arch_target_gap_queue_points_agi_person_counts_to_soi_table_2(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + report = summarize_arch_target_gap_queue( + provider, + period=2024, + profile_name="custom", + target_cells=( + PolicyEngineUSTargetCell( + "person_count", + geo_level="state", + domain_variable="adjusted_gross_income", + ), + ), + ) + + row = report.rows[0] + assert row.expected_source == "IRS_SOI" + assert row.expected_source_table == "IRS SOI Historic Table 2" + assert row.expected_arch_variable == "tax_filer_individual_count" + assert row.expected_target_type == "COUNT" + assert row.expected_entity == "person" + assert row.loader_status == "missing_arch_target_record" + assert row.gap_category == "ready_primary_loader" + + +def test_arch_target_gap_queue_points_state_income_tax_to_census_stc(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + report = summarize_arch_target_gap_queue( + provider, + period=2024, + profile_name="custom", + target_cells=(PolicyEngineUSTargetCell("state_income_tax", geo_level="state"),), + ) + + row = report.rows[0] + assert row.expected_source == "CENSUS_STC" + assert row.expected_source_table == "Census State Tax Collections item T40" + assert row.expected_arch_variable == "state_individual_income_tax_collections" + assert row.expected_target_type == "AMOUNT" + assert row.expected_entity == "tax_unit" + assert row.loader_status == "missing_arch_target_record" + assert row.gap_category == "ready_primary_loader" + + +def test_arch_target_gap_queue_points_itemized_deductions_to_soi_table_2(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + report = summarize_arch_target_gap_queue( + provider, + period=2024, + profile_name="custom", + target_cells=( + PolicyEngineUSTargetCell("salt_deduction", geo_level="national"), + PolicyEngineUSTargetCell("interest_deduction", geo_level="national"), + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="salt,tax_unit_itemizes", + ), + ), + ) + + rows_by_variable = {row.variable: row for row in report.rows} + rows_by_cell = {(row.variable, row.domain_variable): row for row in report.rows} + salt_row = rows_by_variable["salt_deduction"] + assert salt_row.expected_source == "IRS_SOI" + assert salt_row.expected_source_table == "IRS SOI Historic Table 2" + assert salt_row.expected_arch_variable == "limited_state_local_taxes_amount" + assert salt_row.expected_target_type == "AMOUNT" + assert salt_row.expected_entity == "tax_unit" + assert salt_row.loader_status == "missing_arch_target_record" + assert salt_row.gap_category == "ready_primary_loader" + + interest_row = rows_by_variable["interest_deduction"] + assert interest_row.expected_source == "IRS_SOI" + assert interest_row.expected_source_table == "IRS SOI Historic Table 2" + assert interest_row.expected_arch_variable == "interest_paid_deduction_amount" + assert interest_row.expected_target_type == "AMOUNT" + assert interest_row.expected_entity == "tax_unit" + assert interest_row.loader_status == "missing_arch_target_record" + assert interest_row.gap_category == "ready_primary_loader" + + salt_count_row = rows_by_cell[("tax_unit_count", "salt,tax_unit_itemizes")] + assert salt_count_row.expected_source == "IRS_SOI" + assert salt_count_row.expected_source_table == ( + "IRS SOI itemized deduction or credit tables" + ) + assert salt_count_row.expected_arch_variable == "salt_claims" + assert salt_count_row.expected_target_type == "COUNT" + assert salt_count_row.expected_entity == "tax_unit" + assert salt_count_row.loader_status == "missing_arch_target_record" + assert salt_count_row.gap_category == "ready_primary_loader" + + +def test_arch_target_gap_queue_points_income_tax_positive_to_soi_liability(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + report = summarize_arch_target_gap_queue( + provider, + period=2024, + profile_name="custom", + target_cells=( + PolicyEngineUSTargetCell("income_tax_positive", geo_level="national"), + ), + ) + + row = report.rows[0] + assert row.expected_source == "IRS_SOI" + assert row.expected_source_table == ( + "IRS SOI Publication 1304 Table 1.1 or Historic Table 2" + ) + assert row.expected_arch_variable == "income_tax_liability" + assert row.expected_target_type == "AMOUNT" + assert row.expected_entity == "tax_unit" + assert row.loader_status == "missing_arch_target_record" + assert row.gap_category == "ready_primary_loader" + + +def test_arch_target_gap_queue_deprioritizes_survey_or_model_inputs(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + report = summarize_arch_target_gap_queue( + provider, + period=2024, + profile_name="custom", + target_cells=( + PolicyEngineUSTargetCell("rent", geo_level="national"), + PolicyEngineUSTargetCell( + "person_count", + geo_level="national", + domain_variable="ssn_card_type", + ), + ), + ) + + assert report.by_gap_category == {"survey_or_model_input_deprioritized": 2} + assert {row.gap_category for row in report.rows} == { + "survey_or_model_input_deprioritized" + } + assert {row.agent_task_kind for row in report.rows} == { + "defer_or_review_non_primary_source" + } + assert all( + "survey/model-input proxy deprioritized" in row.notes for row in report.rows + ) + + +def test_arch_target_gap_queue_classifies_loaded_wrong_geography(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + report = summarize_arch_target_gap_queue( + provider, + period=2024, + profile_name="custom", + target_cells=( + PolicyEngineUSTargetCell( + "tax_unit_count", + geo_level="national", + domain_variable="adjusted_gross_income", + ), + ), + ) + + row = report.rows[0] + assert row.expected_source == "IRS_SOI" + assert row.expected_arch_variable == "tax_unit_count" + assert row.loader_status == "loaded_arch_variable_missing_geography" + assert row.gap_category == "ready_rollup_or_geography" + assert row.agent_task_kind == "add_arch_rollup_or_geography_records" + + +def test_arch_target_gap_queue_can_include_covered_rows(tmp_path): + db_path = tmp_path / "arch_targets.db" + _create_arch_targets_db(db_path) + + provider = ArchSQLiteTargetProvider(db_path) + report = summarize_arch_target_gap_queue( + provider, + period=2024, + profile_name="custom", + include_covered=True, + target_cells=( + PolicyEngineUSTargetCell( + "adjusted_gross_income", + geo_level="national", + domain_variable=None, + ), + ), + ) + + assert report.row_count == 1 + row = report.rows[0] + assert row.covered is True + assert row.target_ids == (4,) + assert row.expected_filters == () + assert row.loader_status == "covered" + assert row.gap_category == "covered" + assert row.agent_task_kind == "none" + + +def test_arch_target_gap_queue_cli_writes_csv(tmp_path): + db_path = tmp_path / "arch_targets.db" + output_path = tmp_path / "gaps.csv" + _create_arch_targets_db(db_path) + + exit_code = main_gaps( + [ + "--arch-targets-db", + str(db_path), + "--period", + "2024", + "--profile", + "pe_native_broad", + "--format", + "csv", + "--output", + str(output_path), + ] + ) + + assert exit_code == 0 + text = output_path.read_text() + assert text.startswith("priority,profile_name,period,variable") + assert "gap_category" in text + assert "employment_income" in text + assert "missing_arch_target_record" in text + + +def test_arch_target_refresh_cli_discovers_artifact_and_writes_snapshot(tmp_path): + artifact_root = tmp_path / "artifacts" + artifact_root.mkdir() + db_path = artifact_root / "arch_targets_fixture.db" + output_dir = tmp_path / "snapshot" + _create_arch_targets_db(db_path) + + exit_code = main_refresh( + [ + "--artifact-root", + str(artifact_root), + "--period", + "2024", + "--profile", + "pe_native_broad", + "--output-dir", + str(output_dir), + ] + ) + + assert exit_code == 0 + + coverage_path = output_dir / "pe_native_broad_2024_coverage.json" + gaps_json_path = output_dir / "pe_native_broad_2024_gaps.json" + gaps_csv_path = output_dir / "pe_native_broad_2024_gaps.csv" + summary_path = output_dir / "pe_native_broad_2024_summary.md" + + coverage = json.loads(coverage_path.read_text()) + gaps = json.loads(gaps_json_path.read_text()) + gaps_csv = gaps_csv_path.read_text() + summary = summary_path.read_text() + + assert coverage["target_cell_count"] == 189 + assert coverage["covered_cell_count"] == 4 + assert gaps["uncovered_row_count"] == 185 + assert gaps_csv.startswith("priority,profile_name,period,variable") + assert "Coverage rate" in summary + assert str(db_path.resolve()) in summary diff --git a/tests/targets/test_arch_facts.py b/tests/targets/test_arch_facts.py new file mode 100644 index 0000000..5792d51 --- /dev/null +++ b/tests/targets/test_arch_facts.py @@ -0,0 +1,2969 @@ +from __future__ import annotations + +import json +import sqlite3 +from pathlib import Path +from typing import Any + +import pytest +from microplex.targets import TargetQuery + +from microplex_us.pipelines.us import USMicroplexBuildConfig, USMicroplexPipeline +from microplex_us.targets import ( + ArchCompositeSQLiteTargetProvider, + ArchConsumerFactJSONLTargetProvider, + ArchFactSQLiteTargetProvider, + ArchSQLiteTargetProvider, + resolve_arch_sqlite_target_provider, + summarize_arch_target_gap_queue, + summarize_arch_target_profile_coverage, +) +from microplex_us.targets.arch import main_parity, main_smoke + + +def _create_value_constraint_target_db(path: Path) -> None: + conn = sqlite3.connect(path) + conn.executescript( + """ + CREATE TABLE strata ( + id INTEGER PRIMARY KEY, + name TEXT, + jurisdiction TEXT, + definition_hash TEXT + ); + + CREATE TABLE stratum_constraints ( + id INTEGER PRIMARY KEY, + stratum_id INTEGER NOT NULL, + variable TEXT NOT NULL, + operator TEXT NOT NULL, + value TEXT NOT NULL + ); + + CREATE TABLE targets ( + id INTEGER PRIMARY KEY, + stratum_id INTEGER NOT NULL, + variable TEXT NOT NULL, + period INTEGER NOT NULL, + value REAL NOT NULL, + target_type TEXT NOT NULL, + geographic_level TEXT, + source TEXT NOT NULL, + source_table TEXT, + source_url TEXT, + notes TEXT + ); + """ + ) + conn.executemany( + """ + INSERT INTO strata (id, name, jurisdiction, definition_hash) + VALUES (?, ?, ?, ?) + """, + [ + (1, "US All Filers", "US", "all"), + (2, "US Filers AGI 1_to_5k", "US", "1_to_5k"), + ], + ) + conn.executemany( + """ + INSERT INTO stratum_constraints (stratum_id, variable, operator, value) + VALUES (?, ?, ?, ?) + """, + [ + (1, "is_tax_filer", "==", "1"), + (2, "is_tax_filer", "==", "1"), + (2, "adjusted_gross_income", ">=", "1"), + (2, "adjusted_gross_income", "<", "5000"), + ], + ) + conn.executemany( + """ + INSERT INTO targets ( + id, + stratum_id, + variable, + period, + value, + target_type, + geographic_level, + source, + source_table, + source_url, + notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + 1, + 1, + "tax_unit_count", + 2023, + 160_602_107, + "COUNT", + "NATIONAL", + "IRS_SOI", + "Publication 1304 Table 1.1", + "https://www.irs.gov/pub/irs-soi/23in11si.xls", + None, + ), + ( + 2, + 1, + "adjusted_gross_income", + 2023, + 15_286_017_359_000, + "AMOUNT", + "NATIONAL", + "IRS_SOI", + "Publication 1304 Table 1.1", + "https://www.irs.gov/pub/irs-soi/23in11si.xls", + None, + ), + ( + 3, + 1, + "income_tax_liability", + 2023, + 2_147_909_818_000, + "AMOUNT", + "NATIONAL", + "IRS_SOI", + "Publication 1304 Table 1.1", + "https://www.irs.gov/pub/irs-soi/23in11si.xls", + None, + ), + ( + 4, + 2, + "tax_unit_count", + 2023, + 7_357_751, + "COUNT", + "NATIONAL", + "IRS_SOI", + "Publication 1304 Table 1.1", + "https://www.irs.gov/pub/irs-soi/23in11si.xls", + None, + ), + ( + 5, + 2, + "adjusted_gross_income", + 2023, + 20_372_694_000, + "AMOUNT", + "NATIONAL", + "IRS_SOI", + "Publication 1304 Table 1.1", + "https://www.irs.gov/pub/irs-soi/23in11si.xls", + None, + ), + ], + ) + conn.commit() + conn.close() + + +def _create_arch_fact_db(path: Path) -> None: + conn = sqlite3.connect(path) + conn.executescript( + """ + CREATE TABLE aggregate_facts ( + fact_key TEXT PRIMARY KEY, + source_record_id TEXT, + value_numeric REAL, + value_text TEXT, + value_json TEXT NOT NULL, + period_value TEXT NOT NULL, + geography_level TEXT NOT NULL, + geography_id TEXT NOT NULL, + geography_name TEXT, + measure_concept TEXT NOT NULL, + measure_source_concept TEXT, + measure_concept_relation TEXT, + measure_concept_authority TEXT, + measure_concept_evidence_url TEXT, + measure_concept_evidence_notes TEXT, + measure_legal_vintage TEXT, + measure_unit TEXT NOT NULL, + aggregation_method TEXT NOT NULL, + domain TEXT NOT NULL, + filters_json TEXT NOT NULL, + label TEXT, + source_name TEXT, + source_table TEXT, + source_url TEXT, + source_method_notes TEXT + ); + + CREATE TABLE aggregate_constraints ( + fact_key TEXT NOT NULL, + ordinal INTEGER NOT NULL, + variable TEXT NOT NULL, + operator TEXT NOT NULL, + value_text TEXT, + value_numeric REAL, + value_json TEXT NOT NULL, + unit TEXT, + role TEXT NOT NULL, + label TEXT, + PRIMARY KEY (fact_key, ordinal) + ); + + CREATE TABLE fact_source_cells ( + fact_key TEXT NOT NULL, + source_cell_key TEXT NOT NULL, + ordinal INTEGER NOT NULL, + PRIMARY KEY (fact_key, source_cell_key) + ); + + CREATE TABLE fact_source_rows ( + fact_key TEXT NOT NULL, + source_row_key TEXT NOT NULL, + ordinal INTEGER NOT NULL, + PRIMARY KEY (fact_key, source_row_key) + ); + """ + ) + + def fact( + key: str, + *, + concept: str, + value: float, + aggregation: str, + income_range: str, + unit: str, + source_concept: str | None = None, + ) -> tuple[Any, ...]: + return ( + key, + f"irs_soi.ty2023.table_1_1.{income_range}.{concept.rsplit('.', 1)[-1]}", + value, + str(int(value)) if float(value).is_integer() else str(value), + json.dumps(value), + "2023", + "country", + "0100000US", + "United States", + concept, + source_concept, + "exact" if source_concept else None, + "arch-us" if source_concept else None, + "https://uscode.house.gov/view.xhtml?req=(title:26%20section:62%20edition:prelim)" + if source_concept + else None, + "IRS SOI Table 1.1 reports adjusted gross income.", + "tax_year_2023" if source_concept else None, + unit, + aggregation, + "all_individual_income_tax_returns", + json.dumps({"filing_status": "all", "income_range": income_range}), + f"{income_range} {concept}", + "irs_soi", + "Publication 1304 Table 1.1", + "https://www.irs.gov/pub/irs-soi/23in11si.xls", + "Source-package aggregate fact fixture.", + ) + + conn.executemany( + """ + INSERT INTO aggregate_facts ( + fact_key, + source_record_id, + value_numeric, + value_text, + value_json, + period_value, + geography_level, + geography_id, + geography_name, + measure_concept, + measure_source_concept, + measure_concept_relation, + measure_concept_authority, + measure_concept_evidence_url, + measure_concept_evidence_notes, + measure_legal_vintage, + measure_unit, + aggregation_method, + domain, + filters_json, + label, + source_name, + source_table, + source_url, + source_method_notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + fact( + "arch.fact.v1:all-count", + concept="irs_soi.individual_income_tax_returns", + value=160_602_107, + aggregation="count", + income_range="all", + unit="count", + ), + fact( + "arch.fact.v1:all-agi", + concept="us:statutes/26/62#adjusted_gross_income", + source_concept="irs_soi.adjusted_gross_income", + value=15_286_017_359_000, + aggregation="sum", + income_range="all", + unit="usd", + ), + fact( + "arch.fact.v1:all-tax", + concept="irs_soi.total_income_tax", + value=2_147_909_818_000, + aggregation="sum", + income_range="all", + unit="usd", + ), + fact( + "arch.fact.v1:1-to-5k-count", + concept="irs_soi.individual_income_tax_returns", + value=7_357_751, + aggregation="count", + income_range="1_to_5k", + unit="count", + ), + fact( + "arch.fact.v1:1-to-5k-agi", + concept="us:statutes/26/62#adjusted_gross_income", + source_concept="irs_soi.adjusted_gross_income", + value=20_372_694_000, + aggregation="sum", + income_range="1_to_5k", + unit="usd", + ), + ], + ) + conn.executemany( + """ + INSERT INTO aggregate_constraints ( + fact_key, + ordinal, + variable, + operator, + value_text, + value_numeric, + value_json, + unit, + role, + label + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + key, + ordinal, + "us:statutes/26/62#adjusted_gross_income", + operator, + str(value), + float(value), + json.dumps(value), + "usd", + "filter", + "Adjusted gross income bound", + ) + for key in ("arch.fact.v1:1-to-5k-count", "arch.fact.v1:1-to-5k-agi") + for ordinal, operator, value in ((0, ">=", 1), (1, "<", 5000)) + ], + ) + conn.executemany( + """ + INSERT INTO fact_source_cells (fact_key, source_cell_key, ordinal) + VALUES (?, ?, ?) + """, + [ + ("arch.fact.v1:all-agi", "arch.source_cell.v1:agi", 0), + ("arch.fact.v1:all-count", "arch.source_cell.v1:count", 0), + ], + ) + conn.execute( + """ + INSERT INTO fact_source_rows (fact_key, source_row_key, ordinal) + VALUES (?, ?, ?) + """, + ("arch.fact.v1:all-agi", "arch.source_row.v1:all", 0), + ) + conn.commit() + conn.close() + + +def _insert_arch_table_1_1_reference_totals( + path: Path, + *, + year: int, + return_count: float, + adjusted_gross_income: float, +) -> None: + conn = sqlite3.connect(path) + + def fact( + key: str, + *, + concept: str, + value: float, + aggregation: str, + unit: str, + source_concept: str | None = None, + ) -> tuple[Any, ...]: + return ( + key, + f"irs_soi.ty{year}.table_1_1.all.{concept.rsplit('.', 1)[-1]}", + value, + str(int(value)) if float(value).is_integer() else str(value), + json.dumps(value), + str(year), + "country", + "0100000US", + "United States", + concept, + source_concept, + "exact" if source_concept else None, + "arch-us" if source_concept else None, + "https://uscode.house.gov/view.xhtml?req=(title:26%20section:62%20edition:prelim)" + if source_concept + else None, + "IRS SOI Table 1.1 reports adjusted gross income.", + f"tax_year_{year}" if source_concept else None, + unit, + aggregation, + "all_individual_income_tax_returns", + json.dumps({"filing_status": "all", "income_range": "all"}), + f"{year} all {concept}", + "irs_soi", + "Publication 1304 Table 1.1", + f"https://www.irs.gov/pub/irs-soi/{str(year)[-2:]}in11si.xls", + "Source-package aggregate fact aging reference fixture.", + ) + + conn.executemany( + """ + INSERT INTO aggregate_facts ( + fact_key, + source_record_id, + value_numeric, + value_text, + value_json, + period_value, + geography_level, + geography_id, + geography_name, + measure_concept, + measure_source_concept, + measure_concept_relation, + measure_concept_authority, + measure_concept_evidence_url, + measure_concept_evidence_notes, + measure_legal_vintage, + measure_unit, + aggregation_method, + domain, + filters_json, + label, + source_name, + source_table, + source_url, + source_method_notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + fact( + f"arch.fact.v1:{year}-all-count", + concept="irs_soi.individual_income_tax_returns", + value=return_count, + aggregation="count", + unit="count", + ), + fact( + f"arch.fact.v1:{year}-all-agi", + concept="us:statutes/26/62#adjusted_gross_income", + source_concept="irs_soi.adjusted_gross_income", + value=adjusted_gross_income, + aggregation="sum", + unit="usd", + ), + ], + ) + conn.commit() + conn.close() + + +def _insert_arch_table_1_4_facts(path: Path) -> None: + conn = sqlite3.connect(path) + + def fact( + key: str, + *, + concept: str, + value: float, + aggregation: str, + income_range: str, + unit: str, + source_concept: str | None = None, + concept_relation: str | None = None, + ) -> tuple[Any, ...]: + slug = concept.split("#")[-1].rsplit(".", 1)[-1].replace(":", "_") + return ( + key, + f"irs_soi.ty2023.table_1_4.{income_range}.{slug}", + value, + str(int(value)) if float(value).is_integer() else str(value), + json.dumps(value), + "2023", + "country", + "0100000US", + "United States", + concept, + source_concept, + concept_relation, + "arch-us" if concept_relation else None, + "https://www.irs.gov/statistics/soi-tax-stats-individual-income-tax-returns-complete-report-publication-1304-basic-tables-part-1" + if concept_relation + else None, + "SOI Table 1.4 source concept alignment fixture." + if concept_relation + else None, + "tax_year_2023" if concept_relation else None, + unit, + aggregation, + "all_individual_income_tax_returns", + json.dumps({"filing_status": "all", "income_range": income_range}), + f"{income_range} {concept}", + "irs_soi", + "Publication 1304 Table 1.4", + "https://www.irs.gov/pub/irs-soi/23in14ar.xls", + "Source-package aggregate fact fixture.", + ) + + conn.executemany( + """ + INSERT INTO aggregate_facts ( + fact_key, + source_record_id, + value_numeric, + value_text, + value_json, + period_value, + geography_level, + geography_id, + geography_name, + measure_concept, + measure_source_concept, + measure_concept_relation, + measure_concept_authority, + measure_concept_evidence_url, + measure_concept_evidence_notes, + measure_legal_vintage, + measure_unit, + aggregation_method, + domain, + filters_json, + label, + source_name, + source_table, + source_url, + source_method_notes + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + fact( + "arch.fact.v1:t14-all-wages-returns", + concept="irs_soi.returns_with_total_wages", + value=132_000_000, + aggregation="count", + income_range="all", + unit="count", + ), + fact( + "arch.fact.v1:t14-all-wages-amount", + concept="us:statutes/26/62#input.wages", + source_concept="irs_soi.total_wages", + concept_relation="broad_match", + value=10_500_000_000_000, + aggregation="sum", + income_range="all", + unit="usd", + ), + fact( + "arch.fact.v1:t14-all-capital-gains-returns", + concept="irs_soi.returns_with_taxable_net_capital_gains", + value=27_000_000, + aggregation="count", + income_range="all", + unit="count", + ), + fact( + "arch.fact.v1:t14-all-capital-gains-amount", + concept="irs_soi.taxable_net_capital_gains", + value=1_100_000_000_000, + aggregation="sum", + income_range="all", + unit="usd", + ), + fact( + "arch.fact.v1:t14-all-ira-returns", + concept="irs_soi.returns_with_taxable_ira_distributions", + value=18_000_000, + aggregation="count", + income_range="all", + unit="count", + ), + fact( + "arch.fact.v1:t14-all-ira-amount", + concept="irs_soi.taxable_ira_distributions", + value=420_000_000_000, + aggregation="sum", + income_range="all", + unit="usd", + ), + fact( + "arch.fact.v1:t14-all-pension-returns", + concept="irs_soi.returns_with_taxable_pension_income", + value=30_000_000, + aggregation="count", + income_range="all", + unit="count", + ), + fact( + "arch.fact.v1:t14-all-pension-amount", + concept="irs_soi.taxable_pension_income", + value=740_000_000_000, + aggregation="sum", + income_range="all", + unit="usd", + ), + fact( + "arch.fact.v1:t14-all-uc-returns", + concept="irs_soi.returns_with_unemployment_compensation", + value=7_000_000, + aggregation="count", + income_range="all", + unit="count", + ), + fact( + "arch.fact.v1:t14-all-uc-amount", + concept="irs_soi.unemployment_compensation", + value=62_000_000_000, + aggregation="sum", + income_range="all", + unit="usd", + ), + fact( + "arch.fact.v1:t14-all-taxable-ss-returns", + concept="irs_soi.returns_with_taxable_social_security_benefits", + value=29_000_000, + aggregation="count", + income_range="all", + unit="count", + ), + fact( + "arch.fact.v1:t14-all-taxable-ss-amount", + concept="irs_soi.taxable_social_security_benefits", + value=510_000_000_000, + aggregation="sum", + income_range="all", + unit="usd", + ), + fact( + "arch.fact.v1:t14-1-to-5k-wages-amount", + concept="us:statutes/26/62#input.wages", + source_concept="irs_soi.total_wages", + concept_relation="broad_match", + value=4_200_000_000, + aggregation="sum", + income_range="1_to_5k", + unit="usd", + ), + ], + ) + conn.executemany( + """ + INSERT INTO aggregate_constraints ( + fact_key, + ordinal, + variable, + operator, + value_text, + value_numeric, + value_json, + unit, + role, + label + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + "arch.fact.v1:t14-1-to-5k-wages-amount", + ordinal, + "us:statutes/26/62#adjusted_gross_income", + operator, + str(value), + float(value), + json.dumps(value), + "usd", + "filter", + "Adjusted gross income bound", + ) + for ordinal, operator, value in ((0, ">=", 1), (1, "<", 5000)) + ], + ) + conn.executemany( + """ + INSERT INTO fact_source_cells (fact_key, source_cell_key, ordinal) + VALUES (?, ?, ?) + """, + [ + ( + "arch.fact.v1:t14-all-wages-amount", + "arch.source_cell.v1:t14-wages-amount", + 0, + ), + ( + "arch.fact.v1:t14-all-wages-returns", + "arch.source_cell.v1:t14-wages-returns", + 0, + ), + ], + ) + conn.execute( + """ + INSERT INTO fact_source_rows (fact_key, source_row_key, ordinal) + VALUES (?, ?, ?) + """, + ( + "arch.fact.v1:t14-all-wages-amount", + "arch.source_row.v1:t14-all", + 0, + ), + ) + conn.commit() + conn.close() + + +def _write_consumer_fact_jsonl(path: Path) -> None: + def row( + key: str, + *, + semantic_key: str, + concept: str, + value: float, + aggregation: str, + income_range: str, + unit: str, + source_concept: str | None = None, + ) -> dict[str, Any]: + observed_concept = source_concept or concept + constraints = [] + if income_range == "1_to_5k": + constraints = [ + { + "variable": "us:statutes/26/62#adjusted_gross_income", + "operator": ">=", + "value": 1, + "unit": "usd", + "role": "filter", + }, + { + "variable": "us:statutes/26/62#adjusted_gross_income", + "operator": "<", + "value": 5000, + "unit": "usd", + "role": "filter", + }, + ] + payload: dict[str, Any] = { + "schema_version": "arch.consumer_fact.v1", + "aggregate_fact_key": key, + "semantic_fact_key": semantic_key, + "legacy_fact_key": key.replace("aggregate_fact.v2", "fact.v1"), + "value": value, + "value_type": "integer", + "period": {"type": "tax_year", "value": 2023}, + "geography": { + "id": "0100000US", + "level": "country", + "vintage": "2020_census", + }, + "entity": {"name": "tax_unit", "role": "filing_unit"}, + "aggregation": {"method": aggregation}, + "observed_measure": { + "source_concept": observed_concept, + "source_measure_id": observed_concept.rsplit(".", 1)[-1], + "source_name": "irs_soi", + "source_table": "Publication 1304 Table 1.1", + "unit": unit, + }, + "dimensions": {"filing_status": "all", "income_range": income_range}, + "universe_constraints": { + "domain": "all_individual_income_tax_returns", + "constraints": constraints, + }, + "source": { + "source_name": "irs_soi", + "source_table": "Publication 1304 Table 1.1", + "url": "https://www.irs.gov/pub/irs-soi/23in11si.xls", + "method_notes": "Consumer-contract fact fixture.", + }, + "lineage": { + "source_record_id": ( + f"irs_soi.ty2023.table_1_1.{income_range}." + f"{observed_concept.rsplit('.', 1)[-1]}" + ), + "source_cell_keys": [f"arch.source_cell.v1:{key.rsplit(':', 1)[-1]}"], + "source_row_keys": [f"arch.source_row.v1:{income_range}"], + }, + "label": f"{income_range} {concept}", + } + if source_concept is not None: + payload["concept_alignment"] = { + "concept_alignment_key": "arch.concept_alignment.v2:agi", + "source_concept": source_concept, + "canonical_concept": concept, + "relation": "exact", + "authority": "arch-us", + "evidence_url": ( + "https://uscode.house.gov/view.xhtml?" + "req=(title:26%20section:62%20edition:prelim)" + ), + "evidence_notes": "IRS SOI Table 1.1 reports adjusted gross income.", + "legal_vintage": "tax_year_2023", + } + return payload + + rows = [ + row( + "arch.aggregate_fact.v2:all-count", + semantic_key="arch.semantic_fact.v2:all-count", + concept="irs_soi.individual_income_tax_returns", + value=160_602_107, + aggregation="count", + income_range="all", + unit="count", + ), + row( + "arch.aggregate_fact.v2:all-agi", + semantic_key="arch.semantic_fact.v2:all-agi", + concept="us:statutes/26/62#adjusted_gross_income", + source_concept="irs_soi.adjusted_gross_income", + value=15_286_017_359_000, + aggregation="sum", + income_range="all", + unit="usd", + ), + row( + "arch.aggregate_fact.v2:all-tax", + semantic_key="arch.semantic_fact.v2:all-tax", + concept="irs_soi.total_income_tax", + value=2_147_909_818_000, + aggregation="sum", + income_range="all", + unit="usd", + ), + row( + "arch.aggregate_fact.v2:1-to-5k-count", + semantic_key="arch.semantic_fact.v2:1-to-5k-count", + concept="irs_soi.individual_income_tax_returns", + value=7_357_751, + aggregation="count", + income_range="1_to_5k", + unit="count", + ), + row( + "arch.aggregate_fact.v2:1-to-5k-agi", + semantic_key="arch.semantic_fact.v2:1-to-5k-agi", + concept="us:statutes/26/62#adjusted_gross_income", + source_concept="irs_soi.adjusted_gross_income", + value=20_372_694_000, + aggregation="sum", + income_range="1_to_5k", + unit="usd", + ), + ] + path.write_text("\n".join(json.dumps(item, sort_keys=True) for item in rows) + "\n") + + +def _consumer_fact( + key: str, + *, + concept: str, + domain: str, + source_name: str, + source_table: str, + value: float, + period: dict[str, Any] | None = None, + geography: dict[str, Any] | None = None, + constraints: tuple[dict[str, Any], ...] = (), + unit: str = "count", +) -> dict[str, Any]: + return { + "schema_version": "arch.consumer_fact.v1", + "aggregate_fact_key": f"arch.aggregate_fact.v2:{key}", + "semantic_fact_key": f"arch.semantic_fact.v2:{key}", + "value": value, + "period": period or {"type": "calendar_year", "value": 2024}, + "geography": geography + or {"level": "country", "id": "0100000US", "name": "United States"}, + "observed_measure": { + "source_concept": concept, + "source_measure_id": concept.rsplit(".", 1)[-1], + "source_name": source_name, + "source_table": source_table, + "unit": unit, + }, + "universe_constraints": { + "domain": domain, + "constraints": list(constraints), + }, + "source": { + "source_name": source_name, + "source_table": source_table, + "url": f"https://example.test/{key}", + "method_notes": "US admin source-family fixture.", + }, + "lineage": { + "source_record_id": f"{source_name}.{key}", + "source_cell_keys": [f"arch.source_cell.v1:{key}"], + "source_row_keys": [f"arch.source_row.v1:{key}"], + }, + "label": key, + } + + +def _target_filter_tuples(target: Any) -> set[tuple[str, str, str]]: + return { + ( + str(target_filter.feature), + str(getattr(target_filter.operator, "value", target_filter.operator)), + str(target_filter.value), + ) + for target_filter in target.filters + } + + +def _normalize_target_behavior(target_set) -> list[tuple[Any, ...]]: + rows = [] + for target in target_set.targets: + filters = tuple( + sorted( + ( + str(target_filter.feature), + str( + getattr(target_filter.operator, "value", target_filter.operator) + ), + str(target_filter.value), + ) + for target_filter in target.filters + ) + ) + rows.append( + ( + str(target.entity.value), + str(getattr(target.aggregation, "value", target.aggregation)), + target.measure, + round(float(target.value), 6), + int(target.period), + str(target.source), + target.metadata["variable"], + target.metadata["geo_level"], + filters, + ) + ) + return sorted(rows) + + +def test_arch_fact_provider_matches_value_constraint_soi_targets( + tmp_path: Path, +) -> None: + value_db = tmp_path / "value_targets.db" + fact_db = tmp_path / "arch_facts.db" + _create_value_constraint_target_db(value_db) + _create_arch_fact_db(fact_db) + + query = TargetQuery(period=2023) + value_targets = ArchSQLiteTargetProvider(value_db).load_target_set(query) + fact_targets = ArchFactSQLiteTargetProvider(fact_db).load_target_set(query) + + assert _normalize_target_behavior(fact_targets) == _normalize_target_behavior( + value_targets + ) + + +def test_arch_consumer_fact_jsonl_provider_matches_value_constraint_soi_targets( + tmp_path: Path, +) -> None: + value_db = tmp_path / "value_targets.db" + consumer_jsonl = tmp_path / "consumer_facts.jsonl" + _create_value_constraint_target_db(value_db) + _write_consumer_fact_jsonl(consumer_jsonl) + + query = TargetQuery(period=2023) + value_targets = ArchSQLiteTargetProvider(value_db).load_target_set(query) + consumer_targets = ArchConsumerFactJSONLTargetProvider( + consumer_jsonl + ).load_target_set(query) + + assert _normalize_target_behavior(consumer_targets) == _normalize_target_behavior( + value_targets + ) + + +def test_arch_fact_provider_preserves_fact_provenance(tmp_path: Path) -> None: + fact_db = tmp_path / "arch_facts.db" + _create_arch_fact_db(fact_db) + + target_set = ArchFactSQLiteTargetProvider(fact_db).load_target_set( + TargetQuery( + period=2023, + provider_filters={ + "target_cells": [ + { + "variable": "adjusted_gross_income", + "geo_level": "national", + "domain_variable": None, + } + ] + }, + ) + ) + + all_agi = next( + target + for target in target_set.targets + if target.metadata["arch_aggregate_fact_key"] == "arch.fact.v1:all-agi" + ) + assert all_agi.metadata["arch_semantic_fact_key"].startswith( + "arch.semantic_fact.v1|us:statutes/26/62#adjusted_gross_income" + ) + assert all_agi.metadata["arch_source_record_id"].startswith( + "irs_soi.ty2023.table_1_1.all" + ) + assert all_agi.metadata["arch_source_cell_keys"] == ["arch.source_cell.v1:agi"] + assert all_agi.metadata["arch_source_row_keys"] == ["arch.source_row.v1:all"] + assert all_agi.metadata["arch_source_concept"] == "irs_soi.adjusted_gross_income" + assert all_agi.metadata["arch_concept_relation"] == "exact" + assert all_agi.metadata["unit"] == "usd" + + +def test_arch_consumer_fact_jsonl_provider_preserves_contract_keys( + tmp_path: Path, +) -> None: + consumer_jsonl = tmp_path / "consumer_facts.jsonl" + _write_consumer_fact_jsonl(consumer_jsonl) + + target_set = ArchConsumerFactJSONLTargetProvider(consumer_jsonl).load_target_set( + TargetQuery(period=2023) + ) + + all_agi = next( + target + for target in target_set.targets + if target.metadata["arch_aggregate_fact_key"] + == "arch.aggregate_fact.v2:all-agi" + ) + assert all_agi.metadata["arch_semantic_fact_key"] == "arch.semantic_fact.v2:all-agi" + assert all_agi.metadata["arch_source_record_id"].startswith( + "irs_soi.ty2023.table_1_1.all" + ) + assert all_agi.metadata["arch_source_cell_keys"] == ["arch.source_cell.v1:all-agi"] + assert all_agi.metadata["arch_source_row_keys"] == ["arch.source_row.v1:all"] + assert all_agi.metadata["arch_source_concept"] == "irs_soi.adjusted_gross_income" + assert all_agi.metadata["arch_concept_relation"] == "exact" + assert all_agi.metadata["unit"] == "usd" + + +def test_arch_consumer_fact_jsonl_provider_maps_income_tax_after_credits_returns( + tmp_path: Path, +) -> None: + consumer_jsonl = tmp_path / "consumer_facts.jsonl" + _write_consumer_fact_jsonl(consumer_jsonl) + rows = [json.loads(line) for line in consumer_jsonl.read_text().splitlines()] + row = json.loads(json.dumps(rows[0])) + row["aggregate_fact_key"] = "arch.aggregate_fact.v2:all-income-tax-returns" + row["semantic_fact_key"] = "arch.semantic_fact.v2:all-income-tax-returns" + row["legacy_fact_key"] = "arch.fact.v1:all-income-tax-returns" + row["value"] = 111_545_061 + row["observed_measure"] = { + **row["observed_measure"], + "source_concept": "irs_soi.returns_with_income_tax_after_credits", + "source_measure_id": "income_tax_after_credits_returns", + "unit": "count", + } + row["lineage"]["source_record_id"] = ( + "irs_soi.ty2023.table_1_1.all.income_tax_after_credits_returns" + ) + consumer_jsonl.write_text(json.dumps(row, sort_keys=True) + "\n") + + target_set = ArchConsumerFactJSONLTargetProvider(consumer_jsonl).load_target_set( + TargetQuery(period=2023) + ) + target = target_set.targets[0] + filters = { + ( + target_filter.feature, + str(getattr(target_filter.operator, "value", target_filter.operator)), + str(target_filter.value), + ) + for target_filter in target.filters + } + + assert target.metadata["arch_variable"] == "income_tax_liability_returns" + assert target.metadata["variable"] == "tax_unit_count" + assert target.aggregation.value == "count" + assert filters == { + ("income_tax", ">", "0"), + ("tax_unit_is_filer", "==", "1"), + } + + +def test_arch_consumer_fact_jsonl_provider_maps_historic_table_2_concepts( + tmp_path: Path, +) -> None: + consumer_jsonl = tmp_path / "consumer_facts.jsonl" + _write_consumer_fact_jsonl(consumer_jsonl) + template = json.loads(consumer_jsonl.read_text().splitlines()[0]) + rows = [] + for index, (concept, measure_id, value) in enumerate( + ( + ( + "irs_soi.returns_with_premium_tax_credit", + "premium_tax_credit_returns", + 7_841_370, + ), + ("irs_soi.earned_income_credit", "eitc_amount", 59_204_588_000), + ( + "irs_soi.tax_filer_individuals", + "tax_filer_individual_count", + 293_617_150, + ), + ), + start=1, + ): + row = json.loads(json.dumps(template)) + row["aggregate_fact_key"] = f"arch.aggregate_fact.v2:historic-table-2-{index}" + row["semantic_fact_key"] = f"arch.semantic_fact.v2:historic-table-2-{index}" + row["legacy_fact_key"] = f"arch.fact.v1:historic-table-2-{index}" + row["period"] = {"type": "tax_year", "value": 2022} + row["value"] = value + row["source"] = {**row["source"], "source_table": "Historic Table 2"} + row["observed_measure"] = { + **row["observed_measure"], + "source_concept": concept, + "source_measure_id": measure_id, + "source_table": "Historic Table 2", + "unit": "usd" if concept == "irs_soi.earned_income_credit" else "count", + } + row["aggregation"] = { + "method": "sum" if concept == "irs_soi.earned_income_credit" else "count" + } + row["lineage"]["source_record_id"] = ( + f"irs_soi.ty2022.historic_table_2.us.all.{measure_id}" + ) + rows.append(row) + consumer_jsonl.write_text( + "\n".join(json.dumps(row, sort_keys=True) for row in rows) + "\n" + ) + + target_set = ArchConsumerFactJSONLTargetProvider(consumer_jsonl).load_target_set( + TargetQuery(period=2022) + ) + targets_by_arch_variable = { + target.metadata["arch_variable"]: target for target in target_set.targets + } + + premium_tax_credit = targets_by_arch_variable["aca_ptc_returns"] + assert premium_tax_credit.metadata["variable"] == "tax_unit_count" + assert premium_tax_credit.aggregation.value == "count" + assert { + ( + target_filter.feature, + str(getattr(target_filter.operator, "value", target_filter.operator)), + str(target_filter.value), + ) + for target_filter in premium_tax_credit.filters + } == { + ("aca_ptc", ">", "0"), + ("tax_unit_is_filer", "==", "1"), + } + + eitc = targets_by_arch_variable["eitc_amount"] + assert eitc.metadata["variable"] == "eitc" + assert eitc.measure == "eitc" + assert eitc.aggregation.value == "sum" + + tax_filer_individuals = targets_by_arch_variable["tax_filer_individual_count"] + assert tax_filer_individuals.metadata["variable"] == "person_count" + assert tax_filer_individuals.aggregation.value == "count" + + +def test_arch_consumer_fact_jsonl_provider_maps_state_soi_rows( + tmp_path: Path, +) -> None: + consumer_jsonl = tmp_path / "consumer_facts.jsonl" + rows = [ + _consumer_fact( + "state-ca-agi-50k-75k", + concept="irs_soi.adjusted_gross_income", + domain="all_individual_income_tax_returns", + source_name="irs_soi", + source_table="Historic Table 2 state AGI facts", + period={"type": "tax_year", "value": 2022}, + geography={"level": "state", "id": "0400000US06", "name": "California"}, + value=123_456_000_000, + unit="usd", + constraints=( + { + "variable": "us:statutes/26/62#adjusted_gross_income", + "operator": ">=", + "value": 50_000, + "unit": "usd", + "role": "filter", + }, + { + "variable": "us:statutes/26/62#adjusted_gross_income", + "operator": "<", + "value": 75_000, + "unit": "usd", + "role": "filter", + }, + ), + ), + _consumer_fact( + "state-ca-eitc-amount", + concept="irs_soi.earned_income_credit", + domain="individual_income_tax_returns", + source_name="irs_soi", + source_table="Historic Table 2 state EITC totals", + period={"type": "tax_year", "value": 2022}, + geography={"level": "state", "id": "0400000US06", "name": "California"}, + value=5_770_703_000, + unit="usd", + ), + ] + consumer_jsonl.write_text( + "\n".join(json.dumps(row, sort_keys=True) for row in rows) + "\n" + ) + + target_set = ArchConsumerFactJSONLTargetProvider(consumer_jsonl).load_target_set( + TargetQuery(period=2022) + ) + targets_by_arch_variable = { + target.metadata["arch_variable"]: target for target in target_set.targets + } + + agi = targets_by_arch_variable["adjusted_gross_income"] + assert agi.metadata["variable"] == "adjusted_gross_income" + assert agi.metadata["geo_level"] == "state" + assert agi.metadata["geography_id"] == "0400000US06" + assert agi.measure == "adjusted_gross_income" + assert agi.aggregation.value == "sum" + assert _target_filter_tuples(agi) == { + ("tax_unit_is_filer", "==", "1"), + ("adjusted_gross_income", ">=", "50000"), + ("adjusted_gross_income", "<", "75000"), + ("state_fips", "==", "06"), + } + + eitc = targets_by_arch_variable["eitc_amount"] + assert eitc.metadata["variable"] == "eitc" + assert eitc.metadata["geo_level"] == "state" + assert eitc.measure == "eitc" + assert eitc.aggregation.value == "sum" + assert _target_filter_tuples(eitc) == { + ("tax_unit_is_filer", "==", "1"), + ("state_fips", "==", "06"), + } + + +def test_arch_consumer_fact_jsonl_provider_maps_state_broad_soi_concepts( + tmp_path: Path, +) -> None: + consumer_jsonl = tmp_path / "consumer_facts.jsonl" + geography = {"level": "state", "id": "0400000US06", "name": "California"} + rows = [ + _consumer_fact( + "state-ca-qualified-dividends", + concept="irs_soi.qualified_dividends", + domain="all_individual_income_tax_returns", + source_name="irs_soi", + source_table="Historic Table 2 state broad totals", + period={"type": "tax_year", "value": 2022}, + geography=geography, + value=93_000_000_000, + unit="usd", + ), + _consumer_fact( + "state-ca-schedule-c-returns", + concept="irs_soi.returns_with_schedule_c_income", + domain="all_individual_income_tax_returns", + source_name="irs_soi", + source_table="Historic Table 2 state broad totals", + period={"type": "tax_year", "value": 2022}, + geography=geography, + value=3_617_080, + ), + _consumer_fact( + "state-ca-partnership-scorp", + concept="irs_soi.partnership_scorp_income", + domain="all_individual_income_tax_returns", + source_name="irs_soi", + source_table="Historic Table 2 state broad totals", + period={"type": "tax_year", "value": 2022}, + geography=geography, + value=125_930_370_000, + unit="usd", + ), + _consumer_fact( + "state-ca-medical-dental", + concept="irs_soi.medical_dental_expense_deduction", + domain="all_individual_income_tax_returns", + source_name="irs_soi", + source_table="Historic Table 2 state broad totals", + period={"type": "tax_year", "value": 2022}, + geography=geography, + value=11_456_144_000, + unit="usd", + ), + _consumer_fact( + "state-ca-qbi-returns", + concept="irs_soi.returns_with_qualified_business_income_deduction", + domain="all_individual_income_tax_returns", + source_name="irs_soi", + source_table="Historic Table 2 state broad totals", + period={"type": "tax_year", "value": 2022}, + geography=geography, + value=499_080, + ), + _consumer_fact( + "state-ca-qbi", + concept="irs_soi.qualified_business_income_deduction", + domain="all_individual_income_tax_returns", + source_name="irs_soi", + source_table="Historic Table 2 state broad totals", + period={"type": "tax_year", "value": 2022}, + geography=geography, + value=4_400_400_000, + unit="usd", + ), + _consumer_fact( + "state-ca-rental-returns", + concept="irs_soi.returns_with_rental_royalty_income", + domain="all_individual_income_tax_returns", + source_name="irs_soi", + source_table="Historic Table 2 state broad totals", + period={"type": "tax_year", "value": 2022}, + geography=geography, + value=1_315_410, + ), + _consumer_fact( + "state-ca-rental", + concept="irs_soi.rental_royalty_income", + domain="all_individual_income_tax_returns", + source_name="irs_soi", + source_table="Historic Table 2 state broad totals", + period={"type": "tax_year", "value": 2022}, + geography=geography, + value=14_331_993_000, + unit="usd", + ), + _consumer_fact( + "state-ca-ctc-returns", + concept="irs_soi.returns_with_child_tax_credit", + domain="all_individual_income_tax_returns", + source_name="irs_soi", + source_table="Historic Table 2 state broad totals", + period={"type": "tax_year", "value": 2022}, + geography=geography, + value=4_626_510, + ), + _consumer_fact( + "state-ca-ctc", + concept="irs_soi.child_tax_credit", + domain="all_individual_income_tax_returns", + source_name="irs_soi", + source_table="Historic Table 2 state broad totals", + period={"type": "tax_year", "value": 2022}, + geography=geography, + value=9_724_583_000, + unit="usd", + ), + _consumer_fact( + "state-ca-actc-returns", + concept="irs_soi.returns_with_additional_child_tax_credit", + domain="all_individual_income_tax_returns", + source_name="irs_soi", + source_table="Historic Table 2 state broad totals", + period={"type": "tax_year", "value": 2022}, + geography=geography, + value=1_933_500, + ), + _consumer_fact( + "state-ca-actc", + concept="irs_soi.additional_child_tax_credit", + domain="all_individual_income_tax_returns", + source_name="irs_soi", + source_table="Historic Table 2 state broad totals", + period={"type": "tax_year", "value": 2022}, + geography=geography, + value=3_605_628_000, + unit="usd", + ), + ] + consumer_jsonl.write_text( + "\n".join(json.dumps(row, sort_keys=True) for row in rows) + "\n" + ) + + target_set = ArchConsumerFactJSONLTargetProvider(consumer_jsonl).load_target_set( + TargetQuery(period=2022) + ) + targets_by_arch_variable = { + target.metadata["arch_variable"]: target for target in target_set.targets + } + + qualified_dividends = targets_by_arch_variable["qualified_dividends_amount"] + assert qualified_dividends.metadata["variable"] == "qualified_dividend_income" + assert qualified_dividends.measure == "qualified_dividend_income" + assert _target_filter_tuples(qualified_dividends) == { + ("tax_unit_is_filer", "==", "1"), + ("state_fips", "==", "06"), + } + + schedule_c_returns = targets_by_arch_variable["schedule_c_income_returns"] + assert schedule_c_returns.metadata["variable"] == "self_employment_income" + assert schedule_c_returns.aggregation.value == "count" + assert ("self_employment_income", ">", "0") in _target_filter_tuples( + schedule_c_returns + ) + + partnership = targets_by_arch_variable["partnership_scorp_income_amount"] + assert ( + partnership.metadata["variable"] == "tax_unit_partnership_s_corp_income" + ) + assert partnership.measure == "tax_unit_partnership_s_corp_income" + + medical = targets_by_arch_variable["medical_dental_expense_amount"] + assert medical.metadata["variable"] == "medical_expense_deduction" + assert medical.measure == "medical_expense_deduction" + + qbi = targets_by_arch_variable["qbi_amount"] + assert qbi.metadata["variable"] == "qualified_business_income_deduction" + assert qbi.measure == "qualified_business_income_deduction" + + qbi_claims = targets_by_arch_variable["qbi_claims"] + assert ( + qbi_claims.metadata["variable"] + == "qualified_business_income_deduction" + ) + assert qbi_claims.aggregation.value == "count" + assert ( + "qualified_business_income_deduction", + ">", + "0", + ) in _target_filter_tuples(qbi_claims) + + rental = targets_by_arch_variable["rental_royalty_income_amount"] + assert rental.metadata["variable"] == "rental_income" + assert rental.measure == "rental_income" + + rental_returns = targets_by_arch_variable["rental_royalty_income_returns"] + assert rental_returns.metadata["variable"] == "rental_income" + assert rental_returns.aggregation.value == "count" + assert ("rental_income", ">", "0") in _target_filter_tuples( + rental_returns + ) + + ctc = targets_by_arch_variable["ctc_amount"] + assert ctc.metadata["variable"] == "non_refundable_ctc" + assert ctc.measure == "non_refundable_ctc" + + ctc_claims = targets_by_arch_variable["ctc_claims"] + assert ctc_claims.metadata["variable"] == "non_refundable_ctc" + assert ctc_claims.aggregation.value == "count" + assert ("non_refundable_ctc", ">", "0") in _target_filter_tuples( + ctc_claims + ) + + actc = targets_by_arch_variable["actc_amount"] + assert actc.metadata["variable"] == "refundable_ctc" + assert actc.measure == "refundable_ctc" + + actc_claims = targets_by_arch_variable["actc_claims"] + assert actc_claims.metadata["variable"] == "refundable_ctc" + assert actc_claims.aggregation.value == "count" + assert ("refundable_ctc", ">", "0") in _target_filter_tuples( + actc_claims + ) + + +def test_arch_consumer_fact_jsonl_provider_maps_eitc_by_agi_and_children( + tmp_path: Path, +) -> None: + consumer_jsonl = tmp_path / "consumer_facts.jsonl" + row = _consumer_fact( + "eitc-three-child-50k-75k-returns", + concept="irs_soi.returns_with_total_earned_income_credit", + domain="individual_income_tax_returns_with_earned_income_credit", + source_name="irs_soi", + source_table="Publication 1304 Table 2.5 EITC by AGI and qualifying children", + period={"type": "tax_year", "value": 2022}, + value=97_411, + constraints=( + { + "variable": "us:statutes/26/62#adjusted_gross_income", + "operator": ">=", + "value": 50_000, + "unit": "usd", + "role": "filter", + }, + { + "variable": "us:statutes/26/62#adjusted_gross_income", + "operator": "<", + "value": 75_000, + "unit": "usd", + "role": "filter", + }, + { + "variable": "us.tax.earned_income_credit_qualifying_children", + "operator": "==", + "value": 3, + "unit": "count", + "role": "filter", + }, + ), + ) + consumer_jsonl.write_text(json.dumps(row, sort_keys=True) + "\n") + + target_set = ArchConsumerFactJSONLTargetProvider(consumer_jsonl).load_target_set( + TargetQuery(period=2022) + ) + target = target_set.targets[0] + + assert target.metadata["arch_variable"] == "eitc_claims" + assert target.metadata["variable"] == "eitc" + assert target.aggregation.value == "count" + assert _target_filter_tuples(target) == { + ("eitc", ">", "0"), + ("adjusted_gross_income", ">=", "50000"), + ("adjusted_gross_income", "<", "75000"), + ("eitc_child_count", "==", "3"), + } + + +def test_arch_consumer_fact_coverage_accepts_eitc_child_count_totals( + tmp_path: Path, +) -> None: + consumer_jsonl = tmp_path / "consumer_facts.jsonl" + rows = [ + _consumer_fact( + "eitc-one-child-total-returns", + concept="irs_soi.returns_with_total_earned_income_credit", + domain="individual_income_tax_returns_with_earned_income_credit", + source_name="irs_soi", + source_table=( + "Publication 1304 Table 2.5 EITC by AGI and qualifying children" + ), + period={"type": "tax_year", "value": 2022}, + value=8_490_417, + constraints=( + { + "variable": "us.tax.earned_income_credit_qualifying_children", + "operator": "==", + "value": 1, + "unit": "count", + "role": "filter", + }, + ), + ), + _consumer_fact( + "eitc-one-child-total-amount", + concept="irs_soi.total_earned_income_credit", + domain="individual_income_tax_returns_with_earned_income_credit", + source_name="irs_soi", + source_table=( + "Publication 1304 Table 2.5 EITC by AGI and qualifying children" + ), + period={"type": "tax_year", "value": 2022}, + value=21_182_747_000, + unit="usd", + constraints=( + { + "variable": "us.tax.earned_income_credit_qualifying_children", + "operator": "==", + "value": 1, + "unit": "count", + "role": "filter", + }, + ), + ), + ] + consumer_jsonl.write_text( + "\n".join(json.dumps(row, sort_keys=True) for row in rows) + "\n" + ) + provider = ArchConsumerFactJSONLTargetProvider(consumer_jsonl) + + report = summarize_arch_target_profile_coverage( + provider, + period=2022, + profile_name="custom", + target_cells=( + { + "variable": "eitc", + "geo_level": "national", + "domain_variable": "eitc_child_count", + }, + { + "variable": "tax_unit_count", + "geo_level": "national", + "domain_variable": "eitc_child_count", + }, + ), + ) + + assert report.covered_cell_count == 2 + + +def test_arch_consumer_fact_jsonl_provider_maps_us_admin_source_families( + tmp_path: Path, +) -> None: + consumer_jsonl = tmp_path / "consumer_facts.jsonl" + rows = [ + _consumer_fact( + "kff-aca-effectuated", + concept="cms_aca.marketplace_effectuated_enrollment", + domain="aca_marketplace_effectuated_enrollment", + source_name="kff", + source_table="Marketplace Effectuated Enrollment", + geography={"level": "state", "id": "0400000US06", "name": "California"}, + value=1_795_695, + ), + _consumer_fact( + "cms-medicaid-monthly", + concept="cms_medicaid.total_medicaid_enrollment", + domain="medicaid_chip_enrollment", + source_name="cms_medicaid", + source_table="Monthly Medicaid and CHIP Enrollment", + geography={"level": "state", "id": "0400000US06", "name": "California"}, + period={"type": "month", "value": "2024-12"}, + value=13_500_000, + ), + _consumer_fact( + "cms-nhe-medicaid", + concept="cms_nhe.medicaid_title_xix_expenditures", + domain="national_health_expenditures", + source_name="cms_nhe", + source_table="National Health Expenditures", + value=931_692_000_000, + unit="usd", + ), + _consumer_fact( + "snap-benefits", + concept="usda_snap.total_benefits", + domain="supplemental_nutrition_assistance_program", + source_name="usda_snap", + source_table="SNAP fiscal year benefits", + value=100_000_000_000, + unit="usd", + ), + _consumer_fact( + "snap-households", + concept="usda_snap.average_monthly_households", + domain="supplemental_nutrition_assistance_program", + source_name="usda_snap", + source_table="SNAP fiscal year participation", + geography={"level": "state", "id": "0400000US06", "name": "California"}, + value=2_100_000, + ), + _consumer_fact( + "tanf-cash", + concept="hhs_acf_tanf.cash_assistance_expenditures", + domain="tanf_cash_assistance", + source_name="hhs_acf_tanf", + source_table="TANF Financial Data", + period={"type": "fiscal_year", "value": 2024}, + value=7_788_317_475, + unit="usd", + ), + _consumer_fact( + "tanf-total-families", + concept="hhs_acf_tanf.average_monthly_tanf_total_families", + domain="tanf_caseload", + source_name="hhs_acf_tanf", + source_table="TANF Caseload Data 2024", + period={"type": "fiscal_year", "value": 2024}, + value=841_209, + ), + _consumer_fact( + "liheap-households", + concept="hhs_acf_liheap.households_served_by_state_programs", + domain="liheap_state_programs", + source_name="hhs_acf_liheap", + source_table="LIHEAP FY2024 National Profile (All States)", + period={"type": "fiscal_year", "value": 2024}, + value=5_876_646, + constraints=( + {"variable": "program", "operator": "==", "value": "liheap"}, + { + "variable": "administering_entity", + "operator": "==", + "value": "state_programs", + }, + ), + ), + _consumer_fact( + "stc-income-tax", + concept="census_stc.individual_income_tax_collections", + domain="state_government_tax_collections", + source_name="census_stc", + source_table="FY2024 STC Flat File item T40", + geography={"level": "state", "id": "0400000US06", "name": "California"}, + period={"type": "fiscal_year", "value": 2024}, + value=123_101_651_000, + unit="usd", + ), + _consumer_fact( + "ssa-retirement", + concept="ssa.annual_oasdi_or_ssi_payment_amount", + domain="social_security_and_ssi_payments", + source_name="ssa", + source_table="Annual Statistical Supplement", + value=1_111_728_000_000, + unit="usd", + constraints=( + { + "variable": "us_social_security_and_ssi.program_payment_type", + "operator": "==", + "value": "social_security_retirement_benefits", + }, + ), + ), + _consumer_fact( + "ssa-ssi", + concept="ssa.annual_oasdi_or_ssi_payment_amount", + domain="social_security_and_ssi_payments", + source_name="ssa", + source_table="Annual Statistical Supplement", + value=63_079_493_000, + unit="usd", + constraints=( + { + "variable": "us_social_security_and_ssi.program_payment_type", + "operator": "==", + "value": "ssi_payments", + }, + ), + ), + _consumer_fact( + "pep-age", + concept="census_pep.resident_population", + domain="resident_population", + source_name="census_pep", + source_table="Annual Estimates by Age and Sex", + value=18_599_314, + constraints=( + {"variable": "age", "operator": ">=", "value": 0, "unit": "years"}, + {"variable": "age", "operator": "<", "value": 5, "unit": "years"}, + ), + ), + _consumer_fact( + "aca-oep-average-aptc", + concept="cms_aca.average_monthly_aptc", + domain="aca_marketplace_qhp_selections", + source_name="cms_aca", + source_table="OEP State-Level Public Use File", + geography={"level": "state", "id": "0400000US06", "name": "California"}, + value=526, + unit="usd", + ), + _consumer_fact( + "w2-traditional-401k", + concept="irs_soi.form_w2_401k_elective_deferrals", + domain="form_w2_items", + source_name="irs_soi", + source_table="Form W-2 Statistics Table 4.B", + period={"type": "tax_year", "value": 2024}, + value=277_859_181_000, + unit="usd", + ), + _consumer_fact( + "w2-roth-401k", + concept="irs_soi.form_w2_designated_roth_401k_contributions", + domain="form_w2_items", + source_name="irs_soi", + source_table="Form W-2 Statistics Table 4.B", + period={"type": "tax_year", "value": 2024}, + value=32_302_509_000, + unit="usd", + ), + _consumer_fact( + "soi-keogh", + concept="irs_soi.payments_to_keogh_plan", + domain="all_individual_income_tax_returns", + source_name="irs_soi", + source_table="Publication 1304 Table 1.4", + period={"type": "tax_year", "value": 2024}, + value=30_130_848_000, + unit="usd", + ), + ] + consumer_jsonl.write_text( + "\n".join(json.dumps(row, sort_keys=True) for row in rows) + "\n" + ) + provider = ArchConsumerFactJSONLTargetProvider(consumer_jsonl) + + report = summarize_arch_target_profile_coverage( + provider, + period=2024, + profile_name="custom", + target_cells=( + { + "variable": "person_count", + "geo_level": "state", + "domain_variable": "aca_ptc", + }, + { + "variable": "person_count", + "geo_level": "state", + "domain_variable": "medicaid_enrolled", + }, + {"variable": "medicaid", "geo_level": "national", "domain_variable": None}, + {"variable": "snap", "geo_level": "national", "domain_variable": None}, + { + "variable": "household_count", + "geo_level": "state", + "domain_variable": "snap", + }, + {"variable": "tanf", "geo_level": "national", "domain_variable": None}, + { + "variable": "spm_unit_count", + "geo_level": "national", + "domain_variable": "tanf", + }, + { + "variable": "household_count", + "geo_level": "national", + "domain_variable": "spm_unit_energy_subsidy_reported", + }, + { + "variable": "state_income_tax", + "geo_level": "state", + "domain_variable": None, + }, + { + "variable": "social_security_retirement", + "geo_level": "national", + "domain_variable": None, + }, + {"variable": "ssi", "geo_level": "national", "domain_variable": None}, + { + "variable": "person_count", + "geo_level": "national", + "domain_variable": "age", + }, + { + "variable": "traditional_401k_contributions", + "geo_level": "national", + "domain_variable": None, + }, + { + "variable": "roth_401k_contributions", + "geo_level": "national", + "domain_variable": None, + }, + { + "variable": "self_employed_pension_contribution_ald", + "geo_level": "national", + "domain_variable": None, + }, + ), + ) + + assert report.target_cell_count == 15 + assert report.covered_cell_count == 15 + + target_set = provider.load_target_set(TargetQuery(period=2024)) + targets_by_arch_variable = { + target.metadata["arch_variable"]: target for target in target_set.targets + } + assert ( + targets_by_arch_variable["aca_marketplace_enrollment"].metadata["variable"] + == "person_count" + ) + assert ( + targets_by_arch_variable["medicaid_total_enrollment"].metadata["variable"] + == "person_count" + ) + assert targets_by_arch_variable["medicaid_benefits"].measure == "medicaid" + assert targets_by_arch_variable["snap_benefits"].measure == "snap" + assert ( + targets_by_arch_variable["snap_household_count"].metadata["variable"] + == "household_count" + ) + assert targets_by_arch_variable["tanf_cash_assistance"].measure == "tanf" + assert ( + targets_by_arch_variable["tanf_family_count"].metadata["variable"] + == "spm_unit_count" + ) + liheap_target = targets_by_arch_variable["liheap_household_count"] + assert liheap_target.metadata["variable"] == "household_count" + assert { + (target_filter.feature, target_filter.operator.value, target_filter.value) + for target_filter in liheap_target.filters + } == {("spm_unit_energy_subsidy_reported", ">", 0)} + assert ( + targets_by_arch_variable["state_individual_income_tax_collections"].measure + == "state_income_tax" + ) + assert ( + targets_by_arch_variable["social_security_retirement_benefits"].measure + == "social_security_retirement" + ) + assert targets_by_arch_variable["ssi_payments"].measure == "ssi" + traditional_401k = targets_by_arch_variable["traditional_401k_contributions"] + assert traditional_401k.measure == "traditional_401k_contributions" + assert traditional_401k.entity.value == "person" + roth_401k = targets_by_arch_variable["roth_401k_contributions"] + assert roth_401k.measure == "roth_401k_contributions" + assert roth_401k.entity.value == "person" + self_employed_pension = targets_by_arch_variable[ + "self_employed_pension_contribution_ald" + ] + assert self_employed_pension.measure == "self_employed_pension_contribution_ald" + assert self_employed_pension.entity.value == "tax_unit" + assert "aca_average_monthly_aptc" not in targets_by_arch_variable + + +def test_arch_consumer_fact_jsonl_provider_maps_decennial_sld_facts( + tmp_path: Path, +) -> None: + consumer_jsonl = tmp_path / "consumer_facts.jsonl" + rows = [ + _consumer_fact( + "census-cd119-sldu-population", + concept="census_decennial.resident_population", + domain="resident_population", + source_name="census_decennial", + source_table="2020 Census CD119 California SLD P1", + geography={ + "level": "state_legislative_district_upper", + "id": "610U900US06001", + "name": "State Senate District 1", + }, + value=943_108, + ), + _consumer_fact( + "census-cd119-sldl-households", + concept="census_decennial.occupied_housing_units", + domain="households", + source_name="census_decennial", + source_table="2020 Census CD119 California SLD H3", + geography={ + "level": "state_legislative_district_lower", + "id": "620L900US06080", + "name": "Assembly District 80", + }, + value=154_291, + ), + ] + consumer_jsonl.write_text( + "\n".join(json.dumps(row, sort_keys=True) for row in rows) + "\n" + ) + + provider = ArchConsumerFactJSONLTargetProvider(consumer_jsonl) + report = summarize_arch_target_profile_coverage( + provider, + period=2024, + profile_name="custom", + target_cells=( + { + "variable": "person_count", + "geo_level": "sldu", + "geographic_id": "CA-SLDU-001", + "domain_variable": None, + }, + { + "variable": "household_count", + "geo_level": "sldl", + "geographic_id": "CA-SLDL-080", + "domain_variable": None, + }, + ), + ) + + assert report.covered_cell_count == 2 + target_set = provider.load_target_set(TargetQuery(period=2024)) + targets_by_arch_variable = { + target.metadata["arch_variable"]: target for target in target_set.targets + } + population = targets_by_arch_variable["population"] + households = targets_by_arch_variable["household_count"] + + assert population.value == 943_108 + assert population.metadata["source"] == "CENSUS_DECENNIAL" + assert population.metadata["geo_level"] == "sldu" + assert { + ( + target_filter.feature, + str(getattr(target_filter.operator, "value", target_filter.operator)), + str(target_filter.value), + ) + for target_filter in population.filters + } == {("sldu_id", "==", "CA-SLDU-001")} + assert households.value == 154_291 + assert households.metadata["geo_level"] == "sldl" + assert { + ( + target_filter.feature, + str(getattr(target_filter.operator, "value", target_filter.operator)), + str(target_filter.value), + ) + for target_filter in households.filters + } == {("sldl_id", "==", "CA-SLDL-080")} + + +def test_arch_consumer_fact_jsonl_provider_normalizes_legacy_sld_ids( + tmp_path: Path, +) -> None: + consumer_jsonl = tmp_path / "consumer_facts.jsonl" + rows = [ + _consumer_fact( + "legacy-sldu-population", + concept="census_decennial.resident_population", + domain="resident_population", + source_name="census_decennial", + source_table="Legacy SLD fixture", + geography={ + "level": "state_senate_district", + "id": "CA-SD-1", + "name": "State Senate District 1", + }, + value=943_108, + ), + _consumer_fact( + "legacy-sldl-households", + concept="census_decennial.occupied_housing_units", + domain="households", + source_name="census_decennial", + source_table="Legacy SLD fixture", + geography={ + "level": "state_house_district", + "id": "NY-AD-65", + "name": "Assembly District 65", + }, + value=154_291, + ), + ] + consumer_jsonl.write_text( + "\n".join(json.dumps(row, sort_keys=True) for row in rows) + "\n" + ) + + provider = ArchConsumerFactJSONLTargetProvider(consumer_jsonl) + report = summarize_arch_target_profile_coverage( + provider, + period=2024, + profile_name="custom", + target_cells=( + { + "variable": "person_count", + "geo_level": "sldu", + "geographic_id": "06001", + "domain_variable": None, + }, + { + "variable": "household_count", + "geo_level": "sldl", + "geographic_id": "36065", + "domain_variable": None, + }, + ), + ) + + assert report.covered_cell_count == 2 + target_set = provider.load_target_set(TargetQuery(period=2024)) + targets_by_arch_variable = { + target.metadata["arch_variable"]: target for target in target_set.targets + } + + assert { + (target_filter.feature, str(target_filter.value)) + for target_filter in targets_by_arch_variable["population"].filters + } == {("sldu_id", "CA-SLDU-001")} + assert { + (target_filter.feature, str(target_filter.value)) + for target_filter in targets_by_arch_variable["household_count"].filters + } == {("sldl_id", "NY-SLDL-065")} + + +def test_arch_consumer_fact_jsonl_provider_maps_bea_full_population_amounts( + tmp_path: Path, +) -> None: + consumer_jsonl = tmp_path / "consumer_facts.jsonl" + rows = [ + _consumer_fact( + "bea-nipa-wages", + concept="bea_nipa.wages_and_salaries", + domain="personal_income", + source_name="bea", + source_table="NIPA annual total wages and salaries", + value=11_000_000_000_000, + unit="usd", + ), + _consumer_fact( + "bea-nipa-proprietors", + concept=( + "bea_nipa.proprietors_income_with_inventory_valuation_and_capital_consumption_adjustments" + ), + domain="personal_income", + source_name="bea", + source_table="NIPA annual personal income components", + value=2_000_000_000_000, + unit="usd", + ), + _consumer_fact( + "bea-regional-us-wages", + concept="bea_regional.wages_and_salaries", + domain="personal_income", + source_name="bea", + source_table="SAINC5N", + value=12_300_000_000_000, + unit="usd", + ), + _consumer_fact( + "bea-regional-us-proprietors", + concept="bea_regional.proprietors_income", + domain="personal_income", + source_name="bea", + source_table="SAINC5N", + value=2_020_000_000_000, + unit="usd", + ), + _consumer_fact( + "bea-regional-ca-wages", + concept="bea_regional.wages_and_salaries", + domain="personal_income", + source_name="bea", + source_table="SAINC5N", + geography={"level": "state", "id": "0400000US06", "name": "California"}, + value=1_500_000_000_000, + unit="usd", + ), + _consumer_fact( + "bea-regional-ca-proprietors", + concept="bea_regional.proprietors_income", + domain="personal_income", + source_name="bea", + source_table="SAINC5N", + geography={"level": "state", "id": "0400000US06", "name": "California"}, + value=180_000_000_000, + unit="usd", + ), + _consumer_fact( + "bea-nipa-dividends", + concept="bea_nipa.personal_dividend_income", + domain="personal_income", + source_name="bea", + source_table="NIPA annual personal income components", + value=2_100_000_000_000, + unit="usd", + ), + _consumer_fact( + "bea-nipa-rental", + concept=( + "bea_nipa.rental_income_of_persons_with_capital_consumption_adjustment" + ), + domain="personal_income", + source_name="bea", + source_table="NIPA annual personal income components", + value=1_000_000_000_000, + unit="usd", + ), + _consumer_fact( + "bea-nipa-social-security", + concept="bea_nipa.social_security_benefits", + domain="personal_current_transfer_receipts", + source_name="bea", + source_table="NIPA annual personal income components", + value=1_500_000_000_000, + unit="usd", + ), + _consumer_fact( + "bea-nipa-medicaid", + concept="bea_nipa.medicaid_benefits", + domain="personal_current_transfer_receipts", + source_name="bea", + source_table="NIPA annual personal income components", + value=900_000_000_000, + unit="usd", + ), + _consumer_fact( + "bea-nipa-ui", + concept="bea_nipa.unemployment_insurance_benefits", + domain="personal_current_transfer_receipts", + source_name="bea", + source_table="NIPA annual personal income components", + value=30_000_000_000, + unit="usd", + ), + _consumer_fact( + "bea-nipa-saving-rate", + concept="bea_nipa.personal_saving_rate", + domain="personal_income", + source_name="bea", + source_table="NIPA annual personal income disposition", + value=3.8, + unit="percent", + ), + ] + consumer_jsonl.write_text( + "\n".join(json.dumps(row, sort_keys=True) for row in rows) + "\n" + ) + provider = ArchConsumerFactJSONLTargetProvider(consumer_jsonl) + + report = summarize_arch_target_profile_coverage( + provider, + period=2024, + profile_name="custom", + target_cells=( + { + "variable": "employment_income", + "geo_level": "national", + "domain_variable": None, + }, + { + "variable": "employment_income", + "geo_level": "state", + "domain_variable": None, + }, + { + "variable": "self_employment_income", + "geo_level": "national", + "domain_variable": None, + }, + { + "variable": "self_employment_income", + "geo_level": "state", + "domain_variable": None, + }, + { + "variable": "dividend_income", + "geo_level": "national", + "domain_variable": None, + }, + { + "variable": "rental_income", + "geo_level": "national", + "domain_variable": None, + }, + { + "variable": "social_security", + "geo_level": "national", + "domain_variable": None, + }, + {"variable": "medicaid", "geo_level": "national", "domain_variable": None}, + { + "variable": "unemployment_compensation", + "geo_level": "national", + "domain_variable": None, + }, + ), + ) + + assert report.target_cell_count == 9 + assert report.covered_cell_count == 7 + + target_set = provider.load_target_set(TargetQuery(period=2024)) + targets_by_source_record = { + target.metadata["arch_source_record_id"]: target + for target in target_set.targets + } + assert set(targets_by_source_record) == { + "bea.bea-nipa-wages", + "bea.bea-nipa-proprietors", + "bea.bea-regional-ca-wages", + "bea.bea-regional-ca-proprietors", + "bea.bea-nipa-dividends", + "bea.bea-nipa-rental", + "bea.bea-nipa-social-security", + "bea.bea-nipa-medicaid", + "bea.bea-nipa-ui", + } + assert targets_by_source_record["bea.bea-nipa-wages"].measure == ( + "employment_income" + ) + assert ( + targets_by_source_record["bea.bea-nipa-wages"].metadata["arch_variable"] + == "wages_salaries_amount" + ) + assert targets_by_source_record["bea.bea-nipa-wages"].filters == () + assert targets_by_source_record["bea.bea-nipa-proprietors"].measure == ( + "proprietors_income_amount" + ) + assert targets_by_source_record["bea.bea-nipa-proprietors"].metadata[ + "arch_concept" + ] == ( + "bea_nipa.proprietors_income_with_inventory_valuation_and_capital_consumption_adjustments" + ) + assert targets_by_source_record["bea.bea-nipa-proprietors"].filters == () + assert targets_by_source_record["bea.bea-regional-ca-wages"].measure == ( + "employment_income" + ) + assert ( + targets_by_source_record["bea.bea-regional-ca-wages"].metadata["arch_variable"] + == "wages_salaries_amount" + ) + assert targets_by_source_record["bea.bea-regional-ca-proprietors"].measure == ( + "proprietors_income_amount" + ) + assert { + ( + target_filter.feature, + str(getattr(target_filter.operator, "value", target_filter.operator)), + str(target_filter.value), + ) + for target_filter in targets_by_source_record[ + "bea.bea-regional-ca-wages" + ].filters + } == {("state_fips", "==", "06")} + assert targets_by_source_record["bea.bea-nipa-dividends"].source == "BEA" + assert "bea.bea-regional-us-wages" not in targets_by_source_record + assert "bea.bea-regional-us-proprietors" not in targets_by_source_record + assert not provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={"variables": ("self_employment_income",)}, + ) + ).targets + assert all( + target.metadata["arch_variable"] != "personal_saving_rate" + for target in target_set.targets + ) + + +def test_arch_target_smoke_cli_reports_consumer_fact_jsonl_counts( + tmp_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: + consumer_jsonl = tmp_path / "consumer_facts.jsonl" + _write_consumer_fact_jsonl(consumer_jsonl) + + exit_code = main_smoke( + [ + "--arch-targets-db", + str(consumer_jsonl), + "--period", + "2023", + "--expected-target-count", + "5", + "--no-compose-model-year-targets", + ] + ) + payload = json.loads(capsys.readouterr().out) + + assert exit_code == 0 + assert payload["valid"] + assert payload["target_count"] == 5 + assert payload["by_source"] == {"IRS_SOI": 5} + assert payload["by_variable"] == { + "adjusted_gross_income": 2, + "income_tax": 1, + "tax_unit_count": 2, + } + assert payload["errors"] == [] + assert payload["sample_targets"][0]["metadata"]["arch_aggregate_fact_key"] + + +def test_arch_target_smoke_cli_rejects_unexpected_target_count( + tmp_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: + consumer_jsonl = tmp_path / "consumer_facts.jsonl" + _write_consumer_fact_jsonl(consumer_jsonl) + + exit_code = main_smoke( + [ + "--arch-targets-db", + str(consumer_jsonl), + "--period", + "2023", + "--expected-target-count", + "6", + "--no-compose-model-year-targets", + ] + ) + payload = json.loads(capsys.readouterr().out) + + assert exit_code == 1 + assert not payload["valid"] + assert payload["target_count"] == 5 + assert payload["errors"] == [ + { + "code": "unexpected_target_count", + "message": "Expected 6 targets, loaded 5.", + } + ] + + +def test_arch_target_parity_cli_accepts_matching_consumer_fact_jsonl( + tmp_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: + value_db = tmp_path / "value_targets.db" + consumer_jsonl = tmp_path / "consumer_facts.jsonl" + _create_value_constraint_target_db(value_db) + _write_consumer_fact_jsonl(consumer_jsonl) + + exit_code = main_parity( + [ + "--incumbent-arch-targets-db", + str(value_db), + "--candidate-arch-targets-db", + str(consumer_jsonl), + "--period", + "2023", + "--no-compose-model-year-targets", + ] + ) + payload = json.loads(capsys.readouterr().out) + + assert exit_code == 0 + assert payload["valid"] + assert payload["counts"] == { + "candidate_only_count": 0, + "candidate_target_count": 5, + "duplicate_identity_count": 0, + "incumbent_only_count": 0, + "incumbent_target_count": 5, + "matched_count": 5, + "value_mismatch_count": 0, + } + assert payload["errors"] == [] + assert payload["rows"][0]["status"] == "matched" + + +def test_arch_target_parity_cli_rejects_value_mismatch( + tmp_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: + value_db = tmp_path / "value_targets.db" + consumer_jsonl = tmp_path / "consumer_facts.jsonl" + _create_value_constraint_target_db(value_db) + _write_consumer_fact_jsonl(consumer_jsonl) + rows = [json.loads(line) for line in consumer_jsonl.read_text().splitlines()] + rows[1]["value"] += 1_000 + consumer_jsonl.write_text( + "\n".join(json.dumps(row, sort_keys=True) for row in rows) + "\n" + ) + + exit_code = main_parity( + [ + "--incumbent-arch-targets-db", + str(value_db), + "--candidate-arch-targets-db", + str(consumer_jsonl), + "--period", + "2023", + "--no-compose-model-year-targets", + ] + ) + payload = json.loads(capsys.readouterr().out) + + assert exit_code == 1 + assert not payload["valid"] + assert payload["counts"]["matched_count"] == 4 + assert payload["counts"]["value_mismatch_count"] == 1 + assert payload["errors"][0]["code"] == "value_mismatch" + assert payload["errors"][0]["absolute_delta"] == 1_000 + + +def test_arch_target_parity_cli_rejects_duplicate_candidate_identity( + tmp_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: + value_db = tmp_path / "value_targets.db" + consumer_jsonl = tmp_path / "consumer_facts.jsonl" + _create_value_constraint_target_db(value_db) + _write_consumer_fact_jsonl(consumer_jsonl) + lines = consumer_jsonl.read_text().splitlines() + consumer_jsonl.write_text("\n".join([*lines, lines[0]]) + "\n") + + exit_code = main_parity( + [ + "--incumbent-arch-targets-db", + str(value_db), + "--candidate-arch-targets-db", + str(consumer_jsonl), + "--period", + "2023", + "--no-compose-model-year-targets", + ] + ) + payload = json.loads(capsys.readouterr().out) + + assert exit_code == 1 + assert not payload["valid"] + assert payload["counts"]["duplicate_identity_count"] == 1 + assert payload["errors"][0]["code"] == "duplicate_identity" + assert payload["errors"][0]["candidate_target_count"] == 2 + + +def test_arch_fact_provider_composes_latest_source_facts_to_model_year( + tmp_path: Path, +) -> None: + fact_db = tmp_path / "arch_facts.db" + _create_arch_fact_db(fact_db) + + target_set = ArchFactSQLiteTargetProvider(fact_db).load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "adjusted_gross_income", + "geo_level": "national", + "domain_variable": None, + } + ], + }, + ) + ) + + all_agi = next( + target + for target in target_set.targets + if target.metadata["arch_aggregate_fact_key"] == "arch.fact.v1:all-agi" + ) + assert all_agi.period == 2024 + assert all_agi.value == 15_286_017_359_000 + assert all_agi.metadata["arch_source_period"] == 2023 + assert all_agi.metadata["arch_model_period"] == 2024 + assert all_agi.metadata["arch_aging_amount_factor"] == 1 + assert all_agi.metadata["arch_aging_amount_method"] == ( + "source_fact_carry_forward_no_amount_reference" + ) + + +def test_arch_composite_source_facts_age_across_artifacts( + tmp_path: Path, +) -> None: + table_1_1_db = tmp_path / "arch_table_1_1.db" + table_1_4_db = tmp_path / "arch_table_1_4.db" + _create_arch_fact_db(table_1_1_db) + _insert_arch_table_1_1_reference_totals( + table_1_1_db, + year=2022, + return_count=160_602_107 / 1.1, + adjusted_gross_income=15_286_017_359_000 / 1.1, + ) + _create_arch_fact_db(table_1_4_db) + _insert_arch_table_1_4_facts(table_1_4_db) + provider = resolve_arch_sqlite_target_provider((table_1_1_db, table_1_4_db)) + + target_set = provider.load_target_set( + TargetQuery( + period=2024, + provider_filters={ + "sources": ["IRS_SOI"], + "target_cells": [ + { + "variable": "employment_income", + "geo_level": "national", + "domain_variable": "employment_income", + } + ], + }, + ) + ) + + wages = next( + target + for target in target_set.targets + if target.metadata["arch_aggregate_fact_key"] + == "arch.fact.v1:t14-all-wages-amount" + ) + assert wages.period == 2024 + assert wages.value == 10_500_000_000_000 * 1.1 + assert wages.metadata["arch_source_period"] == 2023 + assert wages.metadata["arch_aging_amount_factor"] == 1.1 + assert wages.metadata["arch_aging_amount_method"] == ( + "soi_total_agi_last_growth_extrapolation" + ) + assert wages.metadata["arch_source_db_path"] == str(table_1_4_db) + + +def test_arch_provider_resolver_detects_source_fact_schema(tmp_path: Path) -> None: + fact_db = tmp_path / "arch_facts.db" + _create_arch_fact_db(fact_db) + + provider = resolve_arch_sqlite_target_provider(fact_db) + + assert isinstance(provider, ArchFactSQLiteTargetProvider) + + +def test_arch_provider_resolver_detects_consumer_fact_jsonl(tmp_path: Path) -> None: + consumer_jsonl = tmp_path / "consumer_facts.jsonl" + _write_consumer_fact_jsonl(consumer_jsonl) + + provider = resolve_arch_sqlite_target_provider(consumer_jsonl) + + assert isinstance(provider, ArchConsumerFactJSONLTargetProvider) + + +def test_arch_provider_resolver_combines_multiple_source_fact_dbs( + tmp_path: Path, +) -> None: + table_1_1_db = tmp_path / "arch_table_1_1.db" + table_1_4_db = tmp_path / "arch_table_1_4.db" + _create_arch_fact_db(table_1_1_db) + _create_arch_fact_db(table_1_4_db) + _insert_arch_table_1_4_facts(table_1_4_db) + + provider = resolve_arch_sqlite_target_provider( + (str(table_1_1_db), str(table_1_4_db)) + ) + target_set = provider.load_target_set(TargetQuery(period=2023)) + + assert isinstance(provider, ArchCompositeSQLiteTargetProvider) + assert len(target_set.targets) == 18 + assert len({target.name for target in target_set.targets}) == 18 + assert {target.metadata["target_id"] for target in target_set.targets} == set( + range(1, 19) + ) + assert all( + "arch_source_db_path" in target.metadata for target in target_set.targets + ) + + +def test_us_pipeline_arch_target_provider_accepts_source_fact_db( + tmp_path: Path, +) -> None: + fact_db = tmp_path / "arch_facts.db" + _create_arch_fact_db(fact_db) + pipeline = USMicroplexPipeline( + USMicroplexBuildConfig( + arch_targets_db=str(fact_db), + calibration_target_source="arch", + ) + ) + + provider, source = pipeline._resolve_calibration_target_provider() + + assert source == "arch" + assert isinstance(provider, ArchFactSQLiteTargetProvider) + + +def test_us_pipeline_arch_target_provider_accepts_consumer_fact_jsonl( + tmp_path: Path, +) -> None: + consumer_jsonl = tmp_path / "consumer_facts.jsonl" + _write_consumer_fact_jsonl(consumer_jsonl) + pipeline = USMicroplexPipeline( + USMicroplexBuildConfig( + arch_targets_db=str(consumer_jsonl), + calibration_target_source="arch", + ) + ) + + provider, source = pipeline._resolve_calibration_target_provider() + + assert source == "arch" + assert isinstance(provider, ArchConsumerFactJSONLTargetProvider) + + +def test_us_pipeline_arch_target_provider_accepts_multiple_source_fact_dbs( + tmp_path: Path, +) -> None: + table_1_1_db = tmp_path / "arch_table_1_1.db" + table_1_4_db = tmp_path / "arch_table_1_4.db" + _create_arch_fact_db(table_1_1_db) + _create_arch_fact_db(table_1_4_db) + _insert_arch_table_1_4_facts(table_1_4_db) + pipeline = USMicroplexPipeline( + USMicroplexBuildConfig( + arch_targets_db=(str(table_1_1_db), str(table_1_4_db)), + calibration_target_source="arch", + ) + ) + + provider, source = pipeline._resolve_calibration_target_provider() + target_set = provider.load_target_set(TargetQuery(period=2023)) + + assert source == "arch" + assert isinstance(provider, ArchCompositeSQLiteTargetProvider) + assert len(target_set.targets) == 18 + + +def test_arch_fact_provider_maps_soi_table_1_4_income_source_facts( + tmp_path: Path, +) -> None: + fact_db = tmp_path / "arch_facts.db" + _create_arch_fact_db(fact_db) + _insert_arch_table_1_4_facts(fact_db) + + target_set = ArchFactSQLiteTargetProvider(fact_db).load_target_set( + TargetQuery(period=2023) + ) + table_1_4_targets = [ + target + for target in target_set.targets + if target.metadata["source_table"] == "Publication 1304 Table 1.4" + ] + + arch_variables = {target.metadata["arch_variable"] for target in table_1_4_targets} + assert arch_variables >= { + "wages_salaries_returns", + "wages_salaries_amount", + "net_capital_gains_returns", + "net_capital_gains_amount", + "taxable_ira_distributions_returns", + "taxable_ira_distributions_amount", + "taxable_pension_income_returns", + "taxable_pension_income_amount", + "unemployment_compensation_returns", + "unemployment_compensation_amount", + "taxable_social_security_returns", + "taxable_social_security_amount", + } + + wages_amount = next( + target + for target in table_1_4_targets + if target.metadata["arch_aggregate_fact_key"] + == "arch.fact.v1:t14-all-wages-amount" + ) + assert wages_amount.measure == "employment_income" + assert getattr(wages_amount.aggregation, "value", wages_amount.aggregation) == "sum" + assert getattr(wages_amount.entity, "value", wages_amount.entity) == "person" + assert wages_amount.metadata["variable"] == "employment_income" + assert wages_amount.metadata["arch_source_concept"] == "irs_soi.total_wages" + assert wages_amount.metadata["arch_concept_relation"] == "broad_match" + assert wages_amount.metadata["arch_source_cell_keys"] == [ + "arch.source_cell.v1:t14-wages-amount" + ] + + wages_returns = next( + target + for target in table_1_4_targets + if target.metadata["arch_aggregate_fact_key"] + == "arch.fact.v1:t14-all-wages-returns" + ) + assert wages_returns.measure is None + assert ( + getattr(wages_returns.aggregation, "value", wages_returns.aggregation) + == "count" + ) + assert getattr(wages_returns.entity, "value", wages_returns.entity) == "tax_unit" + assert wages_returns.metadata["variable"] == "employment_income" + assert ( + "employment_income", + ">", + "0", + ) in { + ( + str(target_filter.feature), + str(getattr(target_filter.operator, "value", target_filter.operator)), + str(target_filter.value), + ) + for target_filter in wages_returns.filters + } + + capital_gains_amount = next( + target + for target in table_1_4_targets + if target.metadata["arch_aggregate_fact_key"] + == "arch.fact.v1:t14-all-capital-gains-amount" + ) + assert ( + "net_capital_gains", + ">", + "0", + ) in { + ( + str(target_filter.feature), + str(getattr(target_filter.operator, "value", target_filter.operator)), + str(target_filter.value), + ) + for target_filter in capital_gains_amount.filters + } + + bracket_wages = next( + target + for target in table_1_4_targets + if target.metadata["arch_aggregate_fact_key"] + == "arch.fact.v1:t14-1-to-5k-wages-amount" + ) + assert { + ( + str(target_filter.feature), + str(getattr(target_filter.operator, "value", target_filter.operator)), + str(target_filter.value), + ) + for target_filter in bracket_wages.filters + } >= { + ("adjusted_gross_income", ">=", "1"), + ("adjusted_gross_income", "<", "5000"), + } + + +def test_arch_fact_profile_coverage_accepts_soi_table_1_4_facts( + tmp_path: Path, +) -> None: + fact_db = tmp_path / "arch_facts.db" + _create_arch_fact_db(fact_db) + _insert_arch_table_1_4_facts(fact_db) + provider = ArchFactSQLiteTargetProvider(fact_db) + + report = summarize_arch_target_profile_coverage( + provider, + period=2023, + profile_name="custom", + target_cells=( + { + "variable": "employment_income", + "geo_level": "national", + "domain_variable": "employment_income", + }, + { + "variable": "tax_unit_count", + "geo_level": "national", + "domain_variable": "employment_income", + }, + { + "variable": "taxable_social_security", + "geo_level": "national", + "domain_variable": "taxable_social_security", + }, + { + "variable": "tax_unit_count", + "geo_level": "national", + "domain_variable": "taxable_social_security", + }, + ), + ) + + assert report.target_cell_count == 4 + assert report.covered_cell_count == 4 + assert report.coverage_rate == 1 + + +def test_arch_composite_profile_coverage_combines_table_1_1_and_1_4( + tmp_path: Path, +) -> None: + table_1_1_db = tmp_path / "arch_table_1_1.db" + table_1_4_db = tmp_path / "arch_table_1_4.db" + _create_arch_fact_db(table_1_1_db) + _create_arch_fact_db(table_1_4_db) + _insert_arch_table_1_4_facts(table_1_4_db) + provider = resolve_arch_sqlite_target_provider((table_1_1_db, table_1_4_db)) + + report = summarize_arch_target_profile_coverage( + provider, + period=2023, + profile_name="custom", + target_cells=( + { + "variable": "adjusted_gross_income", + "geo_level": "national", + "domain_variable": None, + }, + { + "variable": "income_tax", + "geo_level": "national", + "domain_variable": None, + }, + { + "variable": "employment_income", + "geo_level": "national", + "domain_variable": "employment_income", + }, + { + "variable": "tax_unit_count", + "geo_level": "national", + "domain_variable": "employment_income", + }, + ), + ) + + assert report.target_cell_count == 4 + assert report.covered_cell_count == 4 + assert report.coverage_rate == 1 + + +def test_arch_fact_gap_queue_uses_source_fact_loaded_catalog( + tmp_path: Path, +) -> None: + fact_db = tmp_path / "arch_facts.db" + _create_arch_fact_db(fact_db) + _insert_arch_table_1_4_facts(fact_db) + provider = ArchFactSQLiteTargetProvider(fact_db) + + report = summarize_arch_target_gap_queue( + provider, + period=2023, + profile_name="custom", + target_cells=( + { + "variable": "employment_income", + "geo_level": "state", + "domain_variable": "employment_income", + }, + ), + ) + + assert report.row_count == 1 + assert report.rows[0].expected_arch_variable == "wages_salaries_amount" + assert report.rows[0].loader_status == "loaded_arch_variable_missing_geography" + + +def test_arch_fact_gap_queue_expected_filters_normalize_geography_ids( + tmp_path: Path, +) -> None: + fact_db = tmp_path / "arch_facts.db" + _create_arch_fact_db(fact_db) + provider = ArchFactSQLiteTargetProvider(fact_db) + + report = summarize_arch_target_gap_queue( + provider, + period=2023, + profile_name="custom", + target_cells=( + { + "variable": "person_count", + "geo_level": "state", + "geographic_id": "06", + "domain_variable": None, + }, + { + "variable": "person_count", + "geo_level": "sldu", + "geographic_id": "06001", + "domain_variable": None, + }, + { + "variable": "household_count", + "geo_level": "sldl", + "geographic_id": "36065", + "domain_variable": None, + }, + ), + ) + + filters_by_level = { + row.geo_level: { + item["feature"]: item["value"] + for item in row.expected_filters + if item["kind"] == "geography" + } + for row in report.rows + } + + assert filters_by_level == { + "state": {"state_fips": "06"}, + "sldu": {"sldu_id": "CA-SLDU-001"}, + "sldl": {"sldl_id": "NY-SLDL-065"}, + } diff --git a/tests/targets/test_census_blocks.py b/tests/targets/test_census_blocks.py new file mode 100644 index 0000000..7cee059 --- /dev/null +++ b/tests/targets/test_census_blocks.py @@ -0,0 +1,153 @@ +"""Tests for Census block-derived target providers.""" + +from __future__ import annotations + +import pandas as pd +from microplex.core import EntityType +from microplex.targets import TargetAggregation, TargetFilter, TargetQuery + +from microplex_us.targets.census_blocks import ( + CENSUS_BLOCK_POPULATION_SOURCE, + CensusBlockPopulationTargetProvider, + build_census_block_population_targets, +) + + +def _sample_blocks() -> pd.DataFrame: + return pd.DataFrame( + { + "geoid": [ + "060010201001000", + "060010201001001", + "060030101001000", + "360610001001000", + ], + "state_fips": ["06", "06", "06", "36"], + "county": ["001", "001", "003", "061"], + "tract": ["020100", "020100", "010100", "000100"], + "population": [10, 20, 5, 7], + "cd_id": ["CA-12", "CA-12", "CA-03", "NY-10"], + "sldu_id": ["CA-SD-09", "CA-SD-09", "CA-SD-01", "NY-SD-30"], + "sldl_id": ["CA-HD-18", "CA-HD-18", "CA-HD-05", "NY-AD-65"], + "cbsa_code": ["41860", "41860", None, "35620"], + "spm_metro_area": ["41860", "41860", "", "35620"], + } + ) + + +def test_build_census_block_population_targets_rolls_parent_geographies() -> None: + targets = build_census_block_population_targets( + _sample_blocks(), + geo_levels=("national", "state", "county", "tract", "cd", "sldu", "sldl"), + ) + + by_name = {target.name: target for target in targets} + + assert by_name["census_block_population_national"].value == 42 + assert by_name["census_block_population_state_06"].value == 35 + assert by_name["census_block_population_county_06001"].value == 30 + assert by_name["census_block_population_tract_06001020100"].value == 30 + assert by_name["census_block_population_cd_CA_12"].value == 30 + assert by_name["census_block_population_sldu_CA_SLDU_009"].value == 30 + assert by_name["census_block_population_sldl_CA_SLDL_018"].value == 30 + + county = by_name["census_block_population_county_06001"] + assert county.entity is EntityType.PERSON + assert county.aggregation is TargetAggregation.COUNT + assert county.source == CENSUS_BLOCK_POPULATION_SOURCE + assert county.filters == ( + TargetFilter(feature="county_fips", operator="==", value="06001"), + ) + assert county.metadata["variable"] == "person_count" + assert county.metadata["geo_level"] == "county" + assert county.metadata["geographic_id"] == "06001" + assert county.metadata["block_rollup"] is True + + +def test_census_block_provider_filters_by_geo_level_and_id() -> None: + provider = CensusBlockPopulationTargetProvider(block_probabilities=_sample_blocks()) + + target_set = provider.load_target_set( + TargetQuery( + provider_filters={ + "geo_levels": ["county", "cd"], + "geographic_ids": ["06001", "CA-03"], + "variables": ["person_count"], + }, + ) + ) + + targets = sorted(target_set.targets, key=lambda target: target.name) + + assert [target.name for target in targets] == [ + "census_block_population_cd_CA_03", + "census_block_population_county_06001", + ] + assert [target.value for target in targets] == [5, 30] + + +def test_census_block_provider_normalizes_legacy_sld_ids() -> None: + provider = CensusBlockPopulationTargetProvider(block_probabilities=_sample_blocks()) + + target_set = provider.load_target_set( + TargetQuery( + provider_filters={ + "geo_levels": ["sldu", "sldl"], + "geographic_ids": ["CA-SD-09", "NY-AD-65"], + } + ) + ) + by_name = {target.name: target for target in target_set.targets} + + assert by_name["census_block_population_sldu_CA_SLDU_009"].value == 30 + assert by_name["census_block_population_sldl_NY_SLDL_065"].value == 7 + + +def test_census_block_targets_use_geo_level_to_normalize_bare_sld_ids() -> None: + targets = build_census_block_population_targets( + _sample_blocks(), + geo_levels=("sldu", "sldl"), + geographic_ids=("06009", "36065"), + ) + by_name = {target.name: target for target in targets} + + assert by_name["census_block_population_sldu_CA_SLDU_009"].value == 30 + assert by_name["census_block_population_sldl_NY_SLDL_065"].value == 7 + + provider = CensusBlockPopulationTargetProvider(block_probabilities=_sample_blocks()) + target_set = provider.load_target_set( + TargetQuery( + provider_filters={ + "geo_levels": ["sldu", "sldl"], + "geographic_ids": ["06009", "36065"], + } + ) + ) + provider_by_name = {target.name: target for target in target_set.targets} + + assert provider_by_name["census_block_population_sldu_CA_SLDU_009"].value == 30 + assert provider_by_name["census_block_population_sldl_NY_SLDL_065"].value == 7 + + +def test_census_block_targets_resolve_all_before_bare_sld_filter_expansion() -> None: + targets = build_census_block_population_targets( + _sample_blocks(), + geo_levels=("all",), + geographic_ids=("06009",), + ) + + assert { + target.name: target.value + for target in targets + if target.metadata["geo_level"] == "sldu" + } == {"census_block_population_sldu_CA_SLDU_009": 30} + + +def test_census_block_provider_ignores_unrelated_variables() -> None: + provider = CensusBlockPopulationTargetProvider(block_probabilities=_sample_blocks()) + + target_set = provider.load_target_set( + TargetQuery(provider_filters={"variables": ["household_count"]}) + ) + + assert target_set.targets == [] diff --git a/tests/test_microdata_roles.py b/tests/test_microdata_roles.py new file mode 100644 index 0000000..842c98a --- /dev/null +++ b/tests/test_microdata_roles.py @@ -0,0 +1,127 @@ +"""Tests for source-specific microdata variable roles.""" + +from microplex_us.microdata_roles import ( + MicrodataVariableRole, + PolicyEngineUSVariableRole, + blocked_policyengine_us_direct_export_variables, + is_model_input_microdata_variable, + is_policyengine_us_direct_export_blocked, + microdata_variable_role, + non_model_input_microdata_variables, + policyengine_us_variable_role, +) + + +def test_puf_tax_credit_lines_are_reported_outputs_not_model_inputs(): + for variable in ( + "foreign_tax_credit", + "savers_credit", + "state_and_local_sales_or_income_tax", + "state_income_tax_paid", + "taxable_social_security", + "taxable_unemployment_compensation", + ): + assert ( + microdata_variable_role("irs_soi_puf_2024", variable) + is MicrodataVariableRole.CALCULATED_TAX_OUTPUT + ) + assert not is_model_input_microdata_variable("irs_soi_puf_2024", variable) + assert is_model_input_microdata_variable( + "irs_soi_puf_2024", + "taxable_interest_income", + ) + + +def test_non_model_input_microdata_variables_is_source_specific(): + assert non_model_input_microdata_variables( + "irs_soi_puf_2024", + ["savers_credit", "taxable_interest_income", "taxable_social_security"], + ) == ("savers_credit", "taxable_social_security") + assert non_model_input_microdata_variables( + "cps_asec_2024", + ["savers_credit"], + ) == () + + +def test_policyengine_us_variable_roles_separate_inputs_from_outputs(): + assert ( + policyengine_us_variable_role("takes_up_snap_if_eligible") + is PolicyEngineUSVariableRole.TAKEUP_INPUT + ) + assert ( + policyengine_us_variable_role("takes_up_eitc") + is PolicyEngineUSVariableRole.TAKEUP_INPUT + ) + assert ( + policyengine_us_variable_role( + "would_file_if_eligible_for_refundable_credit" + ) + is PolicyEngineUSVariableRole.TAKEUP_INPUT + ) + assert ( + policyengine_us_variable_role("would_file_taxes_voluntarily") + is PolicyEngineUSVariableRole.TAKEUP_INPUT + ) + assert ( + policyengine_us_variable_role("snap") + is PolicyEngineUSVariableRole.CALCULATED_OUTPUT + ) + assert ( + policyengine_us_variable_role("state_income_tax") + is PolicyEngineUSVariableRole.CALCULATED_OUTPUT + ) + assert ( + policyengine_us_variable_role("filing_status") + is PolicyEngineUSVariableRole.CALCULATED_OUTPUT + ) + assert ( + policyengine_us_variable_role("snap_reported") + is PolicyEngineUSVariableRole.REPORTED_OUTPUT + ) + assert ( + policyengine_us_variable_role("taxable_interest_income") + is PolicyEngineUSVariableRole.PRESERVED_INPUT + ) + assert ( + policyengine_us_variable_role("non_sch_d_capital_gains") + is PolicyEngineUSVariableRole.PRESERVED_INPUT + ) + assert ( + policyengine_us_variable_role("long_term_capital_gains_before_response") + is PolicyEngineUSVariableRole.PRESERVED_INPUT + ) + assert ( + policyengine_us_variable_role("net_capital_gains") + is PolicyEngineUSVariableRole.CALCULATED_OUTPUT + ) + + +def test_policyengine_direct_export_guard_blocks_calculated_and_reported_outputs(): + blocked = blocked_policyengine_us_direct_export_variables( + [ + "takes_up_snap_if_eligible", + "would_file_taxes_voluntarily", + "net_capital_gains", + "non_sch_d_capital_gains", + "filing_status", + "snap", + "snap_reported", + "state_income_tax", + "taxable_interest_income", + ] + ) + + assert blocked == ( + "filing_status", + "net_capital_gains", + "snap", + "snap_reported", + "state_income_tax", + ) + assert is_policyengine_us_direct_export_blocked("filing_status") + assert is_policyengine_us_direct_export_blocked("snap") + assert not is_policyengine_us_direct_export_blocked("takes_up_snap_if_eligible") + assert not is_policyengine_us_direct_export_blocked( + "would_file_taxes_voluntarily" + ) + assert not is_policyengine_us_direct_export_blocked("non_sch_d_capital_gains") diff --git a/uv.lock b/uv.lock index bbf679b..b6f57b3 100644 --- a/uv.lock +++ b/uv.lock @@ -63,6 +63,46 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl", hash = "sha256:15a3ebc0f43c2d0a50eeafea25e19046c68398e487b9f1f5b517f7c0f40f976a", size = 27047, upload-time = "2025-11-15T16:43:16.109Z" }, ] +[[package]] +name = "boto3" +version = "1.43.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, + { name = "jmespath" }, + { name = "s3transfer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0a/37/78c630d1308964aa9abf44951d9c4df776546ff37251ec2434944e205c4e/boto3-1.43.6.tar.gz", hash = "sha256:e6315effaf12b890b99956e6f8e2c3000a3f64e4ee91943cec3895ce9a836afb", size = 113153, upload-time = "2026-05-07T20:49:59.694Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/e2/3c2eef44f55eafab256836d1d9479bd6a74f70c26cbfdc0639a0e23e4327/boto3-1.43.6-py3-none-any.whl", hash = "sha256:179601ec2992726a718053bf41e43c223ceba397d31ceab11f64d9c910d9fc3a", size = 140502, upload-time = "2026-05-07T20:49:57.8Z" }, +] + +[[package]] +name = "botocore" +version = "1.43.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jmespath" }, + { name = "python-dateutil" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/79/a7/23d0f5028011455096a1eeac0ddf3cbe147b3e855e127342f8202552194d/botocore-1.43.6.tar.gz", hash = "sha256:b1e395b347356860398da42e61c808cf1e34b6fa7180cf2b9d87d986e1a06ba0", size = 15336070, upload-time = "2026-05-07T20:49:48.14Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/c8/6f47223840e8d8cfa8c9f7c0ec1b77970417f257fc885169ff4f6326ce09/botocore-1.43.6-py3-none-any.whl", hash = "sha256:b6d1fdbc6f65a5fe0b7e947823aa37535d3f39f3ba4d21110fab1f55bbbcc04b", size = 15017094, upload-time = "2026-05-07T20:49:44.964Z" }, +] + +[[package]] +name = "census" +version = "0.8.26" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f7/e0/c1cde674716d836139550542febca6231616d776119ae73705036d741da7/census-0.8.26.tar.gz", hash = "sha256:c7f9944e38952b4ecc137d14d083018a1c2734f64d2fbc4a8946f35fd51888c2", size = 13019, upload-time = "2026-04-08T13:44:19.24Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/80/13/13dcc8a3142c3c73e5228c05e1ce6567378bc5c673d5567c116d4a8162d7/census-0.8.26-py3-none-any.whl", hash = "sha256:c341bbce4bcdd75c0ddecf75f28ab7eda26a47d7fecc95c4690a2d8ee5b6a727", size = 11364, upload-time = "2026-04-08T13:44:18.333Z" }, +] + [[package]] name = "certifi" version = "2026.2.25" @@ -269,6 +309,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e6/ac/f9e4e731635192571f86f52d86234f537c7f8ca4f6917c56b29051c077ef/duckdb-1.5.1-cp314-cp314-win_arm64.whl", hash = "sha256:a3be2072315982e232bfe49c9d3db0a59ba67b2240a537ef42656cc772a887c7", size = 14370790, upload-time = "2026-03-23T12:12:12.497Z" }, ] +[[package]] +name = "et-xmlfile" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d3/38/af70d7ab1ae9d4da450eeec1fa3918940a5fafb9055e934af8d6eb0c2313/et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54", size = 17234, upload-time = "2024-10-25T17:25:40.039Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/8b/5fe2cc11fee489817272089c4203e679c63b570a5aaeb18d852ae3cbba6a/et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa", size = 18059, upload-time = "2024-10-25T17:25:39.051Z" }, +] + [[package]] name = "executing" version = "2.2.1" @@ -305,7 +354,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ac/48/f8b875fa7dea7dd9b33245e37f065af59df6a25af2f9561efa8d822fde51/greenlet-3.3.2-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:aa6ac98bdfd716a749b84d4034486863fd81c3abde9aa3cf8eff9127981a4ae4", size = 279120, upload-time = "2026-02-20T20:19:01.9Z" }, { url = "https://files.pythonhosted.org/packages/49/8d/9771d03e7a8b1ee456511961e1b97a6d77ae1dea4a34a5b98eee706689d3/greenlet-3.3.2-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ab0c7e7901a00bc0a7284907273dc165b32e0d109a6713babd04471327ff7986", size = 603238, upload-time = "2026-02-20T20:47:32.873Z" }, { url = "https://files.pythonhosted.org/packages/59/0e/4223c2bbb63cd5c97f28ffb2a8aee71bdfb30b323c35d409450f51b91e3e/greenlet-3.3.2-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:d248d8c23c67d2291ffd47af766e2a3aa9fa1c6703155c099feb11f526c63a92", size = 614219, upload-time = "2026-02-20T20:55:59.817Z" }, - { url = "https://files.pythonhosted.org/packages/94/2b/4d012a69759ac9d77210b8bfb128bc621125f5b20fc398bce3940d036b1c/greenlet-3.3.2-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ccd21bb86944ca9be6d967cf7691e658e43417782bce90b5d2faeda0ff78a7dd", size = 628268, upload-time = "2026-02-20T21:02:48.024Z" }, { url = "https://files.pythonhosted.org/packages/7a/34/259b28ea7a2a0c904b11cd36c79b8cef8019b26ee5dbe24e73b469dea347/greenlet-3.3.2-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b6997d360a4e6a4e936c0f9625b1c20416b8a0ea18a8e19cabbefc712e7397ab", size = 616774, upload-time = "2026-02-20T20:21:02.454Z" }, { url = "https://files.pythonhosted.org/packages/0a/03/996c2d1689d486a6e199cb0f1cf9e4aa940c500e01bdf201299d7d61fa69/greenlet-3.3.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:64970c33a50551c7c50491671265d8954046cb6e8e2999aacdd60e439b70418a", size = 1571277, upload-time = "2026-02-20T20:49:34.795Z" }, { url = "https://files.pythonhosted.org/packages/d9/c4/2570fc07f34a39f2caf0bf9f24b0a1a0a47bc2e8e465b2c2424821389dfc/greenlet-3.3.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1a9172f5bf6bd88e6ba5a84e0a68afeac9dc7b6b412b245dd64f52d83c81e55b", size = 1640455, upload-time = "2026-02-20T20:21:10.261Z" }, @@ -314,7 +362,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/ae/8bffcbd373b57a5992cd077cbe8858fff39110480a9d50697091faea6f39/greenlet-3.3.2-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:8d1658d7291f9859beed69a776c10822a0a799bc4bfe1bd4272bb60e62507dab", size = 279650, upload-time = "2026-02-20T20:18:00.783Z" }, { url = "https://files.pythonhosted.org/packages/d1/c0/45f93f348fa49abf32ac8439938726c480bd96b2a3c6f4d949ec0124b69f/greenlet-3.3.2-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:18cb1b7337bca281915b3c5d5ae19f4e76d35e1df80f4ad3c1a7be91fadf1082", size = 650295, upload-time = "2026-02-20T20:47:34.036Z" }, { url = "https://files.pythonhosted.org/packages/b3/de/dd7589b3f2b8372069ab3e4763ea5329940fc7ad9dcd3e272a37516d7c9b/greenlet-3.3.2-cp314-cp314-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c2e47408e8ce1c6f1ceea0dffcdf6ebb85cc09e55c7af407c99f1112016e45e9", size = 662163, upload-time = "2026-02-20T20:56:01.295Z" }, - { url = "https://files.pythonhosted.org/packages/cd/ac/85804f74f1ccea31ba518dcc8ee6f14c79f73fe36fa1beba38930806df09/greenlet-3.3.2-cp314-cp314-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:e3cb43ce200f59483eb82949bf1835a99cf43d7571e900d7c8d5c62cdf25d2f9", size = 675371, upload-time = "2026-02-20T21:02:49.664Z" }, { url = "https://files.pythonhosted.org/packages/d2/d8/09bfa816572a4d83bccd6750df1926f79158b1c36c5f73786e26dbe4ee38/greenlet-3.3.2-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:63d10328839d1973e5ba35e98cccbca71b232b14051fd957b6f8b6e8e80d0506", size = 664160, upload-time = "2026-02-20T20:21:04.015Z" }, { url = "https://files.pythonhosted.org/packages/48/cf/56832f0c8255d27f6c35d41b5ec91168d74ec721d85f01a12131eec6b93c/greenlet-3.3.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:8e4ab3cfb02993c8cc248ea73d7dae6cec0253e9afa311c9b37e603ca9fad2ce", size = 1619181, upload-time = "2026-02-20T20:49:36.052Z" }, { url = "https://files.pythonhosted.org/packages/0a/23/b90b60a4aabb4cec0796e55f25ffbfb579a907c3898cd2905c8918acaa16/greenlet-3.3.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:94ad81f0fd3c0c0681a018a976e5c2bd2ca2d9d94895f23e7bb1af4e8af4e2d5", size = 1687713, upload-time = "2026-02-20T20:21:11.684Z" }, @@ -323,7 +370,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/98/6d/8f2ef704e614bcf58ed43cfb8d87afa1c285e98194ab2cfad351bf04f81e/greenlet-3.3.2-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:e26e72bec7ab387ac80caa7496e0f908ff954f31065b0ffc1f8ecb1338b11b54", size = 286617, upload-time = "2026-02-20T20:19:29.856Z" }, { url = "https://files.pythonhosted.org/packages/5e/0d/93894161d307c6ea237a43988f27eba0947b360b99ac5239ad3fe09f0b47/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8b466dff7a4ffda6ca975979bab80bdadde979e29fc947ac3be4451428d8b0e4", size = 655189, upload-time = "2026-02-20T20:47:35.742Z" }, { url = "https://files.pythonhosted.org/packages/f5/2c/d2d506ebd8abcb57386ec4f7ba20f4030cbe56eae541bc6fd6ef399c0b41/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:b8bddc5b73c9720bea487b3bffdb1840fe4e3656fba3bd40aa1489e9f37877ff", size = 658225, upload-time = "2026-02-20T20:56:02.527Z" }, - { url = "https://files.pythonhosted.org/packages/d1/67/8197b7e7e602150938049d8e7f30de1660cfb87e4c8ee349b42b67bdb2e1/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:59b3e2c40f6706b05a9cd299c836c6aa2378cabe25d021acd80f13abf81181cf", size = 666581, upload-time = "2026-02-20T21:02:51.526Z" }, { url = "https://files.pythonhosted.org/packages/8e/30/3a09155fbf728673a1dea713572d2d31159f824a37c22da82127056c44e4/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b26b0f4428b871a751968285a1ac9648944cea09807177ac639b030bddebcea4", size = 657907, upload-time = "2026-02-20T20:21:05.259Z" }, { url = "https://files.pythonhosted.org/packages/f3/fd/d05a4b7acd0154ed758797f0a43b4c0962a843bedfe980115e842c5b2d08/greenlet-3.3.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:1fb39a11ee2e4d94be9a76671482be9398560955c9e568550de0224e41104727", size = 1618857, upload-time = "2026-02-20T20:49:37.309Z" }, { url = "https://files.pythonhosted.org/packages/6f/e1/50ee92a5db521de8f35075b5eff060dd43d39ebd46c2181a2042f7070385/greenlet-3.3.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:20154044d9085151bc309e7689d6f7ba10027f8f5a8c0676ad398b951913d89e", size = 1680010, upload-time = "2026-02-20T20:21:13.427Z" }, @@ -504,6 +550,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278, upload-time = "2024-11-11T01:41:40.175Z" }, ] +[[package]] +name = "jellyfish" +version = "1.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0b/14/fc5bdb637996df181e5c4fa3b15dcc27d33215e6c41753564ae453bdb40f/jellyfish-1.2.1.tar.gz", hash = "sha256:72d2fda61b23babe862018729be73c8b0dc12e3e6601f36f6e65d905e249f4db", size = 364417, upload-time = "2025-10-11T19:36:37.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/e6/75feeda1c3634525296aa56265db151f896005b139e177f8b1a285546a1f/jellyfish-1.2.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:4b3e3223aaad74e18aacc74775e01815e68af810258ceea6fa6a81b19f384312", size = 322958, upload-time = "2025-10-11T19:35:29.906Z" }, + { url = "https://files.pythonhosted.org/packages/0e/66/4b92bb55b545ebefbf085e45cbcda576d2a2a3dc48fd61dae469c27e73a6/jellyfish-1.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e967e67058b78189d2b20a9586c7720a05ec4a580d6a98c796cd5cd2b7b11303", size = 317859, upload-time = "2025-10-11T19:35:31.312Z" }, + { url = "https://files.pythonhosted.org/packages/fe/8e/9d0055f921c884605bf22a96e376b016993928126e8a4c7fd8698260fb4e/jellyfish-1.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32581c50b34a09889b2d96796170e53da313a1e7fde32be63c82e50e7e791e3c", size = 353222, upload-time = "2025-10-11T19:35:32.352Z" }, + { url = "https://files.pythonhosted.org/packages/4f/d2/deca58a62e57f7e2b2172ab39f522831279ee08ec0943fc0d0e33cd6e6f9/jellyfish-1.2.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:07b022412ebece96759006cb015d46b8218d7f896d8b327c6bbee784ddf38ed9", size = 362392, upload-time = "2025-10-11T19:35:33.305Z" }, + { url = "https://files.pythonhosted.org/packages/12/40/9a7f62d367f5a862950ce3598188fe0e22e11d1f5d6eaad6eda5adc354b0/jellyfish-1.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80a49eb817eaa6591f43a31e5c93d79904de62537f029907ef88c050d781a638", size = 360358, upload-time = "2025-10-11T19:35:34.585Z" }, + { url = "https://files.pythonhosted.org/packages/a5/e5/6b44a1058df3dfa3dd1174c9f86685c78f780d0b68851a057075aea14587/jellyfish-1.2.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:e1b990fb15985571616f7f40a12d6fa062897b19fb5359b6dec3cd811d802c24", size = 533945, upload-time = "2025-10-11T19:35:35.764Z" }, + { url = "https://files.pythonhosted.org/packages/50/4c/2397f43ad2692a1052299607838b41a4c2dd5707fde4ce459d686e763eb1/jellyfish-1.2.1-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:dd895cf63fac0a9f11b524fff810d9a6081dcf3c518b34172ac8684eb504dd43", size = 553707, upload-time = "2025-10-11T19:35:36.926Z" }, + { url = "https://files.pythonhosted.org/packages/de/aa/dc7cf053c8c40035791de1dc2f45b1f57772a14b0dc53318720e87073831/jellyfish-1.2.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:6d2bac5982d7a08759ea487bfa00149e6aa8a3be7cd43c4ed1be1e3505425c69", size = 523323, upload-time = "2025-10-11T19:35:37.981Z" }, + { url = "https://files.pythonhosted.org/packages/2b/1a/610c7f1f7777646322f489b5ed1e4631370c9fa4fb40a8246af71b496b6d/jellyfish-1.2.1-cp313-cp313-win32.whl", hash = "sha256:509355ebedec69a8bf0cc113a6bf9c01820d12fe2eea44f47dfa809faf2d5463", size = 209143, upload-time = "2025-10-11T19:35:39.276Z" }, + { url = "https://files.pythonhosted.org/packages/80/9a/6102b23b03a6df779fee76c979c0eb819b300c83b468900df78bb574b944/jellyfish-1.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:9c747ae5c0fb4bd519f6abbfe4bd704b2f1c63fd4dd3dbb8d8864478974e1571", size = 213466, upload-time = "2025-10-11T19:35:40.24Z" }, + { url = "https://files.pythonhosted.org/packages/89/c3/92190ff494881008ff127d67aba80245a5071ec7c3ff1181ceddc6c9d636/jellyfish-1.2.1-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:212aaf177236192a735bbbf5938717aa8518d14a25b08b015e47e783e70be060", size = 322379, upload-time = "2025-10-11T19:35:41.21Z" }, + { url = "https://files.pythonhosted.org/packages/d4/db/993c81f3e95e06e2a5cb71aaf9af063d8798a34c9715c8059707ddc12b86/jellyfish-1.2.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:b8986d9768daddd5e87abf513ae168ea0afe690a444d4c82d5b1b14b0d045820", size = 317270, upload-time = "2025-10-11T19:35:43.367Z" }, + { url = "https://files.pythonhosted.org/packages/fc/6a/0f521b098e136c43c7ae1e77db4a792f9e65167fe818820502996488b926/jellyfish-1.2.1-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5fa0ba0946f3c274f6a87aaa3c631dc70a363bd46cceea828ce777e8db653b6f", size = 352931, upload-time = "2025-10-11T19:35:44.402Z" }, + { url = "https://files.pythonhosted.org/packages/a0/c4/5d2242a650f890384b435610ef2962b1ac6091c070912a81a97020d2502a/jellyfish-1.2.1-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6e76b23431a667cd485fb562428d1ad29bae9fdd0fcdfb5a51cc8087bae0e88c", size = 362473, upload-time = "2025-10-11T19:35:45.427Z" }, + { url = "https://files.pythonhosted.org/packages/d5/fe/831fc45a4d3e497bccc4735809551320968360d14b89eb3d7cb892549316/jellyfish-1.2.1-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a058f4c6a591d5e5a47569f5648a26303ba19c76a960fef7e0beba2aa959e52e", size = 359772, upload-time = "2025-10-11T19:35:46.65Z" }, + { url = "https://files.pythonhosted.org/packages/b4/0f/d132265e299947e4462c1485f829a08a513c97c41bdfe758754e4a5c1dfe/jellyfish-1.2.1-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:6a49ce2a580edd3b16b69421137deef464e2f8907f9ef906d49950b1a52908c1", size = 533628, upload-time = "2025-10-11T19:35:47.691Z" }, + { url = "https://files.pythonhosted.org/packages/52/2a/d51dbf0aceb9b141dd8318ce6a41ab08a5deaae56be16a8bf3d8685ac817/jellyfish-1.2.1-cp314-cp314-musllinux_1_1_i686.whl", hash = "sha256:c85aa2bc76a36d92a3197f406f86636664d5b323727dfec4fa2842a8a24a06ae", size = 553614, upload-time = "2025-10-11T19:35:52.928Z" }, + { url = "https://files.pythonhosted.org/packages/f9/e1/fcc7c5919d871537942425f707b764af65b76c7b88377aa71083c5280e37/jellyfish-1.2.1-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:29cfa8bfb72aacf2d611a3313b358ed4d4140fa3d3efcffea750c8e7f8acb1aa", size = 523057, upload-time = "2025-10-11T19:35:54.423Z" }, + { url = "https://files.pythonhosted.org/packages/95/65/ee5289540b2015643493cc29b50350dbe63ca1977a902de5295a4df8c25a/jellyfish-1.2.1-cp314-cp314-win32.whl", hash = "sha256:f121218dc33fb318c34ddd889dc7362606ce1316af2bb63b73cc1df81523ca34", size = 209340, upload-time = "2025-10-11T19:35:55.69Z" }, + { url = "https://files.pythonhosted.org/packages/bc/e2/fa5de38380b0f5bd531b27a78acb0dc6118dab0b21f56d36008b829aa7de/jellyfish-1.2.1-cp314-cp314-win_amd64.whl", hash = "sha256:9a73b5c6425a70ebd440579a677eb4f03b327b2f59090db34e6c937aeea5aabd", size = 213399, upload-time = "2025-10-11T19:35:56.776Z" }, +] + [[package]] name = "jinja2" version = "3.1.6" @@ -516,6 +590,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, ] +[[package]] +name = "jmespath" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d3/59/322338183ecda247fb5d1763a6cbe46eff7222eaeebafd9fa65d4bf5cb11/jmespath-1.1.0.tar.gz", hash = "sha256:472c87d80f36026ae83c6ddd0f1d05d4e510134ed462851fd5f754c8c3cbb88d", size = 27377, upload-time = "2026-01-22T16:35:26.279Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/2f/967ba146e6d58cf6a652da73885f52fc68001525b4197effc174321d70b4/jmespath-1.1.0-py3-none-any.whl", hash = "sha256:a5663118de4908c91729bea0acadca56526eb2698e83de10cd116ae0f4e97c64", size = 20419, upload-time = "2026-01-22T16:35:24.919Z" }, +] + [[package]] name = "joblib" version = "1.5.3" @@ -701,8 +784,8 @@ wheels = [ [[package]] name = "microplex" -version = "0.1.0" -source = { editable = "../microplex" } +version = "0.2.0" +source = { git = "https://github.com/PolicyEngine/microplex.git?rev=1e0627182f9df40aacd7043c96956c2895bf9d30#1e0627182f9df40aacd7043c96956c2895bf9d30" } dependencies = [ { name = "httpx" }, { name = "huggingface-hub" }, @@ -724,41 +807,6 @@ calibrate = [ { name = "microcalibrate" }, ] -[package.metadata] -requires-dist = [ - { name = "cvxpy", marker = "extra == 'cvxpy'", specifier = ">=1.3" }, - { name = "httpx", specifier = ">=0.25" }, - { name = "huggingface-hub", specifier = ">=0.20" }, - { name = "jupyter-book", marker = "extra == 'docs'", specifier = ">=0.15,<2" }, - { name = "l0-python", marker = "extra == 'l0'", specifier = ">=0.4" }, - { name = "matplotlib", marker = "extra == 'benchmark'", specifier = ">=3.7" }, - { name = "microcalibrate", marker = "python_full_version >= '3.13' and extra == 'calibrate'", specifier = ">=0.22" }, - { name = "microplex", extras = ["dev", "benchmark", "docs", "calibrate"], marker = "extra == 'all'" }, - { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.0" }, - { name = "myst-nb", marker = "extra == 'docs'", specifier = ">=0.17" }, - { name = "numpy", specifier = ">=1.24" }, - { name = "pandas", specifier = ">=2.0" }, - { name = "polars", specifier = ">=0.20" }, - { name = "prdc", specifier = ">=0.1" }, - { name = "pyarrow", specifier = ">=14.0" }, - { name = "pydantic", specifier = ">=2.0" }, - { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0" }, - { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0" }, - { name = "pyyaml", specifier = ">=6.0" }, - { name = "quantile-forest", specifier = ">=1.3" }, - { name = "responses", marker = "extra == 'dev'", specifier = ">=0.20" }, - { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1" }, - { name = "scikit-learn", specifier = ">=1.3" }, - { name = "scikit-learn", marker = "extra == 'benchmark'", specifier = ">=1.3" }, - { name = "scipy", specifier = ">=1.10" }, - { name = "sdv", marker = "extra == 'benchmark'", specifier = ">=1.0" }, - { name = "seaborn", marker = "extra == 'benchmark'", specifier = ">=0.12" }, - { name = "sphinx", marker = "extra == 'docs'", specifier = ">=6.0" }, - { name = "sphinx-autodoc-typehints", marker = "extra == 'docs'", specifier = ">=1.23" }, - { name = "torch", specifier = ">=2.0" }, -] -provides-extras = ["dev", "cvxpy", "statmatch", "l0", "calibrate", "benchmark", "docs", "all"] - [[package]] name = "microplex-us" version = "0.2.0" @@ -777,19 +825,25 @@ dev = [ policyengine = [ { name = "microimpute", marker = "python_full_version < '3.15'" }, { name = "policyengine-us", marker = "python_full_version < '3.15'" }, + { name = "spm-calculator" }, +] +r2 = [ + { name = "boto3" }, ] [package.metadata] requires-dist = [ + { name = "boto3", marker = "extra == 'r2'", specifier = ">=1.34" }, { name = "duckdb", specifier = ">=1.2" }, { name = "microimpute", marker = "python_full_version >= '3.12' and python_full_version < '3.15' and extra == 'policyengine'", specifier = "==1.15.1" }, - { name = "microplex", extras = ["calibrate"], editable = "../microplex" }, + { name = "microplex", extras = ["calibrate"], git = "https://github.com/PolicyEngine/microplex.git?rev=1e0627182f9df40aacd7043c96956c2895bf9d30" }, { name = "policyengine-us", marker = "python_full_version >= '3.11' and python_full_version < '3.15' and extra == 'policyengine'", specifier = "==1.587.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0" }, { name = "requests", specifier = ">=2.31" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1" }, + { name = "spm-calculator", marker = "extra == 'policyengine'", specifier = ">=0.3.1" }, ] -provides-extras = ["dev", "policyengine"] +provides-extras = ["dev", "r2", "policyengine"] [[package]] name = "mpmath" @@ -1051,6 +1105,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a8/64/3708a90d1ebe202ffdeb7185f878a3c84d15c2b2c31858da2ce0583e2def/nvidia_nvtx-13.0.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cb7780edb6b14107373c835bf8b72e7a178bac7367e23da7acb108f973f157a6", size = 148878, upload-time = "2025-09-04T08:28:53.627Z" }, ] +[[package]] +name = "openpyxl" +version = "3.1.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "et-xmlfile" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3d/f9/88d94a75de065ea32619465d2f77b29a0469500e99012523b91cc4141cd1/openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050", size = 186464, upload-time = "2024-06-28T14:03:44.161Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/da/977ded879c29cbd04de313843e76868e6e13408a94ed6b987245dc7c8506/openpyxl-3.1.5-py2.py3-none-any.whl", hash = "sha256:5282c12b107bffeef825f4617dc029afaf41d0ea60823bbb665ef3079dc79de2", size = 250910, upload-time = "2024-06-28T14:03:41.161Z" }, +] + [[package]] name = "optuna" version = "4.8.0" @@ -1569,6 +1635,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8f/e8/726643a3ea68c727da31570bde48c7a10f1aa60eddd628d94078fec586ff/ruff-0.15.7-py3-none-win_arm64.whl", hash = "sha256:18e8d73f1c3fdf27931497972250340f92e8c861722161a9caeb89a58ead6ed2", size = 11023304, upload-time = "2026-03-19T16:26:51.669Z" }, ] +[[package]] +name = "s3transfer" +version = "0.17.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9b/ec/7c692cde9125b77e84b307354d4fb705f98b8ccad59a036d5957ca75bfc3/s3transfer-0.17.0.tar.gz", hash = "sha256:9edeb6d1c3c2f89d6050348548834ad8289610d886e5bf7b7207728bd43ce33a", size = 155337, upload-time = "2026-04-29T22:07:36.33Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/72/c6c32d2b657fa3dad1de340254e14390b1e334ce38268b7ad51abda3c8c2/s3transfer-0.17.0-py3-none-any.whl", hash = "sha256:ce3801712acf4ad3e89fb9990df97b4972e93f4b3b0004d214be5bce12814c20", size = 86811, upload-time = "2026-04-29T22:07:34.966Z" }, +] + [[package]] name = "scikit-learn" version = "1.8.0" @@ -1694,6 +1772,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload-time = "2021-05-16T22:03:41.177Z" }, ] +[[package]] +name = "spm-calculator" +version = "0.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "census" }, + { name = "numpy" }, + { name = "openpyxl" }, + { name = "pandas" }, + { name = "requests" }, + { name = "us" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/54/3b/b805c7e3e18c5b5c00f61b60112f9690d084c910e2481bc020f35390d8fd/spm_calculator-0.3.1.tar.gz", hash = "sha256:41f2f4d00d8c03422a7d57b800052e7760b88e463a5884802f83ed58d35c18c1", size = 75945, upload-time = "2026-04-17T19:52:39.707Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/1b/29f705f8a96fc7f55f2c07dfcddbbae78efdc6f174d25d4a0560fc3f5cf9/spm_calculator-0.3.1-py3-none-any.whl", hash = "sha256:52c57ecc5a240ec941b0f2b0d93bc4fa437ef6250e233baed8e11916fa9c1150", size = 57826, upload-time = "2026-04-17T19:52:38.444Z" }, +] + [[package]] name = "sqlalchemy" version = "2.0.49" @@ -1942,6 +2037,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" }, ] +[[package]] +name = "us" +version = "3.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jellyfish" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/35/12/06f87be706ccc5794569d14f903c2f755aa98e1a9d53e4e7e17d9986e9d1/us-3.2.0.tar.gz", hash = "sha256:cb223e85393dcc5171ead0dd212badc47f9667b23700fea3e7ea5f310d545338", size = 16046, upload-time = "2024-07-22T01:09:42.736Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/65/a8/1791660a87f03d10a3bce00401a66035999c91f5a9a6987569b84df5719d/us-3.2.0-py3-none-any.whl", hash = "sha256:571714ad6d473c72bbd2058a53404cdf4ecc0129e4f19adfcbeb4e2d7e3dc3e7", size = 13775, upload-time = "2024-07-22T01:09:41.432Z" }, +] + [[package]] name = "wcwidth" version = "0.6.0"