1313import pathlib
1414from typing import Any , ClassVar , Union
1515
16- from pydantic import (
17- AnyHttpUrl ,
18- BaseSettings ,
19- PostgresDsn ,
20- SecretStr ,
21- root_validator ,
22- )
16+ from pydantic import SecretStr , model_validator
17+ from pydantic .networks import AnyHttpUrl , PostgresDsn
18+ from pydantic_settings import BaseSettings , SettingsConfigDict
2319
2420from usaspending_api .config .utils import (
2521 ENV_SPECIFIC_OVERRIDE ,
@@ -134,20 +130,22 @@ def _validate_database_conf(
134130 )
135131
136132 if enough_parts :
137- pg_dsn = PostgresDsn (
138- url = None ,
133+ try :
134+ _port = int (values [f"{ resource_conf_prefix } _PORT" ])
135+ except (ValueError , TypeError ):
136+ _port = None
137+
138+ pg_dsn = PostgresDsn .build (
139139 scheme = values [f"{ resource_conf_prefix } _SCHEME" ],
140- user = values [f"{ resource_conf_prefix } _USER" ],
141- password = values [
142- f"{ resource_conf_prefix } _PASSWORD"
143- ].get_secret_value (),
144- host = values [f"{ resource_conf_prefix } _HOST" ],
145- port = values [f"{ resource_conf_prefix } _PORT" ],
146- path = (
147- "/" + values [f"{ resource_conf_prefix } _NAME" ]
148- if values [f"{ resource_conf_prefix } _NAME" ]
149- else None
140+ username = values [f"{ resource_conf_prefix } _USER" ],
141+ password = (
142+ values [f"{ resource_conf_prefix } _PASSWORD" ].get_secret_value ()
143+ if isinstance (values [f"{ resource_conf_prefix } _PASSWORD" ], SecretStr )
144+ else values [f"{ resource_conf_prefix } _PASSWORD" ]
150145 ),
146+ host = values [f"{ resource_conf_prefix } _HOST" ],
147+ port = _port ,
148+ path = values .get (f"{ resource_conf_prefix } _NAME" ),
151149 )
152150 values = eval_default_factory_from_root_validator (
153151 cls , values , url_conf_name , lambda : str (pg_dsn )
@@ -159,7 +157,7 @@ def _validate_database_conf(
159157
160158 # noinspection PyMethodParameters
161159 # Pydantic returns a classmethod for its validators, so the cls param is correct
162- @root_validator
160+ @model_validator ( mode = "before" )
163161 def _DATABASE_URL_and_parts_factory (cls , values : dict [str , Any ]) -> dict [str , Any ]:
164162 """A root validator to backfill DATABASE_URL and USASPENDING_DB_* part config vars and validate that they are
165163 all consistent.
@@ -169,6 +167,8 @@ def _DATABASE_URL_and_parts_factory(cls, values: dict[str, Any]) -> dict[str, An
169167 - ALSO validates that the parts and whole string are consistent. A ``ValueError`` is thrown if found to
170168 be inconsistent, which will in turn raise a ``pydantic.ValidationError`` at configuration time.
171169 """
170+ default_fields = {name : field .default for name , field in cls .model_fields .items ()}
171+ values = {** default_fields , ** values }
172172 # noinspection PyArgumentList
173173 cls ._validate_database_conf (
174174 cls = cls ,
@@ -181,7 +181,7 @@ def _DATABASE_URL_and_parts_factory(cls, values: dict[str, Any]) -> dict[str, An
181181
182182 # noinspection PyMethodParameters
183183 # Pydantic returns a classmethod for its validators, so the cls param is correct
184- @root_validator
184+ @model_validator ( mode = "before" )
185185 def _BROKER_DB_and_parts_factory (cls , values : dict [str , Any ]) -> dict [str , Any ]:
186186 """A root validator to backfill BROKER_DB and BROKER_DB_* part config vars and validate
187187 that they are all consistent.
@@ -191,6 +191,8 @@ def _BROKER_DB_and_parts_factory(cls, values: dict[str, Any]) -> dict[str, Any]:
191191 - ALSO validates that the parts and whole string are consistent. A ``ValueError`` is thrown if found to
192192 be inconsistent, which will in turn raise a ``pydantic.ValidationError`` at configuration time.
193193 """
194+ default_fields = {name : field .default for name , field in cls .model_fields .items ()}
195+ values = {** default_fields , ** values }
194196 # noinspection PyArgumentList
195197 cls ._validate_database_conf (
196198 cls = cls ,
@@ -203,17 +205,17 @@ def _BROKER_DB_and_parts_factory(cls, values: dict[str, Any]) -> dict[str, Any]:
203205
204206 # ==== [Elasticsearch] ====
205207 # Where to connect to elasticsearch.
206- ES_HOSTNAME : str = None # FACTORY_PROVIDED_VALUE. See below validator-factory
208+ ES_HOSTNAME : str | None = None # FACTORY_PROVIDED_VALUE. See below validator-factory
207209 ES_SCHEME : str = "https"
208210 ES_HOST : str = ENV_SPECIFIC_OVERRIDE
209- ES_PORT : str = None
210- ES_USER : str = None
211- ES_PASSWORD : SecretStr = None
212- ES_NAME : str = None
211+ ES_PORT : str | None = None
212+ ES_USER : str | None = None
213+ ES_PASSWORD : SecretStr | None = None
214+ ES_NAME : str | None = None
213215
214216 # noinspection PyMethodParameters
215217 # Pydantic returns a classmethod for its validators, so the cls param is correct
216- @root_validator
218+ @model_validator ( mode = "before" )
217219 def _ES_HOSTNAME_and_parts_factory (cls , values : dict [str , Any ]) -> dict [str , Any ]:
218220 """A root validator to backfill ES_HOSTNAME and ES_* part config vars and validate that they are
219221 all consistent.
@@ -223,6 +225,8 @@ def _ES_HOSTNAME_and_parts_factory(cls, values: dict[str, Any]) -> dict[str, Any
223225 - ALSO validates that the parts and whole string are consistent. A ``ValueError`` is thrown if found to
224226 be inconsistent, which will in turn raise a ``pydantic.ValidationError`` at configuration time.
225227 """
228+ default_fields = {name : field .default for name , field in cls .model_fields .items ()}
229+ values = {** default_fields , ** values }
226230 # noinspection PyArgumentList
227231 cls ._validate_http_url (
228232 cls = cls ,
@@ -251,9 +255,7 @@ def _validate_http_url(
251255 # - it should take precedence
252256 # - its values will be used to backfill any missing URL parts stored as separate config vars
253257 if is_full_url_provided :
254- values = backfill_url_parts_config (
255- cls , url_conf_name , resource_conf_prefix , values
256- )
258+ values = backfill_url_parts_config (cls , url_conf_name , resource_conf_prefix , values )
257259
258260 # If the full URL config is not provided, try to build-it-up from provided parts, then set the full URL
259261 if not is_full_url_provided :
@@ -268,21 +270,16 @@ def _validate_http_url(
268270
269271 if enough_parts :
270272 http_url = AnyHttpUrl (
271- url = None ,
272273 scheme = values [f"{ resource_conf_prefix } _SCHEME" ],
273- user = values [f"{ resource_conf_prefix } _USER" ],
274+ username = values [f"{ resource_conf_prefix } _USER" ],
274275 password = (
275276 values [f"{ resource_conf_prefix } _PASSWORD" ].get_secret_value ()
276277 if values [f"{ resource_conf_prefix } _PASSWORD" ]
277278 else None
278279 ),
279280 host = values [f"{ resource_conf_prefix } _HOST" ],
280281 port = values [f"{ resource_conf_prefix } _PORT" ],
281- path = (
282- "/" + values [f"{ resource_conf_prefix } _NAME" ]
283- if values [f"{ resource_conf_prefix } _NAME" ]
284- else None
285- ),
282+ path = values .get (f"{ resource_conf_prefix } _NAME" ),
286283 )
287284 values = eval_default_factory_from_root_validator (
288285 cls , values , url_conf_name , lambda : str (http_url )
@@ -298,7 +295,7 @@ def _validate_http_url(
298295 # Those clusters are the only place we currently need this variable,
299296 # If you write code that depends on this config, make sure you
300297 # set BRANCH as an environment variable on your machine
301- BRANCH : str = os .environ .get ("BRANCH" )
298+ BRANCH : str | None = os .environ .get ("BRANCH" )
302299
303300 # SPARK_SCHEDULER_MODE = "FAIR" # if used with weighted pools, could allow round-robin tasking of simultaneous jobs
304301 # TODO: have to deal with this if really wanting balanced (FAIR) task execution
@@ -361,10 +358,10 @@ def _validate_http_url(
361358 AWS_ACCESS_KEY : SecretStr = ENV_SPECIFIC_OVERRIDE
362359 AWS_SECRET_KEY : SecretStr = ENV_SPECIFIC_OVERRIDE
363360 # Setting AWS_PROFILE to None so boto3 doesn't try to pick up the placeholder string as an actual profile to find
364- AWS_PROFILE : str = None # USER_SPECIFIC_OVERRIDE
365- SPARK_S3_BUCKET : str = os .environ .get ("SPARK_S3_BUCKET" )
366- BULK_DOWNLOAD_S3_BUCKET_NAME : str = os .environ .get ("BULK_DOWNLOAD_S3_BUCKET_NAME" )
367- DATABASE_DOWNLOAD_S3_BUCKET_NAME : str = os .environ .get (
361+ AWS_PROFILE : str | None = None # USER_SPECIFIC_OVERRIDE
362+ SPARK_S3_BUCKET : str | None = os .environ .get ("SPARK_S3_BUCKET" )
363+ BULK_DOWNLOAD_S3_BUCKET_NAME : str | None = os .environ .get ("BULK_DOWNLOAD_S3_BUCKET_NAME" )
364+ DATABASE_DOWNLOAD_S3_BUCKET_NAME : str | None = os .environ .get (
368365 "DATABASE_DOWNLOAD_S3_BUCKET_NAME"
369366 )
370367 DELTA_LAKE_S3_PATH : str = "data/delta" # path within SPARK_S3_BUCKET where Delta output data will accumulate
@@ -380,9 +377,8 @@ def _validate_http_url(
380377 COVID19_DOWNLOAD_README_OBJECT_KEY : str = (
381378 f"files/{ COVID19_DOWNLOAD_README_FILE_NAME } "
382379 )
383-
384- class Config :
385- pass
386- # supporting use of a user-provided (ang git-ignored) .env file for overrides
387- env_file = str (_PROJECT_ROOT_DIR / ".env" )
388- env_file_encoding = "utf-8"
380+ model_config = SettingsConfigDict (
381+ env_file = str (_PROJECT_ROOT_DIR / ".env" ),
382+ env_file_encoding = "utf-8" ,
383+ extra = "allow" ,
384+ )
0 commit comments