Skip to content

Commit a4f9f3d

Browse files
authored
Merge pull request #693 from PolicyEngine/codex/fix-ci-warnings
Remove SQLModel query deprecation warnings
2 parents 82caa97 + ec27d98 commit a4f9f3d

4 files changed

Lines changed: 89 additions & 136 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Replace legacy SQLModel `session.query(...)` lookups in the SOI ETL loaders and their focused tests with `session.exec(select(...))` to remove deprecation warnings in CI.

policyengine_us_data/db/etl_irs_soi.py

Lines changed: 40 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import pandas as pd
66

7-
from sqlmodel import Session, create_engine
7+
from sqlmodel import Session, create_engine, select
88

99
from policyengine_us_data.storage import STORAGE_FOLDER
1010
from policyengine_us_data.db.create_database_tables import (
@@ -313,16 +313,14 @@ def _upsert_target(
313313
source: str,
314314
notes: Optional[str] = None,
315315
) -> None:
316-
existing_target = (
317-
session.query(Target)
318-
.filter(
316+
existing_target = session.exec(
317+
select(Target).where(
319318
Target.stratum_id == stratum_id,
320319
Target.variable == variable,
321320
Target.period == period,
322321
Target.reform_id == 0,
323322
)
324-
.first()
325-
)
323+
).first()
326324
if existing_target:
327325
existing_target.value = value
328326
existing_target.source = source
@@ -347,14 +345,12 @@ def _get_or_create_national_domain_stratum(
347345
session: Session, national_filer_stratum_id: int, variable: str
348346
) -> Stratum:
349347
note = f"National filers with {variable} > 0"
350-
stratum = (
351-
session.query(Stratum)
352-
.filter(
348+
stratum = session.exec(
349+
select(Stratum).where(
353350
Stratum.parent_stratum_id == national_filer_stratum_id,
354351
Stratum.notes == note,
355352
)
356-
.first()
357-
)
353+
).first()
358354
if stratum:
359355
return stratum
360356

@@ -751,14 +747,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
751747
filer_strata = {"national": None, "state": {}, "district": {}}
752748

753749
# National filer stratum - check if it exists first
754-
national_filer_stratum = (
755-
session.query(Stratum)
756-
.filter(
750+
national_filer_stratum = session.exec(
751+
select(Stratum).where(
757752
Stratum.parent_stratum_id == geo_strata["national"],
758753
Stratum.notes == "United States - Tax Filers",
759754
)
760-
.first()
761-
)
755+
).first()
762756

763757
if not national_filer_stratum:
764758
national_filer_stratum = Stratum(
@@ -780,14 +774,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
780774
# State filer strata
781775
for state_fips, state_geo_stratum_id in geo_strata["state"].items():
782776
# Check if state filer stratum exists
783-
state_filer_stratum = (
784-
session.query(Stratum)
785-
.filter(
777+
state_filer_stratum = session.exec(
778+
select(Stratum).where(
786779
Stratum.parent_stratum_id == state_geo_stratum_id,
787780
Stratum.notes == f"State FIPS {state_fips} - Tax Filers",
788781
)
789-
.first()
790-
)
782+
).first()
791783

792784
if not state_filer_stratum:
793785
state_filer_stratum = Stratum(
@@ -814,15 +806,13 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
814806
# District filer strata
815807
for district_geoid, district_geo_stratum_id in geo_strata["district"].items():
816808
# Check if district filer stratum exists
817-
district_filer_stratum = (
818-
session.query(Stratum)
819-
.filter(
809+
district_filer_stratum = session.exec(
810+
select(Stratum).where(
820811
Stratum.parent_stratum_id == district_geo_stratum_id,
821812
Stratum.notes
822813
== f"Congressional District {district_geoid} - Tax Filers",
823814
)
824-
.first()
825-
)
815+
).first()
826816

827817
if not district_filer_stratum:
828818
district_filer_stratum = Stratum(
@@ -917,14 +907,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
917907
]
918908

919909
# Check if stratum already exists
920-
existing_stratum = (
921-
session.query(Stratum)
922-
.filter(
910+
existing_stratum = session.exec(
911+
select(Stratum).where(
923912
Stratum.parent_stratum_id == parent_stratum_id,
924913
Stratum.notes == note,
925914
)
926-
.first()
927-
)
915+
).first()
928916

929917
if existing_stratum:
930918
new_stratum = existing_stratum
@@ -964,15 +952,13 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
964952
("tax_unit_count", count_value),
965953
("eitc", amount_value),
966954
]:
967-
existing_target = (
968-
session.query(Target)
969-
.filter(
955+
existing_target = session.exec(
956+
select(Target).where(
970957
Target.stratum_id == new_stratum.stratum_id,
971958
Target.variable == variable,
972959
Target.period == year,
973960
)
974-
.first()
975-
)
961+
).first()
976962

977963
if existing_target:
978964
existing_target.value = value
@@ -1047,14 +1033,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
10471033
note = f"{geo_description} filers with {amount_variable_name} > 0"
10481034

10491035
# Check if child stratum already exists
1050-
existing_stratum = (
1051-
session.query(Stratum)
1052-
.filter(
1036+
existing_stratum = session.exec(
1037+
select(Stratum).where(
10531038
Stratum.parent_stratum_id == parent_stratum_id,
10541039
Stratum.notes == note,
10551040
)
1056-
.first()
1057-
)
1041+
).first()
10581042

10591043
if existing_stratum:
10601044
child_stratum = existing_stratum
@@ -1119,15 +1103,13 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
11191103
(count_variable_name, count_value),
11201104
(amount_variable_name, amount_value),
11211105
]:
1122-
existing_target = (
1123-
session.query(Target)
1124-
.filter(
1106+
existing_target = session.exec(
1107+
select(Target).where(
11251108
Target.stratum_id == child_stratum.stratum_id,
11261109
Target.variable == variable,
11271110
Target.period == year,
11281111
)
1129-
.first()
1130-
)
1112+
).first()
11311113

11321114
if existing_target:
11331115
existing_target.value = value
@@ -1170,15 +1152,13 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
11701152
)
11711153

11721154
# Check if target already exists
1173-
existing_target = (
1174-
session.query(Target)
1175-
.filter(
1155+
existing_target = session.exec(
1156+
select(Target).where(
11761157
Target.stratum_id == stratum.stratum_id,
11771158
Target.variable == "adjusted_gross_income",
11781159
Target.period == year,
11791160
)
1180-
.first()
1181-
)
1161+
).first()
11821162

11831163
if existing_target:
11841164
existing_target.value = agi_values.iloc[i][["target_value"]].values[0]
@@ -1211,14 +1191,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
12111191
note = f"National filers, AGI >= {agi_income_lower}, AGI < {agi_income_upper}"
12121192

12131193
# Check if national AGI stratum already exists
1214-
nat_stratum = (
1215-
session.query(Stratum)
1216-
.filter(
1194+
nat_stratum = session.exec(
1195+
select(Stratum).where(
12171196
Stratum.parent_stratum_id == filer_strata["national"],
12181197
Stratum.notes == note,
12191198
)
1220-
.first()
1221-
)
1199+
).first()
12221200

12231201
if not nat_stratum:
12241202
nat_stratum = Stratum(
@@ -1296,14 +1274,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
12961274
continue # Skip if not state or district (shouldn't happen, but defensive)
12971275

12981276
# Check if stratum already exists
1299-
existing_stratum = (
1300-
session.query(Stratum)
1301-
.filter(
1277+
existing_stratum = session.exec(
1278+
select(Stratum).where(
13021279
Stratum.parent_stratum_id == parent_stratum_id,
13031280
Stratum.notes == note,
13041281
)
1305-
.first()
1306-
)
1282+
).first()
13071283

13081284
if existing_stratum:
13091285
new_stratum = existing_stratum
@@ -1331,15 +1307,13 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
13311307
session.flush()
13321308

13331309
# Check if target already exists and update or create it
1334-
existing_target = (
1335-
session.query(Target)
1336-
.filter(
1310+
existing_target = session.exec(
1311+
select(Target).where(
13371312
Target.stratum_id == new_stratum.stratum_id,
13381313
Target.variable == "person_count",
13391314
Target.period == year,
13401315
)
1341-
.first()
1342-
)
1316+
).first()
13431317

13441318
if existing_target:
13451319
existing_target.value = person_count

0 commit comments

Comments
 (0)