|
6 | 6 | from typing import Any |
7 | 7 |
|
8 | 8 | import requests |
| 9 | +from microplex.core import EntityType |
| 10 | +from microplex.targets import ( |
| 11 | + FilterOperator, |
| 12 | + TargetAggregation, |
| 13 | + TargetFilter, |
| 14 | + TargetQuery, |
| 15 | + TargetSet, |
| 16 | + TargetSpec, |
| 17 | + apply_target_query, |
| 18 | +) |
| 19 | + |
| 20 | +from microplex_us.target_registry import ( |
| 21 | + US_TARGET_AVAILABLE_KEY, |
| 22 | + US_TARGET_CATEGORY_KEY, |
| 23 | + US_TARGET_GROUP_KEY, |
| 24 | + US_TARGET_IMPUTATION_KEY, |
| 25 | + US_TARGET_LEVEL_KEY, |
| 26 | + TargetCategory, |
| 27 | + TargetLevel, |
| 28 | +) |
| 29 | + |
| 30 | +SUPABASE_TARGET_ID_KEY = "supabase_target_id" |
| 31 | +SUPABASE_VARIABLE_KEY = "supabase_variable" |
| 32 | +SUPABASE_TARGET_TYPE_KEY = "supabase_target_type" |
| 33 | +SUPABASE_JURISDICTION_KEY = "supabase_jurisdiction" |
| 34 | +SUPABASE_STRATUM_NAME_KEY = "supabase_stratum_name" |
| 35 | +SUPABASE_SOURCE_INSTITUTION_KEY = "supabase_source_institution" |
| 36 | +SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY = "supabase_supported_by_column_map" |
| 37 | + |
| 38 | +_COUNT_ALL_VARIABLES = { |
| 39 | + "family_count", |
| 40 | + "household_count", |
| 41 | + "person_count", |
| 42 | + "spm_unit_count", |
| 43 | + "tax_unit_count", |
| 44 | +} |
| 45 | + |
| 46 | +_COUNT_ENTITY_MAP = { |
| 47 | + "family_count": EntityType.FAMILY, |
| 48 | + "household_count": EntityType.HOUSEHOLD, |
| 49 | + "person_count": EntityType.PERSON, |
| 50 | + "spm_unit_count": EntityType.SPM_UNIT, |
| 51 | + "tax_unit_count": EntityType.TAX_UNIT, |
| 52 | +} |
| 53 | + |
| 54 | +_INCOME_VARIABLES = { |
| 55 | + "alimony_income", |
| 56 | + "dividend_income", |
| 57 | + "employment_income", |
| 58 | + "farm_income", |
| 59 | + "interest_income", |
| 60 | + "long_term_capital_gains", |
| 61 | + "partnership_s_corp_income", |
| 62 | + "rental_income", |
| 63 | + "self_employment_income", |
| 64 | + "short_term_capital_gains", |
| 65 | + "social_security", |
| 66 | + "tax_exempt_pension_income", |
| 67 | + "taxable_pension_income", |
| 68 | + "unemployment_compensation", |
| 69 | +} |
| 70 | + |
| 71 | +_BENEFIT_VARIABLES = { |
| 72 | + "eitc_spending", |
| 73 | + "snap_households", |
| 74 | + "snap_spending", |
| 75 | + "social_security_spending", |
| 76 | + "ssi_spending", |
| 77 | + "unemployment_spending", |
| 78 | +} |
| 79 | + |
| 80 | +_HEALTH_VARIABLES = { |
| 81 | + "aca_enrollment", |
| 82 | + "health_insurance_premiums", |
| 83 | + "medicaid_enrollment", |
| 84 | + "other_medical_expenses", |
| 85 | +} |
| 86 | + |
| 87 | +_TAX_UNIT_VARIABLES = { |
| 88 | + "eitc_spending", |
| 89 | +} |
| 90 | + |
| 91 | +_HOUSEHOLD_VARIABLES = { |
| 92 | + "snap_households", |
| 93 | + "snap_spending", |
| 94 | +} |
9 | 95 |
|
10 | 96 |
|
11 | 97 | class SupabaseTargetLoader: |
@@ -217,13 +303,11 @@ def _parse_jurisdiction(self, jurisdiction: str) -> str | None: |
217 | 303 | return None |
218 | 304 |
|
219 | 305 | if jurisdiction.startswith("us-") and len(jurisdiction) == 5: |
220 | | - state = jurisdiction[3:].lower() |
221 | | - if len(state) == 2: |
222 | | - return state |
223 | | - |
224 | | - if jurisdiction.startswith("us-") and len(jurisdiction) == 5: |
225 | | - fips = jurisdiction[3:] |
226 | | - return self.STATE_FIPS.get(fips) |
| 306 | + suffix = jurisdiction[3:].lower() |
| 307 | + if suffix in self.STATE_FIPS: |
| 308 | + return self.STATE_FIPS[suffix] |
| 309 | + if suffix in _state_abbr_to_fips(self.STATE_FIPS): |
| 310 | + return suffix |
227 | 311 |
|
228 | 312 | return None |
229 | 313 |
|
@@ -286,4 +370,225 @@ def get_summary(self) -> dict[str, Any]: |
286 | 370 | } |
287 | 371 |
|
288 | 372 |
|
289 | | -__all__ = ["SupabaseTargetLoader"] |
| 373 | +class SupabaseTargetProvider(SupabaseTargetLoader): |
| 374 | + """Load Supabase targets as canonical core target specs.""" |
| 375 | + |
| 376 | + def load_target_set(self, query: TargetQuery | None = None) -> TargetSet: |
| 377 | + """Load a canonical target set through the core provider protocol.""" |
| 378 | + query = query or TargetQuery() |
| 379 | + provider_filters = query.provider_filters |
| 380 | + period = _query_period(query.period) |
| 381 | + institution = provider_filters.get("institution") |
| 382 | + target_types = _as_string_set(provider_filters.get("target_types")) |
| 383 | + include_unsupported = bool(provider_filters.get("include_unsupported", True)) |
| 384 | + include_states = bool(provider_filters.get("include_states", True)) |
| 385 | + |
| 386 | + if institution: |
| 387 | + rows = self.load_by_institution(str(institution), period=period) |
| 388 | + else: |
| 389 | + rows = self.load_all(period=period) |
| 390 | + |
| 391 | + specs: list[TargetSpec] = [] |
| 392 | + for row in rows: |
| 393 | + target_type = _target_type(row) |
| 394 | + if target_types and target_type not in target_types: |
| 395 | + continue |
| 396 | + |
| 397 | + spec = self.target_from_row(row) |
| 398 | + if ( |
| 399 | + not include_states |
| 400 | + and spec.metadata.get(US_TARGET_LEVEL_KEY) == TargetLevel.STATE.value |
| 401 | + ): |
| 402 | + continue |
| 403 | + if ( |
| 404 | + not include_unsupported |
| 405 | + and not spec.metadata[SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY] |
| 406 | + ): |
| 407 | + continue |
| 408 | + specs.append(spec) |
| 409 | + |
| 410 | + return apply_target_query( |
| 411 | + TargetSet(specs), |
| 412 | + TargetQuery( |
| 413 | + period=period if period is not None else query.period, |
| 414 | + entity=query.entity, |
| 415 | + names=query.names, |
| 416 | + metadata_filters=query.metadata_filters, |
| 417 | + ), |
| 418 | + ) |
| 419 | + |
| 420 | + def target_from_row(self, row: dict[str, Any]) -> TargetSpec: |
| 421 | + """Translate one Supabase target row into the canonical target IR.""" |
| 422 | + variable = str(row["variable"]) |
| 423 | + jurisdiction = _target_jurisdiction(row) |
| 424 | + state_fips, state_abbr = _jurisdiction_state(jurisdiction, self.STATE_FIPS) |
| 425 | + target_type = _target_type(row) |
| 426 | + aggregation = _aggregation_for_target_type(target_type) |
| 427 | + measure = self.CPS_COLUMN_MAP.get(variable, variable) |
| 428 | + supported = variable in self.CPS_COLUMN_MAP |
| 429 | + source = row.get("source") if isinstance(row.get("source"), dict) else {} |
| 430 | + source_name = source.get("name") or source.get("institution") |
| 431 | + source_institution = source.get("institution") |
| 432 | + stratum = row.get("stratum") if isinstance(row.get("stratum"), dict) else {} |
| 433 | + category = _category_for_variable(variable) |
| 434 | + level = TargetLevel.STATE if state_fips is not None else TargetLevel.NATIONAL |
| 435 | + |
| 436 | + filters: list[TargetFilter] = [] |
| 437 | + if aggregation is TargetAggregation.COUNT and variable not in _COUNT_ALL_VARIABLES: |
| 438 | + filters.append( |
| 439 | + TargetFilter( |
| 440 | + feature=measure, |
| 441 | + operator=FilterOperator.GT, |
| 442 | + value=0, |
| 443 | + ) |
| 444 | + ) |
| 445 | + |
| 446 | + if state_fips is not None: |
| 447 | + filters.append( |
| 448 | + TargetFilter( |
| 449 | + feature="state_fips", |
| 450 | + operator=FilterOperator.EQ, |
| 451 | + value=state_fips, |
| 452 | + ) |
| 453 | + ) |
| 454 | + |
| 455 | + metadata: dict[str, Any] = { |
| 456 | + SUPABASE_TARGET_ID_KEY: row.get("id"), |
| 457 | + SUPABASE_VARIABLE_KEY: variable, |
| 458 | + SUPABASE_TARGET_TYPE_KEY: target_type, |
| 459 | + SUPABASE_JURISDICTION_KEY: jurisdiction, |
| 460 | + SUPABASE_STRATUM_NAME_KEY: stratum.get("name"), |
| 461 | + SUPABASE_SOURCE_INSTITUTION_KEY: source_institution, |
| 462 | + SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY: supported, |
| 463 | + US_TARGET_LEVEL_KEY: level.value, |
| 464 | + US_TARGET_GROUP_KEY: _group_for_category(category), |
| 465 | + US_TARGET_AVAILABLE_KEY: supported, |
| 466 | + US_TARGET_IMPUTATION_KEY: not supported, |
| 467 | + } |
| 468 | + if category is not None: |
| 469 | + metadata[US_TARGET_CATEGORY_KEY] = category.value |
| 470 | + if state_fips is not None: |
| 471 | + metadata["state_fips"] = state_fips |
| 472 | + metadata["state_abbr"] = state_abbr |
| 473 | + |
| 474 | + return TargetSpec( |
| 475 | + name=_target_name(variable, jurisdiction), |
| 476 | + entity=_entity_for_variable(variable), |
| 477 | + value=float(row["value"]), |
| 478 | + period=int(row["period"]), |
| 479 | + measure=None if aggregation is TargetAggregation.COUNT else measure, |
| 480 | + aggregation=aggregation, |
| 481 | + filters=tuple(filters), |
| 482 | + source=source_name, |
| 483 | + units=_units_for_target_type(target_type), |
| 484 | + description=row.get("notes"), |
| 485 | + metadata=metadata, |
| 486 | + ) |
| 487 | + |
| 488 | + |
| 489 | +def _target_type(row: dict[str, Any]) -> str: |
| 490 | + return str(row.get("target_type") or "amount").lower() |
| 491 | + |
| 492 | + |
| 493 | +def _aggregation_for_target_type(target_type: str) -> TargetAggregation: |
| 494 | + if target_type == "count": |
| 495 | + return TargetAggregation.COUNT |
| 496 | + if target_type == "mean": |
| 497 | + return TargetAggregation.MEAN |
| 498 | + return TargetAggregation.SUM |
| 499 | + |
| 500 | + |
| 501 | +def _target_jurisdiction(row: dict[str, Any]) -> str: |
| 502 | + stratum = row.get("stratum") if isinstance(row.get("stratum"), dict) else {} |
| 503 | + return str(stratum.get("jurisdiction") or "us") |
| 504 | + |
| 505 | + |
| 506 | +def _target_name(variable: str, jurisdiction: str) -> str: |
| 507 | + if jurisdiction in {"us", "us-national"}: |
| 508 | + return variable |
| 509 | + return f"{variable}_{jurisdiction.replace('-', '_')}" |
| 510 | + |
| 511 | + |
| 512 | +def _query_period(period: int | str | None) -> int | None: |
| 513 | + if isinstance(period, int): |
| 514 | + return period |
| 515 | + if isinstance(period, str) and period.isdigit(): |
| 516 | + return int(period) |
| 517 | + return None |
| 518 | + |
| 519 | + |
| 520 | +def _as_string_set(value: Any) -> set[str]: |
| 521 | + if value is None: |
| 522 | + return set() |
| 523 | + if isinstance(value, str): |
| 524 | + return {value} |
| 525 | + return {str(item) for item in value} |
| 526 | + |
| 527 | + |
| 528 | +def _state_abbr_to_fips(state_fips: dict[str, str]) -> dict[str, str]: |
| 529 | + return {abbr: fips for fips, abbr in state_fips.items()} |
| 530 | + |
| 531 | + |
| 532 | +def _jurisdiction_state( |
| 533 | + jurisdiction: str, |
| 534 | + state_fips: dict[str, str], |
| 535 | +) -> tuple[str | None, str | None]: |
| 536 | + if not jurisdiction.startswith("us-") or len(jurisdiction) != 5: |
| 537 | + return None, None |
| 538 | + |
| 539 | + suffix = jurisdiction[3:].lower() |
| 540 | + if suffix in state_fips: |
| 541 | + return suffix, state_fips[suffix] |
| 542 | + |
| 543 | + abbr_to_fips = _state_abbr_to_fips(state_fips) |
| 544 | + if suffix in abbr_to_fips: |
| 545 | + return abbr_to_fips[suffix], suffix |
| 546 | + |
| 547 | + return None, None |
| 548 | + |
| 549 | + |
| 550 | +def _category_for_variable(variable: str) -> TargetCategory | None: |
| 551 | + if variable in _INCOME_VARIABLES: |
| 552 | + return TargetCategory.INCOME |
| 553 | + if variable in _BENEFIT_VARIABLES: |
| 554 | + return TargetCategory.BENEFITS |
| 555 | + if variable in _HEALTH_VARIABLES: |
| 556 | + return TargetCategory.HEALTH |
| 557 | + if variable.endswith("_tax") or variable.endswith("_credit"): |
| 558 | + return TargetCategory.TAX |
| 559 | + if variable in _COUNT_ALL_VARIABLES: |
| 560 | + return TargetCategory.DEMOGRAPHICS |
| 561 | + return None |
| 562 | + |
| 563 | + |
| 564 | +def _entity_for_variable(variable: str) -> EntityType: |
| 565 | + if variable in _COUNT_ENTITY_MAP: |
| 566 | + return _COUNT_ENTITY_MAP[variable] |
| 567 | + if variable in _TAX_UNIT_VARIABLES: |
| 568 | + return EntityType.TAX_UNIT |
| 569 | + if variable in _HOUSEHOLD_VARIABLES: |
| 570 | + return EntityType.HOUSEHOLD |
| 571 | + return EntityType.PERSON |
| 572 | + |
| 573 | + |
| 574 | +def _group_for_category(category: TargetCategory | None) -> str: |
| 575 | + if category is None: |
| 576 | + return "supabase_targets" |
| 577 | + return f"supabase_{category.value}" |
| 578 | + |
| 579 | + |
| 580 | +def _units_for_target_type(target_type: str) -> str | None: |
| 581 | + return "USD" if target_type == "amount" else None |
| 582 | + |
| 583 | + |
| 584 | +__all__ = [ |
| 585 | + "SUPABASE_JURISDICTION_KEY", |
| 586 | + "SUPABASE_SOURCE_INSTITUTION_KEY", |
| 587 | + "SUPABASE_STRATUM_NAME_KEY", |
| 588 | + "SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY", |
| 589 | + "SUPABASE_TARGET_ID_KEY", |
| 590 | + "SUPABASE_TARGET_TYPE_KEY", |
| 591 | + "SUPABASE_VARIABLE_KEY", |
| 592 | + "SupabaseTargetLoader", |
| 593 | + "SupabaseTargetProvider", |
| 594 | +] |
0 commit comments