Skip to content

Commit 427a9ef

Browse files
committed
chore: introduce FactorySettings to manage some global settings
1 parent b8db2cc commit 427a9ef

18 files changed

Lines changed: 471 additions & 90 deletions

File tree

diracx-core/src/diracx/core/config/sources.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import asyncio
99
import logging
10-
import os
1110
from datetime import datetime, timezone
1211
from pathlib import Path
1312
from tempfile import TemporaryDirectory
@@ -76,7 +75,10 @@ def __init_subclass__(cls) -> None:
7675

7776
@classmethod
7877
def create(cls):
79-
return cls.create_from_url(backend_url=os.environ["DIRACX_CONFIG_BACKEND_URL"])
78+
# Avoid circular import
79+
from diracx.core.settings import FactorySettings
80+
81+
return cls.create_from_url(backend_url=FactorySettings().config_backend_url)
8082

8183
@classmethod
8284
def create_from_url(

diracx-core/src/diracx/core/settings.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414

1515
import contextlib
1616
import json
17+
import os
1718
from collections.abc import AsyncIterator
1819
from pathlib import Path
1920
from typing import Annotated, Any, Self, TypeVar, cast
2021

22+
import dotenv
2123
from cryptography.fernet import Fernet
2224
from joserfc.jwk import KeySet, KeySetSerialization
2325
from pydantic import (
@@ -29,14 +31,18 @@
2931
SecretStr,
3032
TypeAdapter,
3133
UrlConstraints,
34+
field_validator,
3235
model_validator,
3336
)
3437
from pydantic_settings import BaseSettings, SettingsConfigDict
3538
from signurlarity.aio.client import AsyncClient
3639
from signurlarity.exceptions import SignurlarityError
3740

41+
from .config.sources import ConfigSourceUrl
42+
from .extensions import DiracEntryPoint, select_from_extension
3843
from .properties import SecurityProperty
3944
from .s3 import s3_bucket_exists
45+
from .utils import dotenv_files_from_environment
4046

4147
T = TypeVar("T")
4248

@@ -350,3 +356,111 @@ def s3_client(self) -> AsyncClient:
350356
if self._client is None:
351357
raise RuntimeError("S3 client accessed before lifetime function")
352358
return self._client
359+
360+
361+
class FactorySettings(ServiceSettingsBase):
362+
"""Factory settings.
363+
364+
Settings which do not fit into dedicated classes,
365+
or are dynamically generated.
366+
"""
367+
368+
# We want to be able to read both from specific environment variables
369+
# but also to create the object directly with the attribute name
370+
# https://pydantic.dev/docs/validation/latest/concepts/alias#validation
371+
model_config = SettingsConfigDict(
372+
use_attribute_docstrings=True, validate_by_alias=True, validate_by_name=True
373+
)
374+
375+
config_backend_url: ConfigSourceUrl | None = Field(
376+
default=None,
377+
validation_alias="DIRACX_CONFIG_BACKEND_URL",
378+
)
379+
"""The URL of the configuration backend.
380+
"""
381+
382+
legacy_exchange_hashed_api_key: str = Field(
383+
default="", validation_alias="DIRACX_LEGACY_EXCHANGE_HASHED_API_KEY"
384+
)
385+
"""The hashed API key for the legacy exchange endpoint.
386+
"""
387+
388+
tasks_redis_url: str = Field(
389+
default="redis://localhost", validation_alias="DIRACX_TASKS_REDIS_URL"
390+
)
391+
"""The url for the redis server to manage tasks"""
392+
393+
enabled_services: dict[str, bool] = Field(default_factory=dict)
394+
"""The following environment variables dictates which routers are enabled."""
395+
396+
opensearch_dbs: dict[str, str] = Field(default_factory=dict)
397+
"""The following environment variables configure the OpenSearch database connections."""
398+
399+
sql_dbs: dict[str, str] = Field(default_factory=dict)
400+
"""The following environment variables configure the SQL database connections."""
401+
402+
@model_validator(mode="before")
403+
@classmethod
404+
def load_dotenv_files(cls, data: Any) -> Any:
405+
"""Load dotenv files before reading settings from environment."""
406+
for env_file in dotenv_files_from_environment("DIRACX_SERVICE_DOTENV"):
407+
if not dotenv.load_dotenv(env_file):
408+
raise NotImplementedError(f"Could not load dotenv file {env_file}")
409+
return data
410+
411+
@field_validator("enabled_services", mode="before")
412+
@classmethod
413+
def build_enabled_services(cls, value: Any) -> dict[str, bool]:
414+
"""Build enabled services from the installed service entry points."""
415+
enabled_services: dict[str, bool] = {
416+
entry_point.name: True
417+
for entry_point in select_from_extension(group=DiracEntryPoint.SERVICES)
418+
if "well-known" not in entry_point.name
419+
}
420+
421+
for service_name in enabled_services:
422+
env_name = f"DIRACX_SERVICE_{service_name.upper()}_ENABLED"
423+
if env_value := os.environ.get(env_name):
424+
enabled_services[service_name] = TypeAdapter(bool).validate_python(
425+
env_value
426+
)
427+
428+
if isinstance(value, dict):
429+
enabled_services.update(value)
430+
return enabled_services
431+
432+
@field_validator("opensearch_dbs", mode="before")
433+
@classmethod
434+
def build_opensearch_dbs(cls, value: Any) -> dict[str, str]:
435+
"""Build OpenSearch database URLs from the installed entry points."""
436+
opensearch_dbs: dict[str, str] = {
437+
entry_point.name: ""
438+
for entry_point in select_from_extension(group=DiracEntryPoint.OS_DB)
439+
}
440+
441+
for db_name in opensearch_dbs:
442+
env_name = f"DIRACX_OS_DB_{db_name.upper()}"
443+
if env_value := os.environ.get(env_name):
444+
opensearch_dbs[db_name] = env_value
445+
446+
if isinstance(value, dict):
447+
opensearch_dbs.update(value)
448+
return opensearch_dbs
449+
450+
@field_validator("sql_dbs", mode="before")
451+
@classmethod
452+
def build_sql_dbs(cls, value: Any) -> dict[str, str]:
453+
"""Build SQL database URLs from the installed entry points."""
454+
sql_dbs: dict[str, str] = {
455+
entry_point.name: ""
456+
for entry_point in select_from_extension(group=DiracEntryPoint.SQL_DB)
457+
}
458+
459+
for db_name in sql_dbs:
460+
env_name = f"DIRACX_DB_URL_{db_name.upper()}"
461+
if env_value := os.environ.get(env_name):
462+
sql_dbs[db_name] = env_value
463+
464+
if isinstance(value, dict):
465+
sql_dbs.update(value)
466+
return sql_dbs

diracx-db/src/diracx/db/os/utils.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import contextlib
44
import json
55
import logging
6-
import os
76
from abc import ABCMeta, abstractmethod
87
from collections.abc import AsyncIterator
98
from contextvars import ContextVar
@@ -14,6 +13,7 @@
1413

1514
from diracx.core.exceptions import InvalidQueryError
1615
from diracx.core.extensions import DiracEntryPoint, select_from_extension
16+
from diracx.core.settings import FactorySettings
1717
from diracx.db.exceptions import DBUnavailableError
1818

1919
logger = logging.getLogger(__name__)
@@ -38,7 +38,8 @@ class BaseOSDB(metaclass=ABCMeta):
3838
This method returns a dictionary of database names to connection parameters.
3939
The available databases are determined by the `diracx.dbs.os` entrypoint in
4040
the `pyproject.toml` file and the connection parameters are taken from the
41-
environment variables prefixed with `DIRACX_OS_DB_{DB_NAME}`.
41+
`opensearch_dbs` field in FactorySettings, which reads from environment variables
42+
prefixed with `DIRACX_OS_DB_{DB_NAME}`.
4243
4344
If extensions to DiracX are being used, there can be multiple implementations
4445
of the same database. To list the available implementations use
@@ -104,19 +105,26 @@ def available_implementations(cls, db_name: str) -> list[type[BaseOSDB]]:
104105
def available_urls(cls) -> dict[str, dict[str, Any]]:
105106
"""Return a dict of available OpenSearch database urls.
106107
107-
The list of available URLs is determined by environment variables
108+
The list of available URLs is determined by the opensearch_dbs field
109+
in FactorySettings, which reads from environment variables
108110
prefixed with ``DIRACX_OS_DB_{DB_NAME}``.
109111
"""
112+
factory_settings = FactorySettings()
113+
opensearch_dbs = factory_settings.opensearch_dbs
114+
110115
conn_kwargs: dict[str, dict[str, Any]] = {}
111116
for entry_point in select_from_extension(group=DiracEntryPoint.OS_DB):
112117
db_name = entry_point.name
113-
var_name = f"DIRACX_OS_DB_{entry_point.name.upper()}"
114-
if var_name in os.environ:
115-
try:
116-
conn_kwargs[db_name] = json.loads(os.environ[var_name])
117-
except Exception:
118-
logger.error("Error loading connection parameters for %s", db_name)
119-
raise
118+
# Get the field value from the OpenSearchDBSettings model
119+
if field_value := opensearch_dbs.get(db_name):
120+
if field_value:
121+
try:
122+
conn_kwargs[db_name] = json.loads(field_value)
123+
except Exception:
124+
logger.error(
125+
"Error loading connection parameters for %s", db_name
126+
)
127+
raise
120128
return conn_kwargs
121129

122130
@classmethod

diracx-db/src/diracx/db/sql/utils/base.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import contextlib
44
import logging
5-
import os
65
import re
76
from abc import ABCMeta
87
from collections.abc import AsyncIterator
@@ -53,8 +52,9 @@ class BaseSQLDB(metaclass=ABCMeta):
5352
The available databases are discovered by calling `BaseSQLDB.available_urls`.
5453
This method returns a mapping of database names to connection URLs. The
5554
available databases are determined by the `diracx.dbs.sql` entrypoint in the
56-
`pyproject.toml` file and the connection URLs are taken from the environment
57-
variables of the form `DIRACX_DB_URL_<db-name>`.
55+
`pyproject.toml` file and the connection URLs are taken from the
56+
`sql_dbs` field in FactorySettings, which reads from environment variables
57+
of the form `DIRACX_DB_URL_<db-name>`.
5858
5959
If extensions to DiracX are being used, there can be multiple implementations
6060
of the same database. To list the available implementations use
@@ -125,16 +125,21 @@ def available_implementations(cls, db_name: str) -> list[type["BaseSQLDB"]]:
125125
def available_urls(cls) -> dict[str, str]:
126126
"""Return a dict of available database urls.
127127
128-
The list of available URLs is determined by environment variables
128+
The list of available URLs is determined by the sql_dbs field
129+
in FactorySettings, which reads from environment variables
129130
prefixed with ``DIRACX_DB_URL_{DB_NAME}``.
130131
"""
132+
from diracx.core.settings import FactorySettings
133+
134+
factory_settings = FactorySettings()
135+
sql_dbs = factory_settings.sql_dbs
136+
131137
db_urls: dict[str, str] = {}
132138
for entry_point in select_from_extension(group=DiracEntryPoint.SQL_DB):
133139
db_name = entry_point.name
134-
var_name = f"DIRACX_DB_URL_{entry_point.name.upper()}"
135-
if var_name in os.environ:
140+
# Get the field value from the SqlDBSettings model
141+
if db_url := sql_dbs.get(db_name):
136142
try:
137-
db_url = os.environ[var_name]
138143
if db_url == "sqlite+aiosqlite:///:memory:":
139144
db_urls[db_name] = db_url
140145
# pydantic does not allow for underscore in scheme

diracx-logic/src/diracx/logic/__main__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,15 @@ async def delete_jwk(args):
9393
async def cleanup_authdb(args):
9494
"""Maintain AuthDB partitions and remove expired flows."""
9595
logger.info("Maintaining AuthDB partitions and removing expired flows")
96-
import os
9796

98-
from diracx.core.settings import AuthSettings
97+
from diracx.core.settings import AuthSettings, FactorySettings
9998
from diracx.db.sql import AuthDB
10099
from diracx.logic.auth.management import cleanup_expired_data
101100

102101
settings = AuthSettings()
103-
db_url = os.environ["DIRACX_DB_URL_AUTHDB"]
102+
factory_settings = FactorySettings()
103+
db_url = factory_settings.sql_dbs.AuthDB
104+
104105
db = AuthDB(db_url)
105106
async with db.engine_context():
106107
async with db:

diracx-routers/src/diracx/routers/auth/token.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import logging
6-
import os
76
from http import HTTPStatus
87
from typing import Annotated, Literal
98

@@ -21,7 +20,7 @@
2120
RefreshTokenPayload,
2221
TokenResponse,
2322
)
24-
from diracx.core.settings import AuthSettings
23+
from diracx.core.settings import AuthSettings, FactorySettings
2524
from diracx.db.sql import AuthDB
2625
from diracx.logic.auth import create_token
2726
from diracx.logic.auth import get_oidc_token as get_oidc_token_bl
@@ -182,6 +181,7 @@ async def perform_legacy_exchange(
182181
auth_db: AuthDB,
183182
available_properties: AvailableSecurityProperties,
184183
settings: AuthSettings,
184+
factory_settings: FactorySettings,
185185
config: Config,
186186
all_access_policies: Annotated[
187187
dict[str, BaseAccessPolicy], Depends(BaseAccessPolicy.all_used_access_policies)
@@ -193,9 +193,7 @@ async def perform_legacy_exchange(
193193
This route is disabled if DIRACX_LEGACY_EXCHANGE_HASHED_API_KEY is not set
194194
in the environment.
195195
"""
196-
if not (
197-
expected_api_key := os.environ.get("DIRACX_LEGACY_EXCHANGE_HASHED_API_KEY")
198-
):
196+
if not (expected_api_key := factory_settings.legacy_exchange_hashed_api_key):
199197
raise HTTPException(
200198
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
201199
detail="Legacy exchange is not enabled",

diracx-routers/src/diracx/routers/factory.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@
66

77
import inspect
88
import logging
9-
import os
109
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Sequence
1110
from functools import partial
1211
from http import HTTPStatus
1312
from importlib.metadata import EntryPoint, EntryPoints, entry_points
1413
from logging import Formatter, StreamHandler
1514
from typing import Any, TypeVar, cast
1615

17-
import dotenv
1816
from cachetools import TTLCache
1917
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
2018
from fastapi.dependencies.models import Dependant
@@ -24,16 +22,14 @@
2422
from fastapi.responses import JSONResponse, Response
2523
from fastapi.routing import APIRoute
2624
from packaging.version import InvalidVersion, parse
27-
from pydantic import TypeAdapter
2825
from starlette.middleware.base import BaseHTTPMiddleware
2926
from uvicorn.logging import AccessFormatter, DefaultFormatter
3027

3128
from diracx.core.config import ConfigSource
3229
from diracx.core.exceptions import DiracError, DiracHttpResponseError, NotReadyError
3330
from diracx.core.extensions import DiracEntryPoint, select_from_extension
34-
from diracx.core.settings import ServiceSettingsBase
31+
from diracx.core.settings import FactorySettings, ServiceSettingsBase
3532
from diracx.core.sources import AsyncCacheableSource
36-
from diracx.core.utils import dotenv_files_from_environment
3733
from diracx.db.exceptions import DBUnavailableError
3834
from diracx.db.os.utils import BaseOSDB
3935
from diracx.db.sql.utils import BaseSQLDB
@@ -144,7 +140,6 @@ def create_app_inner(
144140
# Please see ServiceSettingsBase for more details
145141

146142
available_settings_classes: set[type[ServiceSettingsBase]] = set()
147-
148143
for service_settings in all_service_settings:
149144
cls = type(service_settings)
150145
assert cls not in available_settings_classes
@@ -388,17 +383,12 @@ def create_app() -> DiracFastAPI:
388383
We attempt to load each setting classes to make sure that the
389384
settings are correctly defined.
390385
"""
391-
for env_file in dotenv_files_from_environment("DIRACX_SERVICE_DOTENV"):
392-
logger.debug("Loading dotenv file: %s", env_file)
393-
if not dotenv.load_dotenv(env_file):
394-
raise NotImplementedError(f"Could not load dotenv file {env_file}")
395-
396386
# Load all available routers
397387
enabled_systems = set()
398388
settings_classes = set()
389+
factory_settings = FactorySettings()
399390
for entry_point in select_from_extension(group=DiracEntryPoint.SERVICES):
400-
env_var = f"DIRACX_SERVICE_{entry_point.name.upper()}_ENABLED"
401-
enabled = TypeAdapter(bool).validate_json(os.environ.get(env_var, "true"))
391+
enabled = factory_settings.enabled_services.get(entry_point.name, True)
402392
logger.debug("Found service %r: enabled=%s", entry_point, enabled)
403393
if not enabled:
404394
continue
@@ -485,6 +475,7 @@ async def validation_error_handler(request: Request, exc: RequestValidationError
485475
def find_dependents(
486476
obj: APIRouter | Iterable[Dependant], cls: type[T]
487477
) -> Iterable[type[T]]:
478+
488479
if isinstance(obj, APIRouter):
489480
# TODO: Support dependencies of the router itself
490481
# yield from find_dependents(obj.dependencies, cls)

0 commit comments

Comments
 (0)