Skip to content

Commit 98f4068

Browse files
authored
Merge pull request #4619 from fedspendingtransparency/ftr/dev-14236-pydantic-upgrade
[DEV-14236] Pydantic upgrade
2 parents 36c6910 + 45352ac commit 98f4068

File tree

10 files changed

+595
-387
lines changed

10 files changed

+595
-387
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ dependencies = [
5252
"psutil==5.9.*",
5353
"psycopg>=3.3.3",
5454
"py-gfm==2.0.0",
55-
"pydantic[dotenv]==1.9.*",
55+
"pydantic==2.12",
56+
"pydantic-settings>=2.13.1",
5657
"python-json-logger==2.0.7",
5758
"requests==2.31.*",
5859
"retrying==1.3.4",

usaspending_api/config/envs/default.py

Lines changed: 44 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,9 @@
1313
import pathlib
1414
from typing import Any, ClassVar, Union
1515

16-
from pydantic import (
17-
AnyHttpUrl,
18-
BaseSettings,
19-
PostgresDsn,
20-
SecretStr,
21-
root_validator,
22-
)
16+
from pydantic import SecretStr, model_validator
17+
from pydantic.networks import AnyHttpUrl, PostgresDsn
18+
from pydantic_settings import BaseSettings, SettingsConfigDict
2319

2420
from usaspending_api.config.utils import (
2521
ENV_SPECIFIC_OVERRIDE,
@@ -134,20 +130,22 @@ def _validate_database_conf(
134130
)
135131

136132
if enough_parts:
137-
pg_dsn = PostgresDsn(
138-
url=None,
133+
try:
134+
_port = int(values[f"{resource_conf_prefix}_PORT"])
135+
except (ValueError, TypeError):
136+
_port = None
137+
138+
pg_dsn = PostgresDsn.build(
139139
scheme=values[f"{resource_conf_prefix}_SCHEME"],
140-
user=values[f"{resource_conf_prefix}_USER"],
141-
password=values[
142-
f"{resource_conf_prefix}_PASSWORD"
143-
].get_secret_value(),
144-
host=values[f"{resource_conf_prefix}_HOST"],
145-
port=values[f"{resource_conf_prefix}_PORT"],
146-
path=(
147-
"/" + values[f"{resource_conf_prefix}_NAME"]
148-
if values[f"{resource_conf_prefix}_NAME"]
149-
else None
140+
username=values[f"{resource_conf_prefix}_USER"],
141+
password=(
142+
values[f"{resource_conf_prefix}_PASSWORD"].get_secret_value()
143+
if isinstance(values[f"{resource_conf_prefix}_PASSWORD"], SecretStr)
144+
else values[f"{resource_conf_prefix}_PASSWORD"]
150145
),
146+
host=values[f"{resource_conf_prefix}_HOST"],
147+
port=_port,
148+
path=values.get(f"{resource_conf_prefix}_NAME"),
151149
)
152150
values = eval_default_factory_from_root_validator(
153151
cls, values, url_conf_name, lambda: str(pg_dsn)
@@ -159,7 +157,7 @@ def _validate_database_conf(
159157

160158
# noinspection PyMethodParameters
161159
# Pydantic returns a classmethod for its validators, so the cls param is correct
162-
@root_validator
160+
@model_validator(mode="before")
163161
def _DATABASE_URL_and_parts_factory(cls, values: dict[str, Any]) -> dict[str, Any]:
164162
"""A root validator to backfill DATABASE_URL and USASPENDING_DB_* part config vars and validate that they are
165163
all consistent.
@@ -169,6 +167,8 @@ def _DATABASE_URL_and_parts_factory(cls, values: dict[str, Any]) -> dict[str, An
169167
- ALSO validates that the parts and whole string are consistent. A ``ValueError`` is thrown if found to
170168
be inconsistent, which will in turn raise a ``pydantic.ValidationError`` at configuration time.
171169
"""
170+
default_fields = {name: field.default for name, field in cls.model_fields.items()}
171+
values = {**default_fields, **values}
172172
# noinspection PyArgumentList
173173
cls._validate_database_conf(
174174
cls=cls,
@@ -181,7 +181,7 @@ def _DATABASE_URL_and_parts_factory(cls, values: dict[str, Any]) -> dict[str, An
181181

182182
# noinspection PyMethodParameters
183183
# Pydantic returns a classmethod for its validators, so the cls param is correct
184-
@root_validator
184+
@model_validator(mode="before")
185185
def _BROKER_DB_and_parts_factory(cls, values: dict[str, Any]) -> dict[str, Any]:
186186
"""A root validator to backfill BROKER_DB and BROKER_DB_* part config vars and validate
187187
that they are all consistent.
@@ -191,6 +191,8 @@ def _BROKER_DB_and_parts_factory(cls, values: dict[str, Any]) -> dict[str, Any]:
191191
- ALSO validates that the parts and whole string are consistent. A ``ValueError`` is thrown if found to
192192
be inconsistent, which will in turn raise a ``pydantic.ValidationError`` at configuration time.
193193
"""
194+
default_fields = {name: field.default for name, field in cls.model_fields.items()}
195+
values = {**default_fields, **values}
194196
# noinspection PyArgumentList
195197
cls._validate_database_conf(
196198
cls=cls,
@@ -203,17 +205,17 @@ def _BROKER_DB_and_parts_factory(cls, values: dict[str, Any]) -> dict[str, Any]:
203205

204206
# ==== [Elasticsearch] ====
205207
# Where to connect to elasticsearch.
206-
ES_HOSTNAME: str = None # FACTORY_PROVIDED_VALUE. See below validator-factory
208+
ES_HOSTNAME: str | None = None # FACTORY_PROVIDED_VALUE. See below validator-factory
207209
ES_SCHEME: str = "https"
208210
ES_HOST: str = ENV_SPECIFIC_OVERRIDE
209-
ES_PORT: str = None
210-
ES_USER: str = None
211-
ES_PASSWORD: SecretStr = None
212-
ES_NAME: str = None
211+
ES_PORT: str | None = None
212+
ES_USER: str | None = None
213+
ES_PASSWORD: SecretStr | None = None
214+
ES_NAME: str | None = None
213215

214216
# noinspection PyMethodParameters
215217
# Pydantic returns a classmethod for its validators, so the cls param is correct
216-
@root_validator
218+
@model_validator(mode="before")
217219
def _ES_HOSTNAME_and_parts_factory(cls, values: dict[str, Any]) -> dict[str, Any]:
218220
"""A root validator to backfill ES_HOSTNAME and ES_* part config vars and validate that they are
219221
all consistent.
@@ -223,6 +225,8 @@ def _ES_HOSTNAME_and_parts_factory(cls, values: dict[str, Any]) -> dict[str, Any
223225
- ALSO validates that the parts and whole string are consistent. A ``ValueError`` is thrown if found to
224226
be inconsistent, which will in turn raise a ``pydantic.ValidationError`` at configuration time.
225227
"""
228+
default_fields = {name: field.default for name, field in cls.model_fields.items()}
229+
values = {**default_fields, **values}
226230
# noinspection PyArgumentList
227231
cls._validate_http_url(
228232
cls=cls,
@@ -251,9 +255,7 @@ def _validate_http_url(
251255
# - it should take precedence
252256
# - its values will be used to backfill any missing URL parts stored as separate config vars
253257
if is_full_url_provided:
254-
values = backfill_url_parts_config(
255-
cls, url_conf_name, resource_conf_prefix, values
256-
)
258+
values = backfill_url_parts_config(cls, url_conf_name, resource_conf_prefix, values)
257259

258260
# If the full URL config is not provided, try to build-it-up from provided parts, then set the full URL
259261
if not is_full_url_provided:
@@ -268,21 +270,16 @@ def _validate_http_url(
268270

269271
if enough_parts:
270272
http_url = AnyHttpUrl(
271-
url=None,
272273
scheme=values[f"{resource_conf_prefix}_SCHEME"],
273-
user=values[f"{resource_conf_prefix}_USER"],
274+
username=values[f"{resource_conf_prefix}_USER"],
274275
password=(
275276
values[f"{resource_conf_prefix}_PASSWORD"].get_secret_value()
276277
if values[f"{resource_conf_prefix}_PASSWORD"]
277278
else None
278279
),
279280
host=values[f"{resource_conf_prefix}_HOST"],
280281
port=values[f"{resource_conf_prefix}_PORT"],
281-
path=(
282-
"/" + values[f"{resource_conf_prefix}_NAME"]
283-
if values[f"{resource_conf_prefix}_NAME"]
284-
else None
285-
),
282+
path=values.get(f"{resource_conf_prefix}_NAME"),
286283
)
287284
values = eval_default_factory_from_root_validator(
288285
cls, values, url_conf_name, lambda: str(http_url)
@@ -298,7 +295,7 @@ def _validate_http_url(
298295
# Those clusters are the only place we currently need this variable,
299296
# If you write code that depends on this config, make sure you
300297
# set BRANCH as an environment variable on your machine
301-
BRANCH: str = os.environ.get("BRANCH")
298+
BRANCH: str | None = os.environ.get("BRANCH")
302299

303300
# SPARK_SCHEDULER_MODE = "FAIR" # if used with weighted pools, could allow round-robin tasking of simultaneous jobs
304301
# TODO: have to deal with this if really wanting balanced (FAIR) task execution
@@ -361,10 +358,10 @@ def _validate_http_url(
361358
AWS_ACCESS_KEY: SecretStr = ENV_SPECIFIC_OVERRIDE
362359
AWS_SECRET_KEY: SecretStr = ENV_SPECIFIC_OVERRIDE
363360
# Setting AWS_PROFILE to None so boto3 doesn't try to pick up the placeholder string as an actual profile to find
364-
AWS_PROFILE: str = None # USER_SPECIFIC_OVERRIDE
365-
SPARK_S3_BUCKET: str = os.environ.get("SPARK_S3_BUCKET")
366-
BULK_DOWNLOAD_S3_BUCKET_NAME: str = os.environ.get("BULK_DOWNLOAD_S3_BUCKET_NAME")
367-
DATABASE_DOWNLOAD_S3_BUCKET_NAME: str = os.environ.get(
361+
AWS_PROFILE: str | None = None # USER_SPECIFIC_OVERRIDE
362+
SPARK_S3_BUCKET: str | None = os.environ.get("SPARK_S3_BUCKET")
363+
BULK_DOWNLOAD_S3_BUCKET_NAME: str | None = os.environ.get("BULK_DOWNLOAD_S3_BUCKET_NAME")
364+
DATABASE_DOWNLOAD_S3_BUCKET_NAME: str | None = os.environ.get(
368365
"DATABASE_DOWNLOAD_S3_BUCKET_NAME"
369366
)
370367
DELTA_LAKE_S3_PATH: str = "data/delta" # path within SPARK_S3_BUCKET where Delta output data will accumulate
@@ -380,9 +377,8 @@ def _validate_http_url(
380377
COVID19_DOWNLOAD_README_OBJECT_KEY: str = (
381378
f"files/{COVID19_DOWNLOAD_README_FILE_NAME}"
382379
)
383-
384-
class Config:
385-
pass
386-
# supporting use of a user-provided (ang git-ignored) .env file for overrides
387-
env_file = str(_PROJECT_ROOT_DIR / ".env")
388-
env_file_encoding = "utf-8"
380+
model_config = SettingsConfigDict(
381+
env_file=str(_PROJECT_ROOT_DIR / ".env"),
382+
env_file_encoding="utf-8",
383+
extra="allow",
384+
)

usaspending_api/config/envs/local.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
# - Set config variables to DefaultConfig.USER_SPECIFIC_OVERRIDE where there is expected to be a
88
# user-provided a config value for a variable (e.g. in the ../.env file)
99
########################################################################################################################
10-
from typing import ClassVar
10+
from typing import Any, ClassVar
1111

12-
from pydantic import root_validator
12+
from pydantic import model_validator
1313
from pydantic.types import SecretStr
14-
from usaspending_api.config.envs.default import DefaultConfig, _PROJECT_ROOT_DIR
14+
15+
from usaspending_api.config.envs.default import _PROJECT_ROOT_DIR, DefaultConfig
1516
from usaspending_api.config.utils import (
16-
USER_SPECIFIC_OVERRIDE,
1717
FACTORY_PROVIDED_VALUE,
18+
USER_SPECIFIC_OVERRIDE,
1819
eval_default_factory_from_root_validator,
1920
)
2021

@@ -90,20 +91,24 @@ class LocalConfig(DefaultConfig):
9091
USE_AWS: bool = False
9192
AWS_ACCESS_KEY: SecretStr = MINIO_ACCESS_KEY
9293
AWS_SECRET_KEY: SecretStr = MINIO_SECRET_KEY
93-
AWS_PROFILE: str = None
94+
AWS_PROFILE: str | None = None
9495
AWS_REGION: str = ""
9596
SPARK_S3_BUCKET: str = "data"
9697
BULK_DOWNLOAD_S3_BUCKET_NAME: str = "bulk-download"
97-
DATABASE_DOWNLOAD_S3_BUCKET_NAME = "dti-usaspending-db"
98+
DATABASE_DOWNLOAD_S3_BUCKET_NAME: str = "dti-usaspending-db"
9899

99100
# Since this config values is built by composing others, we want to late/lazily-evaluate their values,
100101
# in case the declared value is overridden by a shell env var or .env file value
101-
AWS_S3_ENDPOINT: str = FACTORY_PROVIDED_VALUE # See below validator-based factory
102+
AWS_S3_ENDPOINT: str | None = FACTORY_PROVIDED_VALUE # See below validator-based factory
103+
104+
@model_validator(mode="before")
105+
def _AWS_S3_ENDPOINT_factory(cls, values: dict[str, Any]) -> dict[str, Any]:
106+
# Merge defaults into values
107+
default_fields = {name: field.default for name, field in cls.model_fields.items()}
108+
merged_values = {**default_fields, **values}
102109

103-
@root_validator
104-
def _AWS_S3_ENDPOINT_factory(cls, values):
105-
def factory_func():
106-
return values["MINIO_HOST"] + ":" + values["MINIO_PORT"]
110+
def factory_func() -> str:
111+
return merged_values["MINIO_HOST"] + ":" + merged_values["MINIO_PORT"]
107112

108113
return eval_default_factory_from_root_validator(cls, values, "AWS_S3_ENDPOINT", factory_func)
109114

0 commit comments

Comments
 (0)