Skip to content
This repository was archived by the owner on Jun 14, 2026. It is now read-only.

Commit 40f3e38

Browse files
authored
Add US Supabase target loader (#5)
1 parent ee65bd6 commit 40f3e38

5 files changed

Lines changed: 545 additions & 1 deletion

File tree

.github/workflows/site-snapshot.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ jobs:
4141
run: |
4242
uv run --extra dev --with pydantic --with-editable ../microplex pytest -q \
4343
tests/test_package_imports.py \
44+
tests/targets/test_supabase.py \
4445
tests/pipelines/test_check_site_snapshot.py \
4546
tests/pipelines/test_imputation_ablation.py \
4647
tests/pipelines/test_site_snapshot.py \

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ requires-python = ">=3.13"
1515
dependencies = [
1616
"microplex[calibrate]",
1717
"duckdb>=1.2",
18+
"requests>=2.31",
1819
]
1920

2021
[project.optional-dependencies]
Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
"""US Supabase calibration target loader."""
2+
3+
from __future__ import annotations
4+
5+
import os
6+
from typing import Any
7+
8+
import requests
9+
10+
11+
class SupabaseTargetLoader:
12+
"""Load US calibration targets from the microplex Supabase schema."""
13+
14+
# Mapping from Supabase variable names to CPS column names.
15+
CPS_COLUMN_MAP = {
16+
"employment_income": "employment_income",
17+
"self_employment_income": "self_employment_income",
18+
"dividend_income": "dividend_income",
19+
"interest_income": "interest_income",
20+
"rental_income": "rental_income",
21+
"social_security": "social_security",
22+
"unemployment_compensation": "unemployment_compensation",
23+
"taxable_pension_income": "taxable_pension_income",
24+
"tax_exempt_pension_income": "tax_exempt_pension_income",
25+
"long_term_capital_gains": "long_term_capital_gains",
26+
"short_term_capital_gains": "short_term_capital_gains",
27+
"partnership_s_corp_income": "partnership_s_corp_income",
28+
"farm_income": "farm_income",
29+
"alimony_income": "alimony_income",
30+
"snap_spending": "snap",
31+
"ssi_spending": "ssi",
32+
"eitc_spending": "eitc",
33+
"social_security_spending": "social_security",
34+
"unemployment_spending": "unemployment_compensation",
35+
"medicaid_enrollment": "medicaid",
36+
"aca_enrollment": "aca",
37+
"snap_households": "snap",
38+
"health_insurance_premiums": "health_insurance_premiums",
39+
"other_medical_expenses": "medical_expenses",
40+
}
41+
42+
STATE_FIPS = {
43+
"01": "al",
44+
"02": "ak",
45+
"04": "az",
46+
"05": "ar",
47+
"06": "ca",
48+
"08": "co",
49+
"09": "ct",
50+
"10": "de",
51+
"11": "dc",
52+
"12": "fl",
53+
"13": "ga",
54+
"15": "hi",
55+
"16": "id",
56+
"17": "il",
57+
"18": "in",
58+
"19": "ia",
59+
"20": "ks",
60+
"21": "ky",
61+
"22": "la",
62+
"23": "me",
63+
"24": "md",
64+
"25": "ma",
65+
"26": "mi",
66+
"27": "mn",
67+
"28": "ms",
68+
"29": "mo",
69+
"30": "mt",
70+
"31": "ne",
71+
"32": "nv",
72+
"33": "nh",
73+
"34": "nj",
74+
"35": "nm",
75+
"36": "ny",
76+
"37": "nc",
77+
"38": "nd",
78+
"39": "oh",
79+
"40": "ok",
80+
"41": "or",
81+
"42": "pa",
82+
"44": "ri",
83+
"45": "sc",
84+
"46": "sd",
85+
"47": "tn",
86+
"48": "tx",
87+
"49": "ut",
88+
"50": "vt",
89+
"51": "va",
90+
"53": "wa",
91+
"54": "wv",
92+
"55": "wi",
93+
"56": "wy",
94+
}
95+
96+
def __init__(
97+
self,
98+
url: str | None = None,
99+
key: str | None = None,
100+
schema: str = "microplex",
101+
) -> None:
102+
"""Initialize the loader.
103+
104+
Args:
105+
url: Supabase URL. Defaults to SUPABASE_URL env var.
106+
key: Supabase key. Defaults to COSILICO_SUPABASE_SERVICE_KEY env var.
107+
schema: Schema to use. Defaults to 'microplex'.
108+
"""
109+
self.url = url or os.environ.get(
110+
"SUPABASE_URL",
111+
"https://nsupqhfchdtqclomlrgs.supabase.co",
112+
)
113+
self.key = key or os.environ.get("COSILICO_SUPABASE_SERVICE_KEY")
114+
if not self.key:
115+
raise ValueError(
116+
"Supabase service key must be provided via the key argument or "
117+
"COSILICO_SUPABASE_SERVICE_KEY."
118+
)
119+
self.base_url = f"{self.url}/rest/v1"
120+
self.headers = {
121+
"apikey": self.key,
122+
"Authorization": f"Bearer {self.key}",
123+
"Content-Type": "application/json",
124+
"Accept-Profile": schema,
125+
"Content-Profile": schema,
126+
}
127+
self._cache = {}
128+
129+
def _get(
130+
self,
131+
endpoint: str,
132+
params: dict[str, Any] | None = None,
133+
paginate: bool = True,
134+
) -> list[dict[str, Any]]:
135+
"""Make a GET request to Supabase with optional pagination."""
136+
url = f"{self.base_url}/{endpoint}"
137+
params = params or {}
138+
139+
if not paginate:
140+
response = requests.get(
141+
url,
142+
headers=self.headers,
143+
params=params,
144+
timeout=30,
145+
)
146+
response.raise_for_status()
147+
return response.json()
148+
149+
all_results = []
150+
offset = 0
151+
limit = 1000
152+
153+
while True:
154+
page_params = {**params, "limit": limit, "offset": offset}
155+
response = requests.get(
156+
url,
157+
headers=self.headers,
158+
params=page_params,
159+
timeout=30,
160+
)
161+
response.raise_for_status()
162+
results = response.json()
163+
164+
if not results:
165+
break
166+
167+
all_results.extend(results)
168+
offset += limit
169+
170+
if len(results) < limit:
171+
break
172+
173+
return all_results
174+
175+
def load_all(self, period: int | None = None) -> list[dict[str, Any]]:
176+
"""Load all targets with source and stratum info."""
177+
params = {
178+
"select": "id,variable,value,target_type,period,notes,source:sources(id,name,institution),stratum:strata(id,name,jurisdiction)",
179+
}
180+
if period:
181+
params["period"] = f"eq.{period}"
182+
183+
return self._get("targets", params)
184+
185+
def load_by_institution(
186+
self,
187+
institution: str,
188+
period: int | None = None,
189+
) -> list[dict[str, Any]]:
190+
"""Load targets from a specific source institution."""
191+
sources = self._get("sources", {"institution": f"eq.{institution}"})
192+
source_ids = [source["id"] for source in sources]
193+
194+
if not source_ids:
195+
return []
196+
197+
params = {
198+
"select": "id,variable,value,target_type,period,notes,source:sources(id,name,institution),stratum:strata(id,name,jurisdiction)",
199+
"source_id": f"in.({','.join(source_ids)})",
200+
}
201+
if period:
202+
params["period"] = f"eq.{period}"
203+
204+
return self._get("targets", params)
205+
206+
def load_by_period(self, period: int) -> list[dict[str, Any]]:
207+
"""Load targets for a specific year."""
208+
return self.load_all(period=period)
209+
210+
def get_cps_column_map(self) -> dict[str, str]:
211+
"""Get the mapping from Supabase variable names to CPS columns."""
212+
return self.CPS_COLUMN_MAP.copy()
213+
214+
def _parse_jurisdiction(self, jurisdiction: str) -> str | None:
215+
"""Parse jurisdiction to get the state code when applicable."""
216+
if jurisdiction in {"us", "us-national"}:
217+
return None
218+
219+
if jurisdiction.startswith("us-") and len(jurisdiction) == 5:
220+
state = jurisdiction[3:].lower()
221+
if len(state) == 2:
222+
return state
223+
224+
if jurisdiction.startswith("us-") and len(jurisdiction) == 5:
225+
fips = jurisdiction[3:]
226+
return self.STATE_FIPS.get(fips)
227+
228+
return None
229+
230+
def build_calibration_constraints(
231+
self,
232+
period: int = 2024,
233+
include_states: bool = False,
234+
target_types: list[str] | None = None,
235+
) -> dict[str, float]:
236+
"""Build a CPS-column calibration constraint dict from Supabase targets."""
237+
targets = self.load_all(period=period)
238+
constraints = {}
239+
240+
for target in targets:
241+
variable = target["variable"]
242+
value = target["value"]
243+
target_type = target.get("target_type", "amount")
244+
stratum = target.get("stratum", {})
245+
jurisdiction = stratum.get("jurisdiction", "us")
246+
247+
if target_types and target_type not in target_types:
248+
continue
249+
250+
cps_col = self.CPS_COLUMN_MAP.get(variable)
251+
if not cps_col:
252+
continue
253+
254+
state = self._parse_jurisdiction(jurisdiction)
255+
256+
if state and include_states:
257+
constraints[f"{cps_col}_{state}"] = value
258+
elif not state and cps_col not in constraints:
259+
constraints[cps_col] = value
260+
261+
return constraints
262+
263+
def get_summary(self) -> dict[str, Any]:
264+
"""Get summary counts for available targets in Supabase."""
265+
targets = self.load_all()
266+
267+
by_institution = {}
268+
by_variable = {}
269+
by_type = {}
270+
271+
for target in targets:
272+
institution = target.get("source", {}).get("institution", "Unknown")
273+
by_institution[institution] = by_institution.get(institution, 0) + 1
274+
275+
variable = target["variable"]
276+
by_variable[variable] = by_variable.get(variable, 0) + 1
277+
278+
target_type = target.get("target_type", "amount")
279+
by_type[target_type] = by_type.get(target_type, 0) + 1
280+
281+
return {
282+
"total": len(targets),
283+
"by_institution": by_institution,
284+
"by_variable": by_variable,
285+
"by_type": by_type,
286+
}
287+
288+
289+
__all__ = ["SupabaseTargetLoader"]

0 commit comments

Comments
 (0)