Skip to content

Commit bcd58e1

Browse files
Pass format
1 parent 6ccbcd8 commit bcd58e1

14 files changed

Lines changed: 161 additions & 97 deletions

File tree

scripts/create_tables.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,13 @@
1111
# Add src to path
1212
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
1313

14-
from sqlmodel import SQLModel, create_engine, text
1514
from rich.console import Console
15+
from sqlmodel import SQLModel, create_engine
1616

1717
from policyengine_api.config.settings import settings
1818
from policyengine_api.services.storage import get_service_role_client
1919

2020
# Import all models to register them with SQLModel.metadata
21-
from policyengine_api.models import (
22-
Dataset,
23-
DatasetVersion,
24-
Dynamic,
25-
Parameter,
26-
ParameterValue,
27-
Policy,
28-
Simulation,
29-
TaxBenefitModel,
30-
TaxBenefitModelVersion,
31-
Variable,
32-
AggregateOutput,
33-
ChangeAggregate,
34-
)
3521

3622
console = Console()
3723

@@ -126,4 +112,3 @@ def apply_migrations(engine):
126112
engine = create_tables()
127113
apply_migrations(engine)
128114
console.print("\n[bold green]✓ Database setup complete!")
129-

scripts/seed.py

Lines changed: 100 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,51 @@
1-
"""Seed database with UK and US tax-benefit models, variables, parameters, and datasets."""
1+
"""Seed database with UK and US models, variables, parameters, datasets."""
22

3-
import sys
4-
from pathlib import Path
53
import logging
4+
import sys
65
import warnings
6+
from pathlib import Path
7+
78
import logfire
89

910
# Disable all SQLAlchemy and database logging BEFORE any imports
1011
logging.basicConfig(level=logging.ERROR)
11-
logging.getLogger('sqlalchemy').setLevel(logging.ERROR)
12-
warnings.filterwarnings('ignore')
12+
logging.getLogger("sqlalchemy").setLevel(logging.ERROR)
13+
warnings.filterwarnings("ignore")
1314

1415
# Add src to path
1516
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
1617

17-
from policyengine.tax_benefit_models.uk import uk_latest
18-
from policyengine.tax_benefit_models.uk.datasets import ensure_datasets as ensure_uk_datasets
19-
from policyengine.tax_benefit_models.us import us_latest
20-
from policyengine.tax_benefit_models.us.datasets import ensure_datasets as ensure_us_datasets
21-
from policyengine_api.models import (
18+
from policyengine.tax_benefit_models.uk import uk_latest # noqa: E402
19+
from policyengine.tax_benefit_models.uk.datasets import ( # noqa: E402
20+
ensure_datasets as ensure_uk_datasets,
21+
)
22+
from policyengine.tax_benefit_models.us import us_latest # noqa: E402
23+
from policyengine.tax_benefit_models.us.datasets import ( # noqa: E402
24+
ensure_datasets as ensure_us_datasets,
25+
)
26+
from rich.console import Console # noqa: E402
27+
from rich.progress import track # noqa: E402
28+
from sqlmodel import Session, create_engine, select # noqa: E402
29+
30+
from policyengine_api.config.settings import settings # noqa: E402
31+
from policyengine_api.models import ( # noqa: E402
32+
Dataset,
33+
Parameter,
34+
ParameterValue,
2235
TaxBenefitModel,
2336
TaxBenefitModelVersion,
2437
Variable,
25-
Parameter,
26-
ParameterValue,
27-
Dataset,
2838
)
29-
from policyengine_api.services.storage import upload_dataset_for_seeding
30-
from sqlmodel import Session, create_engine, select
31-
from policyengine_api.config.settings import settings
32-
from rich.console import Console
33-
from rich.progress import track
39+
from policyengine_api.services.storage import ( # noqa: E402
40+
upload_dataset_for_seeding,
41+
)
3442

3543
# Configure logfire
3644
if settings.logfire_token:
37-
logfire.configure(token=settings.logfire_token, environment=settings.logfire_environment)
45+
logfire.configure(
46+
token=settings.logfire_token,
47+
environment=settings.logfire_environment,
48+
)
3849

3950
console = Console()
4051

@@ -49,12 +60,18 @@ def get_quiet_session():
4960
def seed_model(model_version, session) -> TaxBenefitModelVersion:
5061
"""Seed a tax-benefit model with its variables and parameters."""
5162

52-
with logfire.span(f"seed_model", model=model_version.model.id, version=model_version.version):
63+
with logfire.span(
64+
"seed_model",
65+
model=model_version.model.id,
66+
version=model_version.version,
67+
):
5368
# Create or get the model
5469
console.print(f"[bold blue]Seeding {model_version.model.id}...")
5570

5671
existing_model = session.exec(
57-
select(TaxBenefitModel).where(TaxBenefitModel.name == model_version.model.id)
72+
select(TaxBenefitModel).where(
73+
TaxBenefitModel.name == model_version.model.id
74+
)
5875
).first()
5976

6077
if existing_model:
@@ -79,7 +96,9 @@ def seed_model(model_version, session) -> TaxBenefitModelVersion:
7996
).first()
8097

8198
if existing_version:
82-
console.print(f" Model version {model_version.version} already exists, skipping")
99+
console.print(
100+
f" Model version {model_version.version} already exists, skipping"
101+
)
83102
return existing_version
84103

85104
db_version = TaxBenefitModelVersion(
@@ -100,19 +119,26 @@ def seed_model(model_version, session) -> TaxBenefitModelVersion:
100119
name=var.name,
101120
entity=var.entity,
102121
description=var.description or "",
103-
data_type=var.data_type.__name__ if hasattr(var.data_type, '__name__') else str(var.data_type),
122+
data_type=var.data_type.__name__
123+
if hasattr(var.data_type, "__name__")
124+
else str(var.data_type),
104125
tax_benefit_model_version_id=db_version.id,
105126
)
106127
session.add(db_var)
107128

108129
session.commit()
109-
console.print(f" [green]✓[/green] Added {len(model_version.variables)} variables")
130+
console.print(
131+
f" [green]✓[/green] Added {len(model_version.variables)} variables"
132+
)
110133

111134
# Add parameters (creating a lookup for parameter values later)
112135
parameters_to_add = model_version.parameters
113136
if settings.limit_seed_parameters:
114137
parameters_to_add = model_version.parameters[:10_000]
115-
console.print(f" [yellow]Limiting to {len(parameters_to_add)} parameters (LIMIT_SEED_PARAMETERS=true)[/yellow]")
138+
console.print(
139+
f" [yellow]Limiting to {len(parameters_to_add)} parameters "
140+
f"(LIMIT_SEED_PARAMETERS=true)[/yellow]"
141+
)
116142

117143
with logfire.span("add_parameters", count=len(parameters_to_add)):
118144
console.print(f" Adding {len(parameters_to_add)} parameters...")
@@ -121,9 +147,11 @@ def seed_model(model_version, session) -> TaxBenefitModelVersion:
121147
for param in track(parameters_to_add, description="Parameters"):
122148
db_param = Parameter(
123149
name=param.name,
124-
label=param.label if hasattr(param, 'label') else None,
150+
label=param.label if hasattr(param, "label") else None,
125151
description=param.description or "",
126-
data_type=param.data_type.__name__ if hasattr(param.data_type, '__name__') else str(param.data_type),
152+
data_type=param.data_type.__name__
153+
if hasattr(param.data_type, "__name__")
154+
else str(param.data_type),
127155
unit=param.unit,
128156
tax_benefit_model_version_id=db_version.id,
129157
)
@@ -132,22 +160,29 @@ def seed_model(model_version, session) -> TaxBenefitModelVersion:
132160
session.refresh(db_param)
133161
param_id_map[param.id] = db_param.id
134162

135-
console.print(f" [green]✓[/green] Added {len(parameters_to_add)} parameters")
163+
console.print(
164+
f" [green]✓[/green] Added {len(parameters_to_add)} parameters"
165+
)
136166

137167
# Add parameter values
138168
# Filter to only include values for parameters we actually added
139169
parameter_values_to_add = [
140-
pv for pv in model_version.parameter_values
170+
pv
171+
for pv in model_version.parameter_values
141172
if pv.parameter.id in param_id_map
142173
]
143174

144175
with logfire.span("add_parameter_values", count=len(parameter_values_to_add)):
145-
console.print(f" Adding {len(parameter_values_to_add)} parameter values...")
176+
console.print(
177+
f" Adding {len(parameter_values_to_add)} parameter values..."
178+
)
146179
import math
147180

148181
for pv in track(parameter_values_to_add, description="Parameter values"):
149182
# Handle Infinity values - skip them as they can't be stored in JSON
150-
if isinstance(pv.value, float) and (math.isinf(pv.value) or math.isnan(pv.value)):
183+
if isinstance(pv.value, float) and (
184+
math.isinf(pv.value) or math.isnan(pv.value)
185+
):
151186
continue
152187

153188
db_pv = ParameterValue(
@@ -159,7 +194,10 @@ def seed_model(model_version, session) -> TaxBenefitModelVersion:
159194
session.add(db_pv)
160195

161196
session.commit()
162-
console.print(f" [green]✓[/green] Added {len(parameter_values_to_add)} parameter values")
197+
console.print(
198+
f" [green]✓[/green] Added {len(parameter_values_to_add)} "
199+
f"parameter values"
200+
)
163201

164202
return db_version
165203

@@ -171,34 +209,46 @@ def seed_datasets(session):
171209

172210
# Get UK and US models
173211
uk_model = session.exec(
174-
select(TaxBenefitModel).where(TaxBenefitModel.name.in_(["uk", "policyengine-uk"]))
212+
select(TaxBenefitModel).where(
213+
TaxBenefitModel.name.in_(["uk", "policyengine-uk"])
214+
)
175215
).first()
176216
us_model = session.exec(
177-
select(TaxBenefitModel).where(TaxBenefitModel.name.in_(["us", "policyengine-us"]))
217+
select(TaxBenefitModel).where(
218+
TaxBenefitModel.name.in_(["us", "policyengine-us"])
219+
)
178220
).first()
179221

180222
if not uk_model or not us_model:
181-
console.print("[red]Error: UK or US model not found. Run seed_model first.[/red]")
223+
console.print(
224+
"[red]Error: UK or US model not found. Run seed_model first.[/red]"
225+
)
182226
return
183227

184228
# UK datasets
185229
console.print(" Creating UK datasets...")
186230
uk_datasets = ensure_uk_datasets()
187231

188232
with logfire.span("seed_uk_datasets", count=len(uk_datasets)):
189-
for _, pe_dataset in track(list(uk_datasets.items()), description="UK datasets"):
233+
for _, pe_dataset in track(
234+
list(uk_datasets.items()), description="UK datasets"
235+
):
190236
# Check if dataset already exists
191237
existing = session.exec(
192238
select(Dataset).where(Dataset.name == pe_dataset.name)
193239
).first()
194240

195241
if existing:
196-
console.print(f" Dataset {pe_dataset.name} already exists, skipping")
242+
console.print(
243+
f" Dataset {pe_dataset.name} already exists, skipping"
244+
)
197245
continue
198246

199247
# Upload to S3
200248
object_name = upload_dataset_for_seeding(pe_dataset.filepath)
201-
console.print(f" Uploaded {pe_dataset.filepath} to S3 as {object_name}")
249+
console.print(
250+
f" Uploaded {pe_dataset.filepath} to S3 as {object_name}"
251+
)
202252

203253
# Create database record
204254
db_dataset = Dataset(
@@ -217,19 +267,25 @@ def seed_datasets(session):
217267
us_datasets = ensure_us_datasets()
218268

219269
with logfire.span("seed_us_datasets", count=len(us_datasets)):
220-
for _, pe_dataset in track(list(us_datasets.items()), description="US datasets"):
270+
for _, pe_dataset in track(
271+
list(us_datasets.items()), description="US datasets"
272+
):
221273
# Check if dataset already exists
222274
existing = session.exec(
223275
select(Dataset).where(Dataset.name == pe_dataset.name)
224276
).first()
225277

226278
if existing:
227-
console.print(f" Dataset {pe_dataset.name} already exists, skipping")
279+
console.print(
280+
f" Dataset {pe_dataset.name} already exists, skipping"
281+
)
228282
continue
229283

230284
# Upload to S3
231285
object_name = upload_dataset_for_seeding(pe_dataset.filepath)
232-
console.print(f" Uploaded {pe_dataset.filepath} to S3 as {object_name}")
286+
console.print(
287+
f" Uploaded {pe_dataset.filepath} to S3 as {object_name}"
288+
)
233289

234290
# Create database record
235291
db_dataset = Dataset(
@@ -243,7 +299,9 @@ def seed_datasets(session):
243299
session.commit()
244300
console.print(f" [green]✓[/green] Created dataset: {db_dataset.name}")
245301

246-
console.print(f"[green]✓[/green] Seeded {len(uk_datasets) + len(us_datasets)} datasets\n")
302+
console.print(
303+
f"[green]✓[/green] Seeded {len(uk_datasets) + len(us_datasets)} datasets\n"
304+
)
247305

248306

249307
def main():

src/policyengine_api/config/settings.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,13 @@ class Settings(BaseSettings):
3838
@property
3939
def database_url(self) -> str:
4040
"""Get database URL from Supabase."""
41-
return self.supabase_db_url or self.supabase_url.replace(
42-
"http://", "postgresql://postgres:postgres@"
43-
).replace("https://", "postgresql://postgres:postgres@") + "/postgres"
41+
return (
42+
self.supabase_db_url
43+
or self.supabase_url.replace(
44+
"http://", "postgresql://postgres:postgres@"
45+
).replace("https://", "postgresql://postgres:postgres@")
46+
+ "/postgres"
47+
)
4448

4549

4650
settings = Settings()

src/policyengine_api/main.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from contextlib import asynccontextmanager
22

3+
import logfire
34
from fastapi import FastAPI
45
from fastapi_cache import FastAPICache
56
from fastapi_cache.backends.redis import RedisBackend
67
from redis import asyncio as aioredis
78
from rich.console import Console
8-
import logfire
99

1010
from policyengine_api.api import api_router
1111
from policyengine_api.config.settings import settings
@@ -14,12 +14,10 @@
1414
console = Console()
1515

1616
# Configure Logfire
17-
from policyengine_api.config.settings import settings as app_settings
18-
1917
logfire.configure(
2018
service_name="policyengine-api",
21-
token=app_settings.logfire_token if app_settings.logfire_token else None,
22-
environment=app_settings.logfire_environment,
19+
token=settings.logfire_token if settings.logfire_token else None,
20+
environment=settings.logfire_environment,
2321
)
2422
logfire.instrument_httpx()
2523

@@ -32,7 +30,9 @@ async def lifespan(app: FastAPI):
3230
console.print("[bold green]Database initialized[/bold green]")
3331

3432
console.print("[bold green]Initializing cache...[/bold green]")
35-
redis = aioredis.from_url(settings.redis_url, encoding="utf8", decode_responses=True)
33+
redis = aioredis.from_url(
34+
settings.redis_url, encoding="utf8", decode_responses=True
35+
)
3636
FastAPICache.init(RedisBackend(redis), prefix="fastapi-cache")
3737
console.print("[bold green]Cache initialized[/bold green]")
3838

src/policyengine_api/models/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from sqlmodel import Field, Relationship, SQLModel
66

77
if TYPE_CHECKING:
8-
from .tax_benefit_model import TaxBenefitModel
98
from .dataset_version import DatasetVersion
9+
from .tax_benefit_model import TaxBenefitModel
1010

1111

1212
class DatasetBase(SQLModel):

0 commit comments

Comments
 (0)