Skip to content
This repository was archived by the owner on Jun 14, 2026. It is now read-only.

Commit 1f014d8

Browse files
authored
Add Supabase target provider (#6)
1 parent 40f3e38 commit 1f014d8

2 files changed

Lines changed: 462 additions & 10 deletions

File tree

src/microplex_us/supabase_targets.py

Lines changed: 313 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,92 @@
66
from typing import Any
77

88
import requests
9+
from microplex.core import EntityType
10+
from microplex.targets import (
11+
FilterOperator,
12+
TargetAggregation,
13+
TargetFilter,
14+
TargetQuery,
15+
TargetSet,
16+
TargetSpec,
17+
apply_target_query,
18+
)
19+
20+
from microplex_us.target_registry import (
21+
US_TARGET_AVAILABLE_KEY,
22+
US_TARGET_CATEGORY_KEY,
23+
US_TARGET_GROUP_KEY,
24+
US_TARGET_IMPUTATION_KEY,
25+
US_TARGET_LEVEL_KEY,
26+
TargetCategory,
27+
TargetLevel,
28+
)
29+
30+
SUPABASE_TARGET_ID_KEY = "supabase_target_id"
31+
SUPABASE_VARIABLE_KEY = "supabase_variable"
32+
SUPABASE_TARGET_TYPE_KEY = "supabase_target_type"
33+
SUPABASE_JURISDICTION_KEY = "supabase_jurisdiction"
34+
SUPABASE_STRATUM_NAME_KEY = "supabase_stratum_name"
35+
SUPABASE_SOURCE_INSTITUTION_KEY = "supabase_source_institution"
36+
SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY = "supabase_supported_by_column_map"
37+
38+
_COUNT_ALL_VARIABLES = {
39+
"family_count",
40+
"household_count",
41+
"person_count",
42+
"spm_unit_count",
43+
"tax_unit_count",
44+
}
45+
46+
_COUNT_ENTITY_MAP = {
47+
"family_count": EntityType.FAMILY,
48+
"household_count": EntityType.HOUSEHOLD,
49+
"person_count": EntityType.PERSON,
50+
"spm_unit_count": EntityType.SPM_UNIT,
51+
"tax_unit_count": EntityType.TAX_UNIT,
52+
}
53+
54+
_INCOME_VARIABLES = {
55+
"alimony_income",
56+
"dividend_income",
57+
"employment_income",
58+
"farm_income",
59+
"interest_income",
60+
"long_term_capital_gains",
61+
"partnership_s_corp_income",
62+
"rental_income",
63+
"self_employment_income",
64+
"short_term_capital_gains",
65+
"social_security",
66+
"tax_exempt_pension_income",
67+
"taxable_pension_income",
68+
"unemployment_compensation",
69+
}
70+
71+
_BENEFIT_VARIABLES = {
72+
"eitc_spending",
73+
"snap_households",
74+
"snap_spending",
75+
"social_security_spending",
76+
"ssi_spending",
77+
"unemployment_spending",
78+
}
79+
80+
_HEALTH_VARIABLES = {
81+
"aca_enrollment",
82+
"health_insurance_premiums",
83+
"medicaid_enrollment",
84+
"other_medical_expenses",
85+
}
86+
87+
_TAX_UNIT_VARIABLES = {
88+
"eitc_spending",
89+
}
90+
91+
_HOUSEHOLD_VARIABLES = {
92+
"snap_households",
93+
"snap_spending",
94+
}
995

1096

1197
class SupabaseTargetLoader:
@@ -217,13 +303,11 @@ def _parse_jurisdiction(self, jurisdiction: str) -> str | None:
217303
return None
218304

219305
if jurisdiction.startswith("us-") and len(jurisdiction) == 5:
220-
state = jurisdiction[3:].lower()
221-
if len(state) == 2:
222-
return state
223-
224-
if jurisdiction.startswith("us-") and len(jurisdiction) == 5:
225-
fips = jurisdiction[3:]
226-
return self.STATE_FIPS.get(fips)
306+
suffix = jurisdiction[3:].lower()
307+
if suffix in self.STATE_FIPS:
308+
return self.STATE_FIPS[suffix]
309+
if suffix in _state_abbr_to_fips(self.STATE_FIPS):
310+
return suffix
227311

228312
return None
229313

@@ -286,4 +370,225 @@ def get_summary(self) -> dict[str, Any]:
286370
}
287371

288372

289-
__all__ = ["SupabaseTargetLoader"]
373+
class SupabaseTargetProvider(SupabaseTargetLoader):
374+
"""Load Supabase targets as canonical core target specs."""
375+
376+
def load_target_set(self, query: TargetQuery | None = None) -> TargetSet:
377+
"""Load a canonical target set through the core provider protocol."""
378+
query = query or TargetQuery()
379+
provider_filters = query.provider_filters
380+
period = _query_period(query.period)
381+
institution = provider_filters.get("institution")
382+
target_types = _as_string_set(provider_filters.get("target_types"))
383+
include_unsupported = bool(provider_filters.get("include_unsupported", True))
384+
include_states = bool(provider_filters.get("include_states", True))
385+
386+
if institution:
387+
rows = self.load_by_institution(str(institution), period=period)
388+
else:
389+
rows = self.load_all(period=period)
390+
391+
specs: list[TargetSpec] = []
392+
for row in rows:
393+
target_type = _target_type(row)
394+
if target_types and target_type not in target_types:
395+
continue
396+
397+
spec = self.target_from_row(row)
398+
if (
399+
not include_states
400+
and spec.metadata.get(US_TARGET_LEVEL_KEY) == TargetLevel.STATE.value
401+
):
402+
continue
403+
if (
404+
not include_unsupported
405+
and not spec.metadata[SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY]
406+
):
407+
continue
408+
specs.append(spec)
409+
410+
return apply_target_query(
411+
TargetSet(specs),
412+
TargetQuery(
413+
period=period if period is not None else query.period,
414+
entity=query.entity,
415+
names=query.names,
416+
metadata_filters=query.metadata_filters,
417+
),
418+
)
419+
420+
def target_from_row(self, row: dict[str, Any]) -> TargetSpec:
421+
"""Translate one Supabase target row into the canonical target IR."""
422+
variable = str(row["variable"])
423+
jurisdiction = _target_jurisdiction(row)
424+
state_fips, state_abbr = _jurisdiction_state(jurisdiction, self.STATE_FIPS)
425+
target_type = _target_type(row)
426+
aggregation = _aggregation_for_target_type(target_type)
427+
measure = self.CPS_COLUMN_MAP.get(variable, variable)
428+
supported = variable in self.CPS_COLUMN_MAP
429+
source = row.get("source") if isinstance(row.get("source"), dict) else {}
430+
source_name = source.get("name") or source.get("institution")
431+
source_institution = source.get("institution")
432+
stratum = row.get("stratum") if isinstance(row.get("stratum"), dict) else {}
433+
category = _category_for_variable(variable)
434+
level = TargetLevel.STATE if state_fips is not None else TargetLevel.NATIONAL
435+
436+
filters: list[TargetFilter] = []
437+
if aggregation is TargetAggregation.COUNT and variable not in _COUNT_ALL_VARIABLES:
438+
filters.append(
439+
TargetFilter(
440+
feature=measure,
441+
operator=FilterOperator.GT,
442+
value=0,
443+
)
444+
)
445+
446+
if state_fips is not None:
447+
filters.append(
448+
TargetFilter(
449+
feature="state_fips",
450+
operator=FilterOperator.EQ,
451+
value=state_fips,
452+
)
453+
)
454+
455+
metadata: dict[str, Any] = {
456+
SUPABASE_TARGET_ID_KEY: row.get("id"),
457+
SUPABASE_VARIABLE_KEY: variable,
458+
SUPABASE_TARGET_TYPE_KEY: target_type,
459+
SUPABASE_JURISDICTION_KEY: jurisdiction,
460+
SUPABASE_STRATUM_NAME_KEY: stratum.get("name"),
461+
SUPABASE_SOURCE_INSTITUTION_KEY: source_institution,
462+
SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY: supported,
463+
US_TARGET_LEVEL_KEY: level.value,
464+
US_TARGET_GROUP_KEY: _group_for_category(category),
465+
US_TARGET_AVAILABLE_KEY: supported,
466+
US_TARGET_IMPUTATION_KEY: not supported,
467+
}
468+
if category is not None:
469+
metadata[US_TARGET_CATEGORY_KEY] = category.value
470+
if state_fips is not None:
471+
metadata["state_fips"] = state_fips
472+
metadata["state_abbr"] = state_abbr
473+
474+
return TargetSpec(
475+
name=_target_name(variable, jurisdiction),
476+
entity=_entity_for_variable(variable),
477+
value=float(row["value"]),
478+
period=int(row["period"]),
479+
measure=None if aggregation is TargetAggregation.COUNT else measure,
480+
aggregation=aggregation,
481+
filters=tuple(filters),
482+
source=source_name,
483+
units=_units_for_target_type(target_type),
484+
description=row.get("notes"),
485+
metadata=metadata,
486+
)
487+
488+
489+
def _target_type(row: dict[str, Any]) -> str:
490+
return str(row.get("target_type") or "amount").lower()
491+
492+
493+
def _aggregation_for_target_type(target_type: str) -> TargetAggregation:
494+
if target_type == "count":
495+
return TargetAggregation.COUNT
496+
if target_type == "mean":
497+
return TargetAggregation.MEAN
498+
return TargetAggregation.SUM
499+
500+
501+
def _target_jurisdiction(row: dict[str, Any]) -> str:
502+
stratum = row.get("stratum") if isinstance(row.get("stratum"), dict) else {}
503+
return str(stratum.get("jurisdiction") or "us")
504+
505+
506+
def _target_name(variable: str, jurisdiction: str) -> str:
507+
if jurisdiction in {"us", "us-national"}:
508+
return variable
509+
return f"{variable}_{jurisdiction.replace('-', '_')}"
510+
511+
512+
def _query_period(period: int | str | None) -> int | None:
513+
if isinstance(period, int):
514+
return period
515+
if isinstance(period, str) and period.isdigit():
516+
return int(period)
517+
return None
518+
519+
520+
def _as_string_set(value: Any) -> set[str]:
521+
if value is None:
522+
return set()
523+
if isinstance(value, str):
524+
return {value}
525+
return {str(item) for item in value}
526+
527+
528+
def _state_abbr_to_fips(state_fips: dict[str, str]) -> dict[str, str]:
529+
return {abbr: fips for fips, abbr in state_fips.items()}
530+
531+
532+
def _jurisdiction_state(
533+
jurisdiction: str,
534+
state_fips: dict[str, str],
535+
) -> tuple[str | None, str | None]:
536+
if not jurisdiction.startswith("us-") or len(jurisdiction) != 5:
537+
return None, None
538+
539+
suffix = jurisdiction[3:].lower()
540+
if suffix in state_fips:
541+
return suffix, state_fips[suffix]
542+
543+
abbr_to_fips = _state_abbr_to_fips(state_fips)
544+
if suffix in abbr_to_fips:
545+
return abbr_to_fips[suffix], suffix
546+
547+
return None, None
548+
549+
550+
def _category_for_variable(variable: str) -> TargetCategory | None:
551+
if variable in _INCOME_VARIABLES:
552+
return TargetCategory.INCOME
553+
if variable in _BENEFIT_VARIABLES:
554+
return TargetCategory.BENEFITS
555+
if variable in _HEALTH_VARIABLES:
556+
return TargetCategory.HEALTH
557+
if variable.endswith("_tax") or variable.endswith("_credit"):
558+
return TargetCategory.TAX
559+
if variable in _COUNT_ALL_VARIABLES:
560+
return TargetCategory.DEMOGRAPHICS
561+
return None
562+
563+
564+
def _entity_for_variable(variable: str) -> EntityType:
565+
if variable in _COUNT_ENTITY_MAP:
566+
return _COUNT_ENTITY_MAP[variable]
567+
if variable in _TAX_UNIT_VARIABLES:
568+
return EntityType.TAX_UNIT
569+
if variable in _HOUSEHOLD_VARIABLES:
570+
return EntityType.HOUSEHOLD
571+
return EntityType.PERSON
572+
573+
574+
def _group_for_category(category: TargetCategory | None) -> str:
575+
if category is None:
576+
return "supabase_targets"
577+
return f"supabase_{category.value}"
578+
579+
580+
def _units_for_target_type(target_type: str) -> str | None:
581+
return "USD" if target_type == "amount" else None
582+
583+
584+
__all__ = [
585+
"SUPABASE_JURISDICTION_KEY",
586+
"SUPABASE_SOURCE_INSTITUTION_KEY",
587+
"SUPABASE_STRATUM_NAME_KEY",
588+
"SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY",
589+
"SUPABASE_TARGET_ID_KEY",
590+
"SUPABASE_TARGET_TYPE_KEY",
591+
"SUPABASE_VARIABLE_KEY",
592+
"SupabaseTargetLoader",
593+
"SupabaseTargetProvider",
594+
]

0 commit comments

Comments
 (0)