|
204 | 204 | } |
205 | 205 | ) |
206 | 206 |
|
| 207 | +ARCH_COMPONENT_SUM_TARGETS = { |
| 208 | + "salt_amount": ( |
| 209 | + "state_local_income_or_sales_tax_amount", |
| 210 | + "real_estate_taxes_amount", |
| 211 | + ), |
| 212 | +} |
| 213 | + |
207 | 214 | ARCH_NATIONAL_ROLLUP_STATE_FIPS = frozenset( |
208 | 215 | state_fips for state_fips in US_STATE_ABBR_BY_FIPS if state_fips != "72" |
209 | 216 | ) |
@@ -2462,10 +2469,146 @@ def _compose_arch_model_year_records( |
2462 | 2469 | def _with_state_to_national_rollup_records( |
2463 | 2470 | records: list[ArchTargetRecord], |
2464 | 2471 | ) -> list[ArchTargetRecord]: |
2465 | | - rollups = _state_to_national_rollup_records(records) |
| 2472 | + expanded_records = _with_component_sum_records(records) |
| 2473 | + rollups = _state_to_national_rollup_records(expanded_records) |
2466 | 2474 | if not rollups: |
| 2475 | + return expanded_records |
| 2476 | + return [*expanded_records, *rollups] |
| 2477 | + |
| 2478 | + |
| 2479 | +def _with_component_sum_records( |
| 2480 | + records: list[ArchTargetRecord], |
| 2481 | +) -> list[ArchTargetRecord]: |
| 2482 | + component_records = _component_sum_records(records) |
| 2483 | + if not component_records: |
2467 | 2484 | return records |
2468 | | - return [*records, *rollups] |
| 2485 | + return [*records, *component_records] |
| 2486 | + |
| 2487 | + |
| 2488 | +def _component_sum_records( |
| 2489 | + records: list[ArchTargetRecord], |
| 2490 | +) -> list[ArchTargetRecord]: |
| 2491 | + existing_keys = { |
| 2492 | + _component_sum_record_key(record, output_variable=record.variable) |
| 2493 | + for record in records |
| 2494 | + if record.target_type == "AMOUNT" |
| 2495 | + } |
| 2496 | + grouped: dict[ |
| 2497 | + tuple[Any, ...], |
| 2498 | + dict[str, ArchTargetRecord], |
| 2499 | + ] = {} |
| 2500 | + for record in records: |
| 2501 | + if record.target_type != "AMOUNT": |
| 2502 | + continue |
| 2503 | + for output_variable, component_variables in ARCH_COMPONENT_SUM_TARGETS.items(): |
| 2504 | + if record.variable not in component_variables: |
| 2505 | + continue |
| 2506 | + key = _component_sum_record_key(record, output_variable=output_variable) |
| 2507 | + if key in existing_keys: |
| 2508 | + continue |
| 2509 | + components = grouped.setdefault(key, {}) |
| 2510 | + if record.variable in components: |
| 2511 | + components.clear() |
| 2512 | + break |
| 2513 | + components[record.variable] = record |
| 2514 | + |
| 2515 | + composite_records: list[ArchTargetRecord] = [] |
| 2516 | + for key, components_by_variable in grouped.items(): |
| 2517 | + output_variable = str(key[0]) |
| 2518 | + component_variables = ARCH_COMPONENT_SUM_TARGETS[output_variable] |
| 2519 | + if set(components_by_variable) != set(component_variables): |
| 2520 | + continue |
| 2521 | + composite_records.append( |
| 2522 | + _component_records_to_sum_record( |
| 2523 | + key, |
| 2524 | + [ |
| 2525 | + components_by_variable[component_variable] |
| 2526 | + for component_variable in component_variables |
| 2527 | + ], |
| 2528 | + ) |
| 2529 | + ) |
| 2530 | + return composite_records |
| 2531 | + |
| 2532 | + |
| 2533 | +def _component_sum_record_key( |
| 2534 | + record: ArchTargetRecord, |
| 2535 | + *, |
| 2536 | + output_variable: str, |
| 2537 | +) -> tuple[Any, ...]: |
| 2538 | + return ( |
| 2539 | + output_variable, |
| 2540 | + record.target_type, |
| 2541 | + record.period, |
| 2542 | + _arch_record_geo_level(record), |
| 2543 | + record.geography_id, |
| 2544 | + tuple(sorted(record.constraints)), |
| 2545 | + _normalize_arch_source(record.source), |
| 2546 | + record.source_period, |
| 2547 | + record.aging_factors, |
| 2548 | + record.unit, |
| 2549 | + ) |
| 2550 | + |
| 2551 | + |
| 2552 | +def _component_records_to_sum_record( |
| 2553 | + key: tuple[Any, ...], |
| 2554 | + records: list[ArchTargetRecord], |
| 2555 | +) -> ArchTargetRecord: |
| 2556 | + first = records[0] |
| 2557 | + digest = sha1(repr(key).encode("utf-8")).hexdigest() |
| 2558 | + component_labels = ", ".join(record.variable for record in records) |
| 2559 | + source_tables = tuple( |
| 2560 | + dict.fromkeys(record.source_table for record in records if record.source_table) |
| 2561 | + ) |
| 2562 | + source_urls = tuple( |
| 2563 | + dict.fromkeys(record.source_url for record in records if record.source_url) |
| 2564 | + ) |
| 2565 | + source_row_keys = tuple( |
| 2566 | + dict.fromkeys( |
| 2567 | + source_row_key |
| 2568 | + for record in records |
| 2569 | + for source_row_key in ( |
| 2570 | + record.source_row_keys |
| 2571 | + or (str(record.source_target_id or record.target_id),) |
| 2572 | + ) |
| 2573 | + ) |
| 2574 | + ) |
| 2575 | + source_cell_keys = tuple( |
| 2576 | + dict.fromkeys( |
| 2577 | + source_cell_key |
| 2578 | + for record in records |
| 2579 | + for source_cell_key in record.source_cell_keys |
| 2580 | + ) |
| 2581 | + ) |
| 2582 | + notes = ( |
| 2583 | + "Microplex component sum matching PolicyEngine salt sources: " |
| 2584 | + f"{component_labels}." |
| 2585 | + ) |
| 2586 | + return replace( |
| 2587 | + first, |
| 2588 | + target_id=-int(digest[:12], 16), |
| 2589 | + stratum_id=-int(digest[12:20], 16), |
| 2590 | + variable=str(key[0]), |
| 2591 | + value=sum(record.value for record in records), |
| 2592 | + source_table=( |
| 2593 | + source_tables[0] |
| 2594 | + if len(source_tables) == 1 |
| 2595 | + else "Microplex component sum from Arch source tables" |
| 2596 | + ), |
| 2597 | + source_url=source_urls[0] if len(source_urls) == 1 else None, |
| 2598 | + notes=f"{first.notes} {notes}" if first.notes else notes, |
| 2599 | + source_record_id=f"microplex_component_sum:{digest[:16]}", |
| 2600 | + source_cell_keys=source_cell_keys, |
| 2601 | + source_row_keys=source_row_keys, |
| 2602 | + aggregate_fact_key=None, |
| 2603 | + semantic_fact_key=None, |
| 2604 | + source_target_id=None, |
| 2605 | + source_stratum_id=None, |
| 2606 | + concept=None, |
| 2607 | + source_concept=None, |
| 2608 | + concept_relation="sum_of_components", |
| 2609 | + concept_authority="policyengine_us", |
| 2610 | + concept_evidence_notes=notes, |
| 2611 | + ) |
2469 | 2612 |
|
2470 | 2613 |
|
2471 | 2614 | def _state_to_national_rollup_records( |
|
0 commit comments