Skip to content

Commit c2cd486

Browse files
authored
Merge pull request #285 from chaen/fix_nested_access_policy
Fix nested access policy and violent crash
2 parents b570671 + 2903afb commit c2cd486

15 files changed

Lines changed: 82 additions & 45 deletions

File tree

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ jobs:
7878
- name: Start demo
7979
run: |
8080
git clone https://github.com/DIRACGrid/diracx-charts.git ../diracx-charts
81-
../diracx-charts/run_demo.sh --enable-open-telemetry --enable-coverage --exit-when-done --set-value developer.autoReload=false $PWD
81+
../diracx-charts/run_demo.sh --enable-open-telemetry --enable-coverage --exit-when-done --set-value developer.autoReload=false --ci-values ../diracx-charts/demo/ci_values.yaml $PWD
8282
- name: Debugging information
8383
run: |
8484
DIRACX_DEMO_DIR=$PWD/../diracx-charts/.demo

diracx-core/src/diracx/core/config/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ def __hash__(self):
204204
@cachedmethod(lambda self: self._pull_cache)
205205
def _pull(self):
206206
"""Git pull from remote repo."""
207-
print("CHRIS PULL")
208207
self.repo.remotes.origin.pull()
209208

210209
def latest_revision(self) -> tuple[str, datetime]:

diracx-core/src/diracx/core/settings.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,13 @@ def create(cls) -> Self:
7878
async def lifetime_function(self) -> AsyncIterator[None]:
7979
"""A context manager that can be used to run code at startup and shutdown."""
8080
yield
81+
82+
83+
class DevelopmentSettings(ServiceSettingsBase):
84+
"""Settings for the Development Configuration that can influence run time."""
85+
86+
model_config = SettingsConfigDict(env_prefix="DIRACX_DEV_")
87+
88+
# When then to true (only for demo/CI), crash if an access policy isn't
89+
# called
90+
crash_on_missed_access_policy: bool = False

diracx-routers/src/diracx/routers/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,15 @@
1313
from collections.abc import AsyncGenerator
1414
from functools import partial
1515
from logging import Formatter, StreamHandler
16-
from typing import Any, Awaitable, Callable, Iterable, Sequence, TypeVar, cast
16+
from typing import (
17+
Any,
18+
Awaitable,
19+
Callable,
20+
Iterable,
21+
Sequence,
22+
TypeVar,
23+
cast,
24+
)
1725

1826
import dotenv
1927
from cachetools import TTLCache
@@ -139,6 +147,7 @@ def create_app_inner(
139147
# Please see ServiceSettingsBase for more details
140148

141149
available_settings_classes: set[type[ServiceSettingsBase]] = set()
150+
142151
for service_settings in all_service_settings:
143152
cls = type(service_settings)
144153
assert cls not in available_settings_classes

diracx-routers/src/diracx/routers/access_policies.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from fastapi import Depends
2727

2828
from diracx.core.extensions import select_from_extension
29+
from diracx.routers.dependencies import DevelopmentSettings
2930
from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token
3031

3132
# FastAPI bug:
@@ -99,6 +100,7 @@ def check_permissions(
99100
policy: Callable,
100101
policy_name: str,
101102
user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)],
103+
dev_settings: DevelopmentSettings,
102104
):
103105
"""This wrapper just calls the actual implementation, but also makes sure
104106
that the policy has been called.
@@ -120,6 +122,7 @@ async def wrapped_policy(**kwargs):
120122
try:
121123
yield wrapped_policy
122124
finally:
125+
123126
if not has_been_called:
124127
# TODO nice error message with inspect
125128
# That should really not happen
@@ -128,9 +131,11 @@ async def wrapped_policy(**kwargs):
128131
"(PS: I hope you are in a CI)",
129132
flush=True,
130133
)
131-
# Sleep a bit to make sure the flush happened
132-
time.sleep(1)
133-
os._exit(1)
134+
# If enable, just crash, meanly
135+
if dev_settings.crash_on_missed_access_policy:
136+
# Sleep a bit to make sure the flush happened
137+
time.sleep(1)
138+
os._exit(1)
134139

135140

136141
def open_access(f):

diracx-routers/src/diracx/routers/auth/well_known.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from fastapi import Request
44
from typing_extensions import TypedDict
55

6-
from ..dependencies import Config
6+
from ..dependencies import Config, DevelopmentSettings
77
from ..fastapi_classes import DiracxRouter
88
from ..utils.users import AuthSettings
99

@@ -17,7 +17,6 @@ async def openid_configuration(
1717
request: Request,
1818
config: Config,
1919
settings: AuthSettings,
20-
# check_permissions: OpenAccessPolicyCallable,
2120
):
2221
"""OpenID Connect discovery endpoint."""
2322
# await check_permissions()
@@ -65,17 +64,20 @@ class VOInfo(TypedDict):
6564

6665
class Metadata(TypedDict):
6766
virtual_organizations: dict[str, VOInfo]
67+
development_settings: DevelopmentSettings
6868

6969

7070
@router.get("/dirac-metadata")
7171
async def installation_metadata(
7272
config: Config,
7373
# check_permissions: OpenAccessPolicyCallable,
74+
dev_settings: DevelopmentSettings,
7475
) -> Metadata:
7576
"""Get metadata about the dirac installation."""
7677
# await check_permissions()
7778
metadata: Metadata = {
7879
"virtual_organizations": {},
80+
"development_settings": dev_settings,
7981
}
8082
for vo, vo_info in config.Registry.items():
8183
groups: dict[str, GroupInfo] = {

diracx-routers/src/diracx/routers/dependencies.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from diracx.core.config import Config as _Config
1919
from diracx.core.config import ConfigSource
2020
from diracx.core.properties import SecurityProperty
21+
from diracx.core.settings import DevelopmentSettings as _DevelopmentSettings
2122
from diracx.db.sql import AuthDB as _AuthDB
2223
from diracx.db.sql import JobDB as _JobDB
2324
from diracx.db.sql import JobLoggingDB as _JobLoggingDB
@@ -46,3 +47,7 @@ def add_settings_annotation(cls: T) -> T:
4647
AvailableSecurityProperties = Annotated[
4748
set[SecurityProperty], Depends(SecurityProperty.available_properties)
4849
]
50+
51+
DevelopmentSettings = Annotated[
52+
_DevelopmentSettings, Depends(_DevelopmentSettings.create)
53+
]

diracx-routers/src/diracx/routers/job_manager/access_policies.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ async def policy(
9797

9898

9999
class SandboxAccessPolicy(BaseAccessPolicy):
100-
"""Policy for the sandbox
101-
It delegates most of it to the WMSPolicy.
100+
"""Policy for the sandbox.
101+
They are similar to the WMS access policies.
102102
"""
103103

104104
@staticmethod
@@ -108,25 +108,11 @@ async def policy(
108108
/,
109109
*,
110110
action: ActionType | None = None,
111-
job_db: JobDB | None = None,
112111
sandbox_metadata_db: SandboxMetadataDB | None = None,
113112
pfns: list[str] | None = None,
114113
required_prefix: str | None = None,
115-
job_ids: list[int] | None = None,
116-
check_wms_permissions: CheckWMSPolicyCallable | None = None,
117114
):
118-
119115
assert action, "action is a mandatory parameter"
120-
121-
# if we pass the job_db or job_ids,
122-
# delegate the check to the WMSAccessPolicy
123-
if job_db or job_ids:
124-
# Make sure that check_wms_permission is set
125-
# It should always be by fastapi Depends,
126-
# but not when we test the policy in itself
127-
assert check_wms_permissions
128-
return check_wms_permissions(action=action, job_db=job_db, job_ids=job_ids)
129-
130116
assert sandbox_metadata_db, "sandbox_metadata_db is a mandatory parameter"
131117
assert pfns, "pfns is a mandatory parameter"
132118

diracx-routers/src/diracx/routers/job_manager/sandboxes.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@
2626
from diracx.core.settings import ServiceSettingsBase
2727

2828
from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token
29-
from .access_policies import ActionType, CheckSandboxPolicyCallable
29+
from .access_policies import (
30+
ActionType,
31+
CheckSandboxPolicyCallable,
32+
CheckWMSPolicyCallable,
33+
)
3034

3135
if TYPE_CHECKING:
3236
from types_aiobotocore_s3.client import S3Client
@@ -221,7 +225,7 @@ async def get_job_sandboxes(
221225
job_id: int,
222226
sandbox_metadata_db: SandboxMetadataDB,
223227
job_db: JobDB,
224-
check_permissions: CheckSandboxPolicyCallable,
228+
check_permissions: CheckWMSPolicyCallable,
225229
) -> dict[str, list[Any]]:
226230
"""Get input and output sandboxes of given job."""
227231
await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id])
@@ -241,7 +245,7 @@ async def get_job_sandbox(
241245
sandbox_metadata_db: SandboxMetadataDB,
242246
job_db: JobDB,
243247
sandbox_type: Literal["input", "output"],
244-
check_permissions: CheckSandboxPolicyCallable,
248+
check_permissions: CheckWMSPolicyCallable,
245249
) -> list[Any]:
246250
"""Get input or output sandbox of given job."""
247251
await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id])
@@ -259,7 +263,7 @@ async def assign_sandbox_to_job(
259263
sandbox_metadata_db: SandboxMetadataDB,
260264
job_db: JobDB,
261265
settings: SandboxStoreSettings,
262-
check_permissions: CheckSandboxPolicyCallable,
266+
check_permissions: CheckWMSPolicyCallable,
263267
):
264268
"""Map the pfn as output sandbox to job."""
265269
await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id])
@@ -277,7 +281,7 @@ async def unassign_job_sandboxes(
277281
job_id: int,
278282
sandbox_metadata_db: SandboxMetadataDB,
279283
job_db: JobDB,
280-
check_permissions: CheckSandboxPolicyCallable,
284+
check_permissions: CheckWMSPolicyCallable,
281285
):
282286
"""Delete single job sandbox mapping."""
283287
await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id])
@@ -289,7 +293,7 @@ async def unassign_bulk_jobs_sandboxes(
289293
jobs_ids: Annotated[list[int], Query()],
290294
sandbox_metadata_db: SandboxMetadataDB,
291295
job_db: JobDB,
292-
check_permissions: CheckSandboxPolicyCallable,
296+
check_permissions: CheckWMSPolicyCallable,
293297
):
294298
"""Delete bulk jobs sandbox mapping."""
295299
await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=jobs_ids)

diracx-routers/tests/auth/test_legacy_exchange.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99

1010
DIRAC_CLIENT_ID = "myDIRACClientID"
1111
pytestmark = pytest.mark.enabled_dependencies(
12-
["AuthDB", "AuthSettings", "ConfigSource", "BaseAccessPolicy"]
12+
[
13+
"AuthDB",
14+
"AuthSettings",
15+
"ConfigSource",
16+
"BaseAccessPolicy",
17+
"DevelopmentSettings",
18+
]
1319
)
1420

1521

0 commit comments

Comments
 (0)