55import pytest
66from fastapi import HTTPException , status
77
8- from diracx .core .properties import JOB_ADMINISTRATOR , NORMAL_USER
8+ from diracx .core .properties import GENERIC_PILOT , JOB_ADMINISTRATOR , NORMAL_USER
99from diracx .routers .jobs .access_policies import (
1010 ActionType ,
1111 SandboxAccessPolicy ,
@@ -27,6 +27,11 @@ class FakeJobDB:
2727 async def summary (self , * args ): ...
2828
2929
30+ class FakePilotDB :
31+ async def get_pilot_by_reference (self , * args ): ...
32+ async def get_pilot_job_ids (self , * args ): ...
33+
34+
3035class FakeSBMetadataDB :
3136 async def get_owner_id (self , * args ): ...
3237 async def get_sandbox_owner_id (self , * args ): ...
@@ -37,6 +42,11 @@ def job_db():
3742 yield FakeJobDB ()
3843
3944
45+ @pytest .fixture
46+ def pilot_db ():
47+ yield FakePilotDB ()
48+
49+
4050@pytest .fixture
4151def sandbox_metadata_db ():
4252 yield FakeSBMetadataDB ()
@@ -69,6 +79,112 @@ async def test_wms_access_policy_weird_user(job_db):
6979 )
7080
7181
82+ async def test_wms_access_policy_pilot (job_db , pilot_db , monkeypatch ):
83+
84+ normal_user = AuthorizedUserInfo (properties = [NORMAL_USER ], ** base_payload )
85+ pilot = AuthorizedUserInfo (properties = [GENERIC_PILOT ], ** base_payload )
86+
87+ # ------------------------- Simple User accessing a pilot action -------------------------
88+ # A user cannot create any resource
89+ with pytest .raises (HTTPException , match = f"{ status .HTTP_403_FORBIDDEN } " ) as excinfo :
90+ await WMSAccessPolicy .policy (
91+ WMS_POLICY_NAME ,
92+ normal_user ,
93+ action = ActionType .PILOT ,
94+ job_db = job_db ,
95+ pilot_db = pilot_db ,
96+ job_ids = [1 , 2 ],
97+ )
98+
99+ # Split to distinguish the generated part ("403 ") from the message part ("you are not a pilot")
100+ assert str (excinfo .value ) == "403: " + "you are not a pilot" , excinfo
101+
102+ # ------------------------- Lost pilot -------------------------
103+ async def get_pilot_by_reference_patch (* args ):
104+ return []
105+
106+ monkeypatch .setattr (
107+ pilot_db , "get_pilot_by_reference" , get_pilot_by_reference_patch
108+ )
109+
110+ # A pilot that has expired (removed from db) should not be able to access jobs
111+ with pytest .raises (HTTPException , match = f"{ status .HTTP_403_FORBIDDEN } " ) as excinfo :
112+ await WMSAccessPolicy .policy (
113+ WMS_POLICY_NAME ,
114+ pilot ,
115+ action = ActionType .PILOT ,
116+ pilot_db = pilot_db ,
117+ job_db = job_db ,
118+ job_ids = [1 , 2 ],
119+ )
120+
121+ assert str (excinfo .value ) == "403: " + "this pilot is not registered" , excinfo
122+
123+ # ------------------------- Pilot accessing wrong jobs -------------------------
124+ async def get_pilot_by_reference_patch (* args , ** kwargs ):
125+ return {"PilotID" : 1 }
126+
127+ async def get_pilot_job_ids_patch (* args , ** kwargs ):
128+ return []
129+
130+ monkeypatch .setattr (
131+ pilot_db , "get_pilot_by_reference" , get_pilot_by_reference_patch
132+ )
133+ monkeypatch .setattr (pilot_db , "get_pilot_job_ids" , get_pilot_job_ids_patch )
134+
135+ # A pilot that has is not associated with a job can't access a job
136+ with pytest .raises (HTTPException , match = f"{ status .HTTP_403_FORBIDDEN } " ) as excinfo :
137+ await WMSAccessPolicy .policy (
138+ WMS_POLICY_NAME ,
139+ pilot ,
140+ action = ActionType .PILOT ,
141+ pilot_db = pilot_db ,
142+ job_db = job_db ,
143+ job_ids = [1 , 2 ],
144+ )
145+
146+ assert (
147+ str (excinfo .value ) == "403: " + "this pilot can't access/modify this job"
148+ ), excinfo
149+
150+ # ------------------------- Pilot accessing some of his jobs -------------------------
151+ async def get_pilot_job_ids_patch (* args , ** kwargs ):
152+ return [1 , 2 , 3 , 4 ]
153+
154+ monkeypatch .setattr (pilot_db , "get_pilot_job_ids" , get_pilot_job_ids_patch )
155+
156+ # A pilot that is associated with a job can access a job
157+ await WMSAccessPolicy .policy (
158+ WMS_POLICY_NAME ,
159+ pilot ,
160+ action = ActionType .PILOT ,
161+ pilot_db = pilot_db ,
162+ job_db = job_db ,
163+ job_ids = [1 , 2 ],
164+ )
165+
166+ # ------------------------- Pilot accessing some of his jobs plus some forbidden -------------------------
167+ async def get_pilot_job_ids_patch (* args , ** kwargs ):
168+ return [1 , 2 , 3 , 4 ]
169+
170+ monkeypatch .setattr (pilot_db , "get_pilot_job_ids" , get_pilot_job_ids_patch )
171+
172+ # A pilot that fetches few jobs, one where he does not have the rights, and few where he has the rights
173+ with pytest .raises (HTTPException , match = f"{ status .HTTP_403_FORBIDDEN } " ) as excinfo :
174+ await WMSAccessPolicy .policy (
175+ WMS_POLICY_NAME ,
176+ pilot ,
177+ action = ActionType .PILOT ,
178+ pilot_db = pilot_db ,
179+ job_db = job_db ,
180+ job_ids = [1 , 2 , 12 ],
181+ )
182+
183+ assert (
184+ str (excinfo .value ) == "403: " + "this pilot can't access/modify this job"
185+ ), excinfo
186+
187+
72188async def test_wms_access_policy_create (job_db ):
73189
74190 admin_user = AuthorizedUserInfo (properties = [JOB_ADMINISTRATOR ], ** base_payload )
0 commit comments