77from fastapi import Depends , HTTPException , status
88
99from diracx .core .properties import SERVICE_ADMINISTRATOR
10+ from diracx .db .sql .job .db import JobDB
1011from diracx .db .sql .pilots .db import PilotAgentsDB
1112from diracx .logic .pilots .query import get_pilots_by_stamp
1213from diracx .routers .access_policies import BaseAccessPolicy
@@ -35,26 +36,50 @@ async def policy(
3536 action : ActionType | None = None ,
3637 pilot_db : PilotAgentsDB | None = None ,
3738 pilot_stamps : list [str ] | None = None ,
39+ job_db : JobDB | None = None ,
40+ job_ids : list [int ] | None = None ,
3841 ):
3942 assert action , "action is a mandatory parameter"
4043
4144 # Users can query
4245 # NOTE: Add into queries a VO constraint
43- if action == ActionType .READ_PILOT_FIELDS :
44- return
46+ # To manage pilots, user have to be an admin
47+ if (
48+ action == ActionType .MANAGE_PILOTS
49+ and SERVICE_ADMINISTRATOR not in user_info .properties
50+ ):
51+ raise HTTPException (
52+ status_code = status .HTTP_403_FORBIDDEN ,
53+ detail = "You don't have the permission to manage pilots." ,
54+ )
4555
46- # If we want to modify pilots, we allow only admins
47- # TODO: See if we add other types of admins
48- if SERVICE_ADMINISTRATOR in user_info .properties :
49- # If we don't provide pilot_db and pilot_stamps, we accept directly
50- # This is for example when we submit pilots, we use the user VO, so no need to verify
51- if not (pilot_db and pilot_stamps ):
52- return
56+ #
57+ # Additional checks if job_ids or pilot_stamps are provided
58+ #
5359
54- # Else, check its VO
55- assert pilot_db , "PilotDB is needed to determine pilot VO."
56- assert pilot_stamps , "PilotStamps are needed to determine pilot VO."
60+ # First, if job_ids are provided, we check who is the owner
61+ if job_db and job_ids :
62+ job_owners = await job_db .summary (
63+ ["Owner" , "VO" ],
64+ [{"parameter" : "JobID" , "operator" : "in" , "values" : job_ids }],
65+ )
66+
67+ expected_owner = {
68+ "Owner" : user_info .preferred_username ,
69+ "VO" : user_info .vo ,
70+ "count" : len (set (job_ids )),
71+ }
72+ # All the jobs belong to the user doing the query
73+ # and all of them are present
74+ if not job_owners == [expected_owner ]:
75+ raise HTTPException (
76+ status_code = status .HTTP_403_FORBIDDEN ,
77+ detail = "You don't have the rights to modify a pilot." ,
78+ )
5779
80+ # This is for example when we submit pilots, we use the user VO, so no need to verify
81+ if pilot_db and pilot_stamps :
82+ # Else, check its VO
5883 pilots = await get_pilots_by_stamp (
5984 pilot_db = pilot_db ,
6085 pilot_stamps = pilot_stamps ,
@@ -74,13 +99,6 @@ async def policy(
7499 detail = "You don't have access to all pilots." ,
75100 )
76101
77- return
78-
79- raise HTTPException (
80- status_code = status .HTTP_403_FORBIDDEN ,
81- detail = "You don't have the rights to modify a pilot." ,
82- )
83-
84102
85103CheckPilotManagementPolicyCallable = Annotated [
86104 Callable , Depends (PilotManagementAccessPolicy .check )
0 commit comments