1- """Seed database with UK and US tax-benefit models, variables, parameters, and datasets."""
1+ """Seed database with UK and US models, variables, parameters, datasets."""
22
3- import sys
4- from pathlib import Path
53import logging
4+ import sys
65import warnings
6+ from pathlib import Path
7+
78import logfire
89
910# Disable all SQLAlchemy and database logging BEFORE any imports
1011logging .basicConfig (level = logging .ERROR )
11- logging .getLogger (' sqlalchemy' ).setLevel (logging .ERROR )
12- warnings .filterwarnings (' ignore' )
12+ logging .getLogger (" sqlalchemy" ).setLevel (logging .ERROR )
13+ warnings .filterwarnings (" ignore" )
1314
1415# Add src to path
1516sys .path .insert (0 , str (Path (__file__ ).parent .parent / "src" ))
1617
17- from policyengine .tax_benefit_models .uk import uk_latest
18- from policyengine .tax_benefit_models .uk .datasets import ensure_datasets as ensure_uk_datasets
19- from policyengine .tax_benefit_models .us import us_latest
20- from policyengine .tax_benefit_models .us .datasets import ensure_datasets as ensure_us_datasets
21- from policyengine_api .models import (
18+ from policyengine .tax_benefit_models .uk import uk_latest # noqa: E402
19+ from policyengine .tax_benefit_models .uk .datasets import ( # noqa: E402
20+ ensure_datasets as ensure_uk_datasets ,
21+ )
22+ from policyengine .tax_benefit_models .us import us_latest # noqa: E402
23+ from policyengine .tax_benefit_models .us .datasets import ( # noqa: E402
24+ ensure_datasets as ensure_us_datasets ,
25+ )
26+ from rich .console import Console # noqa: E402
27+ from rich .progress import track # noqa: E402
28+ from sqlmodel import Session , create_engine , select # noqa: E402
29+
30+ from policyengine_api .config .settings import settings # noqa: E402
31+ from policyengine_api .models import ( # noqa: E402
32+ Dataset ,
33+ Parameter ,
34+ ParameterValue ,
2235 TaxBenefitModel ,
2336 TaxBenefitModelVersion ,
2437 Variable ,
25- Parameter ,
26- ParameterValue ,
27- Dataset ,
2838)
29- from policyengine_api .services .storage import upload_dataset_for_seeding
30- from sqlmodel import Session , create_engine , select
31- from policyengine_api .config .settings import settings
32- from rich .console import Console
33- from rich .progress import track
39+ from policyengine_api .services .storage import ( # noqa: E402
40+ upload_dataset_for_seeding ,
41+ )
3442
3543# Configure logfire
3644if settings .logfire_token :
37- logfire .configure (token = settings .logfire_token , environment = settings .logfire_environment )
45+ logfire .configure (
46+ token = settings .logfire_token ,
47+ environment = settings .logfire_environment ,
48+ )
3849
3950console = Console ()
4051
@@ -49,12 +60,18 @@ def get_quiet_session():
4960def seed_model (model_version , session ) -> TaxBenefitModelVersion :
5061 """Seed a tax-benefit model with its variables and parameters."""
5162
52- with logfire .span (f"seed_model" , model = model_version .model .id , version = model_version .version ):
63+ with logfire .span (
64+ "seed_model" ,
65+ model = model_version .model .id ,
66+ version = model_version .version ,
67+ ):
5368 # Create or get the model
5469 console .print (f"[bold blue]Seeding { model_version .model .id } ..." )
5570
5671 existing_model = session .exec (
57- select (TaxBenefitModel ).where (TaxBenefitModel .name == model_version .model .id )
72+ select (TaxBenefitModel ).where (
73+ TaxBenefitModel .name == model_version .model .id
74+ )
5875 ).first ()
5976
6077 if existing_model :
@@ -79,7 +96,9 @@ def seed_model(model_version, session) -> TaxBenefitModelVersion:
7996 ).first ()
8097
8198 if existing_version :
82- console .print (f" Model version { model_version .version } already exists, skipping" )
99+ console .print (
100+ f" Model version { model_version .version } already exists, skipping"
101+ )
83102 return existing_version
84103
85104 db_version = TaxBenefitModelVersion (
@@ -100,19 +119,26 @@ def seed_model(model_version, session) -> TaxBenefitModelVersion:
100119 name = var .name ,
101120 entity = var .entity ,
102121 description = var .description or "" ,
103- data_type = var .data_type .__name__ if hasattr (var .data_type , '__name__' ) else str (var .data_type ),
122+ data_type = var .data_type .__name__
123+ if hasattr (var .data_type , "__name__" )
124+ else str (var .data_type ),
104125 tax_benefit_model_version_id = db_version .id ,
105126 )
106127 session .add (db_var )
107128
108129 session .commit ()
109- console .print (f" [green]✓[/green] Added { len (model_version .variables )} variables" )
130+ console .print (
131+ f" [green]✓[/green] Added { len (model_version .variables )} variables"
132+ )
110133
111134 # Add parameters (creating a lookup for parameter values later)
112135 parameters_to_add = model_version .parameters
113136 if settings .limit_seed_parameters :
114137 parameters_to_add = model_version .parameters [:10_000 ]
115- console .print (f" [yellow]Limiting to { len (parameters_to_add )} parameters (LIMIT_SEED_PARAMETERS=true)[/yellow]" )
138+ console .print (
139+ f" [yellow]Limiting to { len (parameters_to_add )} parameters "
140+ f"(LIMIT_SEED_PARAMETERS=true)[/yellow]"
141+ )
116142
117143 with logfire .span ("add_parameters" , count = len (parameters_to_add )):
118144 console .print (f" Adding { len (parameters_to_add )} parameters..." )
@@ -121,9 +147,11 @@ def seed_model(model_version, session) -> TaxBenefitModelVersion:
121147 for param in track (parameters_to_add , description = "Parameters" ):
122148 db_param = Parameter (
123149 name = param .name ,
124- label = param .label if hasattr (param , ' label' ) else None ,
150+ label = param .label if hasattr (param , " label" ) else None ,
125151 description = param .description or "" ,
126- data_type = param .data_type .__name__ if hasattr (param .data_type , '__name__' ) else str (param .data_type ),
152+ data_type = param .data_type .__name__
153+ if hasattr (param .data_type , "__name__" )
154+ else str (param .data_type ),
127155 unit = param .unit ,
128156 tax_benefit_model_version_id = db_version .id ,
129157 )
@@ -132,22 +160,29 @@ def seed_model(model_version, session) -> TaxBenefitModelVersion:
132160 session .refresh (db_param )
133161 param_id_map [param .id ] = db_param .id
134162
135- console .print (f" [green]✓[/green] Added { len (parameters_to_add )} parameters" )
163+ console .print (
164+ f" [green]✓[/green] Added { len (parameters_to_add )} parameters"
165+ )
136166
137167 # Add parameter values
138168 # Filter to only include values for parameters we actually added
139169 parameter_values_to_add = [
140- pv for pv in model_version .parameter_values
170+ pv
171+ for pv in model_version .parameter_values
141172 if pv .parameter .id in param_id_map
142173 ]
143174
144175 with logfire .span ("add_parameter_values" , count = len (parameter_values_to_add )):
145- console .print (f" Adding { len (parameter_values_to_add )} parameter values..." )
176+ console .print (
177+ f" Adding { len (parameter_values_to_add )} parameter values..."
178+ )
146179 import math
147180
148181 for pv in track (parameter_values_to_add , description = "Parameter values" ):
149182 # Handle Infinity values - skip them as they can't be stored in JSON
150- if isinstance (pv .value , float ) and (math .isinf (pv .value ) or math .isnan (pv .value )):
183+ if isinstance (pv .value , float ) and (
184+ math .isinf (pv .value ) or math .isnan (pv .value )
185+ ):
151186 continue
152187
153188 db_pv = ParameterValue (
@@ -159,7 +194,10 @@ def seed_model(model_version, session) -> TaxBenefitModelVersion:
159194 session .add (db_pv )
160195
161196 session .commit ()
162- console .print (f" [green]✓[/green] Added { len (parameter_values_to_add )} parameter values" )
197+ console .print (
198+ f" [green]✓[/green] Added { len (parameter_values_to_add )} "
199+ f"parameter values"
200+ )
163201
164202 return db_version
165203
@@ -171,34 +209,46 @@ def seed_datasets(session):
171209
172210 # Get UK and US models
173211 uk_model = session .exec (
174- select (TaxBenefitModel ).where (TaxBenefitModel .name .in_ (["uk" , "policyengine-uk" ]))
212+ select (TaxBenefitModel ).where (
213+ TaxBenefitModel .name .in_ (["uk" , "policyengine-uk" ])
214+ )
175215 ).first ()
176216 us_model = session .exec (
177- select (TaxBenefitModel ).where (TaxBenefitModel .name .in_ (["us" , "policyengine-us" ]))
217+ select (TaxBenefitModel ).where (
218+ TaxBenefitModel .name .in_ (["us" , "policyengine-us" ])
219+ )
178220 ).first ()
179221
180222 if not uk_model or not us_model :
181- console .print ("[red]Error: UK or US model not found. Run seed_model first.[/red]" )
223+ console .print (
224+ "[red]Error: UK or US model not found. Run seed_model first.[/red]"
225+ )
182226 return
183227
184228 # UK datasets
185229 console .print (" Creating UK datasets..." )
186230 uk_datasets = ensure_uk_datasets ()
187231
188232 with logfire .span ("seed_uk_datasets" , count = len (uk_datasets )):
189- for _ , pe_dataset in track (list (uk_datasets .items ()), description = "UK datasets" ):
233+ for _ , pe_dataset in track (
234+ list (uk_datasets .items ()), description = "UK datasets"
235+ ):
190236 # Check if dataset already exists
191237 existing = session .exec (
192238 select (Dataset ).where (Dataset .name == pe_dataset .name )
193239 ).first ()
194240
195241 if existing :
196- console .print (f" Dataset { pe_dataset .name } already exists, skipping" )
242+ console .print (
243+ f" Dataset { pe_dataset .name } already exists, skipping"
244+ )
197245 continue
198246
199247 # Upload to S3
200248 object_name = upload_dataset_for_seeding (pe_dataset .filepath )
201- console .print (f" Uploaded { pe_dataset .filepath } to S3 as { object_name } " )
249+ console .print (
250+ f" Uploaded { pe_dataset .filepath } to S3 as { object_name } "
251+ )
202252
203253 # Create database record
204254 db_dataset = Dataset (
@@ -217,19 +267,25 @@ def seed_datasets(session):
217267 us_datasets = ensure_us_datasets ()
218268
219269 with logfire .span ("seed_us_datasets" , count = len (us_datasets )):
220- for _ , pe_dataset in track (list (us_datasets .items ()), description = "US datasets" ):
270+ for _ , pe_dataset in track (
271+ list (us_datasets .items ()), description = "US datasets"
272+ ):
221273 # Check if dataset already exists
222274 existing = session .exec (
223275 select (Dataset ).where (Dataset .name == pe_dataset .name )
224276 ).first ()
225277
226278 if existing :
227- console .print (f" Dataset { pe_dataset .name } already exists, skipping" )
279+ console .print (
280+ f" Dataset { pe_dataset .name } already exists, skipping"
281+ )
228282 continue
229283
230284 # Upload to S3
231285 object_name = upload_dataset_for_seeding (pe_dataset .filepath )
232- console .print (f" Uploaded { pe_dataset .filepath } to S3 as { object_name } " )
286+ console .print (
287+ f" Uploaded { pe_dataset .filepath } to S3 as { object_name } "
288+ )
233289
234290 # Create database record
235291 db_dataset = Dataset (
@@ -243,7 +299,9 @@ def seed_datasets(session):
243299 session .commit ()
244300 console .print (f" [green]✓[/green] Created dataset: { db_dataset .name } " )
245301
246- console .print (f"[green]✓[/green] Seeded { len (uk_datasets ) + len (us_datasets )} datasets\n " )
302+ console .print (
303+ f"[green]✓[/green] Seeded { len (uk_datasets ) + len (us_datasets )} datasets\n "
304+ )
247305
248306
249307def main ():
0 commit comments