Skip to content

Commit 6f601f1

Browse files
test: Adding tests to the WMS access policy, and some fixes
1 parent 444dd15 commit 6f601f1

3 files changed

Lines changed: 142 additions & 7 deletions

File tree

diracx-routers/src/diracx/routers/jobs/access_policies.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,6 @@ async def policy(
5959
pilot_db
6060
), "pilot_db is a mandatory parameter when using a pilot action"
6161
assert job_ids, "job_ids has to be defined"
62-
assert (
63-
len(job_ids) == 1
64-
), "a pilot can have only one job_id associated, and it has to be given"
65-
6662
pilot_info = user_info # For semantic
6763

6864
# Syntax to avoid code duplication
@@ -83,11 +79,11 @@ async def policy(
8379
)
8480

8581
# Equivalent of issubset, but cleaner
86-
if set(job_ids) <= pilot_jobs:
82+
if set(job_ids) <= set(pilot_jobs):
8783
return
8884

8985
raise HTTPException(
90-
status.HTTP_403_FORBIDDEN, "this pilot can't modify this job"
86+
status.HTTP_403_FORBIDDEN, "this pilot can't access/modify this job"
9187
)
9288

9389
raise HTTPException(status.HTTP_403_FORBIDDEN, "you are not a pilot")

diracx-routers/tests/jobs/test_wms_access_policy.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
from fastapi import HTTPException, status
77

8-
from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER
8+
from diracx.core.properties import GENERIC_PILOT, JOB_ADMINISTRATOR, NORMAL_USER
99
from diracx.routers.jobs.access_policies import (
1010
ActionType,
1111
SandboxAccessPolicy,
@@ -27,6 +27,11 @@ class FakeJobDB:
2727
async def summary(self, *args): ...
2828

2929

30+
class FakePilotDB:
31+
async def get_pilot_by_reference(self, *args): ...
32+
async def get_pilot_job_ids(self, *args): ...
33+
34+
3035
class FakeSBMetadataDB:
3136
async def get_owner_id(self, *args): ...
3237
async def get_sandbox_owner_id(self, *args): ...
@@ -37,6 +42,11 @@ def job_db():
3742
yield FakeJobDB()
3843

3944

45+
@pytest.fixture
46+
def pilot_db():
47+
yield FakePilotDB()
48+
49+
4050
@pytest.fixture
4151
def sandbox_metadata_db():
4252
yield FakeSBMetadataDB()
@@ -69,6 +79,112 @@ async def test_wms_access_policy_weird_user(job_db):
6979
)
7080

7181

82+
async def test_wms_access_policy_pilot(job_db, pilot_db, monkeypatch):
83+
84+
normal_user = AuthorizedUserInfo(properties=[NORMAL_USER], **base_payload)
85+
pilot = AuthorizedUserInfo(properties=[GENERIC_PILOT], **base_payload)
86+
87+
# ------------------------- Simple User accessing a pilot action -------------------------
88+
# A user cannot create any resource
89+
with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}") as excinfo:
90+
await WMSAccessPolicy.policy(
91+
WMS_POLICY_NAME,
92+
normal_user,
93+
action=ActionType.PILOT,
94+
job_db=job_db,
95+
pilot_db=pilot_db,
96+
job_ids=[1, 2],
97+
)
98+
99+
# Split to distinguish the generated part ("403 ") from the message part ("you are not a pilot")
100+
assert str(excinfo.value) == "403: " + "you are not a pilot", excinfo
101+
102+
# ------------------------- Lost pilot -------------------------
103+
async def get_pilot_by_reference_patch(*args):
104+
return []
105+
106+
monkeypatch.setattr(
107+
pilot_db, "get_pilot_by_reference", get_pilot_by_reference_patch
108+
)
109+
110+
# A pilot that has expired (removed from db) should not be able to access jobs
111+
with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}") as excinfo:
112+
await WMSAccessPolicy.policy(
113+
WMS_POLICY_NAME,
114+
pilot,
115+
action=ActionType.PILOT,
116+
pilot_db=pilot_db,
117+
job_db=job_db,
118+
job_ids=[1, 2],
119+
)
120+
121+
assert str(excinfo.value) == "403: " + "this pilot is not registered", excinfo
122+
123+
# ------------------------- Pilot accessing wrong jobs -------------------------
124+
async def get_pilot_by_reference_patch(*args, **kwargs):
125+
return {"PilotID": 1}
126+
127+
async def get_pilot_job_ids_patch(*args, **kwargs):
128+
return []
129+
130+
monkeypatch.setattr(
131+
pilot_db, "get_pilot_by_reference", get_pilot_by_reference_patch
132+
)
133+
monkeypatch.setattr(pilot_db, "get_pilot_job_ids", get_pilot_job_ids_patch)
134+
135+
# A pilot that has is not associated with a job can't access a job
136+
with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}") as excinfo:
137+
await WMSAccessPolicy.policy(
138+
WMS_POLICY_NAME,
139+
pilot,
140+
action=ActionType.PILOT,
141+
pilot_db=pilot_db,
142+
job_db=job_db,
143+
job_ids=[1, 2],
144+
)
145+
146+
assert (
147+
str(excinfo.value) == "403: " + "this pilot can't access/modify this job"
148+
), excinfo
149+
150+
# ------------------------- Pilot accessing some of his jobs -------------------------
151+
async def get_pilot_job_ids_patch(*args, **kwargs):
152+
return [1, 2, 3, 4]
153+
154+
monkeypatch.setattr(pilot_db, "get_pilot_job_ids", get_pilot_job_ids_patch)
155+
156+
# A pilot that is associated with a job can access a job
157+
await WMSAccessPolicy.policy(
158+
WMS_POLICY_NAME,
159+
pilot,
160+
action=ActionType.PILOT,
161+
pilot_db=pilot_db,
162+
job_db=job_db,
163+
job_ids=[1, 2],
164+
)
165+
166+
# ------------------------- Pilot accessing some of his jobs plus some forbidden -------------------------
167+
async def get_pilot_job_ids_patch(*args, **kwargs):
168+
return [1, 2, 3, 4]
169+
170+
monkeypatch.setattr(pilot_db, "get_pilot_job_ids", get_pilot_job_ids_patch)
171+
172+
# A pilot that fetches few jobs, one where he does not have the rights, and few where he has the rights
173+
with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}") as excinfo:
174+
await WMSAccessPolicy.policy(
175+
WMS_POLICY_NAME,
176+
pilot,
177+
action=ActionType.PILOT,
178+
pilot_db=pilot_db,
179+
job_db=job_db,
180+
job_ids=[1, 2, 12],
181+
)
182+
183+
assert (
184+
str(excinfo.value) == "403: " + "this pilot can't access/modify this job"
185+
), excinfo
186+
187+
72188
async def test_wms_access_policy_create(job_db):
73189

74190
admin_user = AuthorizedUserInfo(properties=[JOB_ADMINISTRATOR], **base_payload)

diracx-testing/src/diracx/testing/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,29 @@ def unauthenticated(self):
359359
with TestClient(self.app) as client:
360360
yield client
361361

362+
@contextlib.contextmanager
363+
def pilot(self):
364+
from diracx.core.properties import GENERIC_PILOT, LIMITED_DELEGATION
365+
from diracx.routers.auth.token import create_token
366+
367+
with self.unauthenticated() as client:
368+
payload = {
369+
"sub": "testingVO:yellow-sub",
370+
"exp": datetime.now(tz=timezone.utc)
371+
+ timedelta(self.test_auth_settings.access_token_expire_minutes),
372+
"iss": ISSUER,
373+
"dirac_properties": [GENERIC_PILOT, LIMITED_DELEGATION],
374+
"jti": str(uuid4()),
375+
"preferred_username": "preferred_username",
376+
"dirac_group": "test_group",
377+
"vo": "lhcb",
378+
}
379+
token = create_token(payload, self.test_auth_settings)
380+
381+
client.headers["Authorization"] = f"Bearer {token}"
382+
client.dirac_token_payload = payload
383+
yield client
384+
362385
@contextlib.contextmanager
363386
def normal_user(self):
364387
from diracx.core.properties import NORMAL_USER

0 commit comments

Comments
 (0)