44import numpy as np
55import pandas as pd
66
7- from sqlmodel import Session , create_engine
7+ from sqlmodel import Session , create_engine , select
88
99from policyengine_us_data .storage import STORAGE_FOLDER
1010from policyengine_us_data .db .create_database_tables import (
@@ -313,16 +313,14 @@ def _upsert_target(
313313 source : str ,
314314 notes : Optional [str ] = None ,
315315) -> None :
316- existing_target = (
317- session .query (Target )
318- .filter (
316+ existing_target = session .exec (
317+ select (Target ).where (
319318 Target .stratum_id == stratum_id ,
320319 Target .variable == variable ,
321320 Target .period == period ,
322321 Target .reform_id == 0 ,
323322 )
324- .first ()
325- )
323+ ).first ()
326324 if existing_target :
327325 existing_target .value = value
328326 existing_target .source = source
@@ -347,14 +345,12 @@ def _get_or_create_national_domain_stratum(
347345 session : Session , national_filer_stratum_id : int , variable : str
348346) -> Stratum :
349347 note = f"National filers with { variable } > 0"
350- stratum = (
351- session .query (Stratum )
352- .filter (
348+ stratum = session .exec (
349+ select (Stratum ).where (
353350 Stratum .parent_stratum_id == national_filer_stratum_id ,
354351 Stratum .notes == note ,
355352 )
356- .first ()
357- )
353+ ).first ()
358354 if stratum :
359355 return stratum
360356
@@ -751,14 +747,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
751747 filer_strata = {"national" : None , "state" : {}, "district" : {}}
752748
753749 # National filer stratum - check if it exists first
754- national_filer_stratum = (
755- session .query (Stratum )
756- .filter (
750+ national_filer_stratum = session .exec (
751+ select (Stratum ).where (
757752 Stratum .parent_stratum_id == geo_strata ["national" ],
758753 Stratum .notes == "United States - Tax Filers" ,
759754 )
760- .first ()
761- )
755+ ).first ()
762756
763757 if not national_filer_stratum :
764758 national_filer_stratum = Stratum (
@@ -780,14 +774,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
780774 # State filer strata
781775 for state_fips , state_geo_stratum_id in geo_strata ["state" ].items ():
782776 # Check if state filer stratum exists
783- state_filer_stratum = (
784- session .query (Stratum )
785- .filter (
777+ state_filer_stratum = session .exec (
778+ select (Stratum ).where (
786779 Stratum .parent_stratum_id == state_geo_stratum_id ,
787780 Stratum .notes == f"State FIPS { state_fips } - Tax Filers" ,
788781 )
789- .first ()
790- )
782+ ).first ()
791783
792784 if not state_filer_stratum :
793785 state_filer_stratum = Stratum (
@@ -814,15 +806,13 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
814806 # District filer strata
815807 for district_geoid , district_geo_stratum_id in geo_strata ["district" ].items ():
816808 # Check if district filer stratum exists
817- district_filer_stratum = (
818- session .query (Stratum )
819- .filter (
809+ district_filer_stratum = session .exec (
810+ select (Stratum ).where (
820811 Stratum .parent_stratum_id == district_geo_stratum_id ,
821812 Stratum .notes
822813 == f"Congressional District { district_geoid } - Tax Filers" ,
823814 )
824- .first ()
825- )
815+ ).first ()
826816
827817 if not district_filer_stratum :
828818 district_filer_stratum = Stratum (
@@ -917,14 +907,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
917907 ]
918908
919909 # Check if stratum already exists
920- existing_stratum = (
921- session .query (Stratum )
922- .filter (
910+ existing_stratum = session .exec (
911+ select (Stratum ).where (
923912 Stratum .parent_stratum_id == parent_stratum_id ,
924913 Stratum .notes == note ,
925914 )
926- .first ()
927- )
915+ ).first ()
928916
929917 if existing_stratum :
930918 new_stratum = existing_stratum
@@ -964,15 +952,13 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
964952 ("tax_unit_count" , count_value ),
965953 ("eitc" , amount_value ),
966954 ]:
967- existing_target = (
968- session .query (Target )
969- .filter (
955+ existing_target = session .exec (
956+ select (Target ).where (
970957 Target .stratum_id == new_stratum .stratum_id ,
971958 Target .variable == variable ,
972959 Target .period == year ,
973960 )
974- .first ()
975- )
961+ ).first ()
976962
977963 if existing_target :
978964 existing_target .value = value
@@ -1047,14 +1033,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
10471033 note = f"{ geo_description } filers with { amount_variable_name } > 0"
10481034
10491035 # Check if child stratum already exists
1050- existing_stratum = (
1051- session .query (Stratum )
1052- .filter (
1036+ existing_stratum = session .exec (
1037+ select (Stratum ).where (
10531038 Stratum .parent_stratum_id == parent_stratum_id ,
10541039 Stratum .notes == note ,
10551040 )
1056- .first ()
1057- )
1041+ ).first ()
10581042
10591043 if existing_stratum :
10601044 child_stratum = existing_stratum
@@ -1119,15 +1103,13 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
11191103 (count_variable_name , count_value ),
11201104 (amount_variable_name , amount_value ),
11211105 ]:
1122- existing_target = (
1123- session .query (Target )
1124- .filter (
1106+ existing_target = session .exec (
1107+ select (Target ).where (
11251108 Target .stratum_id == child_stratum .stratum_id ,
11261109 Target .variable == variable ,
11271110 Target .period == year ,
11281111 )
1129- .first ()
1130- )
1112+ ).first ()
11311113
11321114 if existing_target :
11331115 existing_target .value = value
@@ -1170,15 +1152,13 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
11701152 )
11711153
11721154 # Check if target already exists
1173- existing_target = (
1174- session .query (Target )
1175- .filter (
1155+ existing_target = session .exec (
1156+ select (Target ).where (
11761157 Target .stratum_id == stratum .stratum_id ,
11771158 Target .variable == "adjusted_gross_income" ,
11781159 Target .period == year ,
11791160 )
1180- .first ()
1181- )
1161+ ).first ()
11821162
11831163 if existing_target :
11841164 existing_target .value = agi_values .iloc [i ][["target_value" ]].values [0 ]
@@ -1211,14 +1191,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
12111191 note = f"National filers, AGI >= { agi_income_lower } , AGI < { agi_income_upper } "
12121192
12131193 # Check if national AGI stratum already exists
1214- nat_stratum = (
1215- session .query (Stratum )
1216- .filter (
1194+ nat_stratum = session .exec (
1195+ select (Stratum ).where (
12171196 Stratum .parent_stratum_id == filer_strata ["national" ],
12181197 Stratum .notes == note ,
12191198 )
1220- .first ()
1221- )
1199+ ).first ()
12221200
12231201 if not nat_stratum :
12241202 nat_stratum = Stratum (
@@ -1296,14 +1274,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
12961274 continue # Skip if not state or district (shouldn't happen, but defensive)
12971275
12981276 # Check if stratum already exists
1299- existing_stratum = (
1300- session .query (Stratum )
1301- .filter (
1277+ existing_stratum = session .exec (
1278+ select (Stratum ).where (
13021279 Stratum .parent_stratum_id == parent_stratum_id ,
13031280 Stratum .notes == note ,
13041281 )
1305- .first ()
1306- )
1282+ ).first ()
13071283
13081284 if existing_stratum :
13091285 new_stratum = existing_stratum
@@ -1331,15 +1307,13 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
13311307 session .flush ()
13321308
13331309 # Check if target already exists and update or create it
1334- existing_target = (
1335- session .query (Target )
1336- .filter (
1310+ existing_target = session .exec (
1311+ select (Target ).where (
13371312 Target .stratum_id == new_stratum .stratum_id ,
13381313 Target .variable == "person_count" ,
13391314 Target .period == year ,
13401315 )
1341- .first ()
1342- )
1316+ ).first ()
13431317
13441318 if existing_target :
13451319 existing_target .value = person_count
0 commit comments