Skip to content

Commit 6909c7c

Browse files
feat: Better handling of refresh tokens for pilots
1 parent 2097e10 commit 6909c7c

7 files changed

Lines changed: 157 additions & 78 deletions

File tree

diracx-db/src/diracx/db/sql/pilot_agents/db.py

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

33
from datetime import datetime, timezone
4+
from typing import Sequence
45

5-
from sqlalchemy import DateTime, insert, select, update
6+
from sqlalchemy import insert, select, update
67
from sqlalchemy.exc import IntegrityError, NoResultFound
78

89
from diracx.core.exceptions import (
@@ -26,7 +27,7 @@ async def add_pilot_references(
2627
vo: str,
2728
grid_type: str = "DIRAC",
2829
pilot_stamps: dict | None = None,
29-
) -> None:
30+
) -> Sequence: # Return a list of primary keys
3031

3132
if pilot_stamps is None:
3233
pilot_stamps = {}
@@ -47,10 +48,18 @@ async def add_pilot_references(
4748
for ref in pilot_ref
4849
]
4950

50-
# Insert multiple rows in a single execute call
51-
stmt = insert(PilotAgents).values(values)
52-
await self.conn.execute(stmt)
53-
return
51+
# Insert multiple rows in a single execute call and use 'returning' to get primary keys
52+
stmt = (
53+
insert(PilotAgents).values(values).returning(PilotAgents.pilot_id)
54+
) # Assuming 'id' is the primary key
55+
result = await self.conn.execute(stmt)
56+
57+
# Use .scalars() and .all() to get the primary keys directly in a list
58+
primary_keys = (
59+
result.scalars().all()
60+
) # This returns a flat list of primary keys
61+
62+
return primary_keys
5463

5564
async def increment_pilot_secret_use(
5665
self,
@@ -100,32 +109,6 @@ async def verify_pilot_secret(
100109
# Increment the count
101110
await self.increment_pilot_secret_use(pilot_id=pilot_id)
102111

103-
async def register_new_pilot(
104-
self,
105-
vo: str,
106-
pilot_job_reference: str,
107-
pilot_stamp: str,
108-
grid_type: str = "DIRAC",
109-
submission_time: DateTime | None = None, # ?
110-
last_update_time: DateTime | None = None, # = now?
111-
) -> int | None:
112-
stmt = insert(PilotAgents).values(
113-
vo=vo,
114-
submission_time=submission_time,
115-
last_update_time=last_update_time,
116-
pilot_job_reference=pilot_job_reference,
117-
grid_type=grid_type,
118-
pilot_stamp=pilot_stamp,
119-
)
120-
121-
# Execute the request
122-
res = await self.conn.execute(stmt)
123-
124-
new_pilot_id = res.inserted_primary_key
125-
126-
# Returns the new pilot ID
127-
return int(new_pilot_id[0]) if new_pilot_id else None
128-
129112
async def add_pilot_credentials(self, pilot_id: int, pilot_hashed_secret: str):
130113

131114
stmt = insert(PilotRegistrations).values(

diracx-db/tests/pilot_agents/test_pilot_agents_db.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,9 @@ async def test_insert_and_select_single(pilot_agents_db: PilotAgentsDB):
3838

3939
async with pilot_agents_db as pilot_agents_db:
4040
pilot_reference = "pilot-reference-test"
41-
await pilot_agents_db.register_new_pilot(
42-
vo="pilot-vo",
43-
pilot_job_reference=pilot_reference,
44-
pilot_stamp="pilot-stamp",
41+
await pilot_agents_db.add_pilot_references(
42+
vo="lhcb",
43+
pilot_ref=[pilot_reference],
4544
grid_type="grid-type",
4645
)
4746

@@ -51,26 +50,26 @@ async def test_insert_and_select_single(pilot_agents_db: PilotAgentsDB):
5150
await pilot_agents_db.get_pilot_by_reference("I am a fake ref")
5251

5352
# Set values
54-
assert res["VO"] == "pilot-vo"
53+
assert res["VO"] == "lhcb"
5554
assert res["PilotJobReference"] == pilot_reference
56-
assert res["PilotStamp"] == "pilot-stamp"
5755
assert res["GridType"] == "grid-type"
5856

59-
# Default values
60-
assert res["BenchMark"] == 0.0
61-
assert res["Status"] == "Unknown"
62-
6357

6458
async def test_create_pilot_and_verify_secret(pilot_agents_db: PilotAgentsDB):
6559

6660
async with pilot_agents_db as pilot_agents_db:
6761
pilot_reference = "pilot-reference-test"
68-
pilot_id = await pilot_agents_db.register_new_pilot(
69-
vo="pilot-vo",
70-
pilot_job_reference=pilot_reference,
71-
pilot_stamp="pilot-stamp",
62+
pilot_ids = await pilot_agents_db.add_pilot_references(
63+
vo="lhcb",
64+
pilot_ref=[pilot_reference],
65+
grid_type="grid-type",
7266
)
7367

68+
assert len(pilot_ids) == 1
69+
70+
# Only one element
71+
pilot_id = pilot_ids[0]
72+
7473
secret = "AW0nd3rfulS3cr3t"
7574
pilot_hashed_secret = hash(secret)
7675

diracx-logic/src/diracx/logic/auth/token.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ async def get_oidc_token(
7373
oidc_token_info,
7474
scope,
7575
legacy_exchange,
76-
) = await get_oidc_token_info_from_refresh_flow(
77-
refresh_token, auth_db, settings
78-
)
76+
) = await get_token_info_from_refresh_flow(refresh_token, auth_db, settings)
7977
else:
8078
raise NotImplementedError(f"Grant type not implemented {grant_type}")
8179

@@ -155,7 +153,7 @@ async def get_oidc_token_info_from_authorization_flow(
155153
return (oidc_token_info, scope)
156154

157155

158-
async def get_oidc_token_info_from_refresh_flow(
156+
async def get_token_info_from_refresh_flow(
159157
refresh_token: str, auth_db: AuthDB, settings: AuthSettings
160158
) -> tuple[dict, str, bool]:
161159
"""Get OIDC token information from the refresh token DB and check few parameters before returning it."""
@@ -310,7 +308,7 @@ async def exchange_token(
310308
)
311309

312310
else:
313-
preferred_username = oidc_token_info["pilot_reference"]
311+
preferred_username = oidc_token_info["preferred_username"]
314312

315313
# Merge the VO with the subject to get a unique DIRAC sub
316314
sub = f"{vo}:{sub}"
@@ -361,20 +359,32 @@ async def generate_pilot_tokens(
361359
config: Config,
362360
settings: AuthSettings,
363361
available_properties: set[SecurityProperty],
362+
refresh_token: str | None = None,
364363
) -> tuple[AccessTokenPayload, RefreshTokenPayload]:
365364

366-
pilot = await get_pilot_informations_by_reference(
367-
pilot_db=pilot_db, pilot_job_reference=pilot_job_reference
368-
)
365+
scope = None
366+
pilot_info = None
369367

370-
pilot_info = {
371-
"pilot_reference": pilot["PilotJobReference"],
372-
"sub": pilot["PilotJobReference"],
373-
}
368+
if refresh_token is not None:
369+
pilot_info, scope, _ = await get_token_info_from_refresh_flow(
370+
refresh_token=refresh_token, auth_db=auth_db, settings=settings
371+
)
372+
else:
373+
374+
pilot = await get_pilot_informations_by_reference(
375+
pilot_db=pilot_db, pilot_job_reference=pilot_job_reference
376+
)
377+
378+
pilot_info = {
379+
"preferred_username": pilot["PilotJobReference"],
380+
"sub": pilot["PilotJobReference"],
381+
}
382+
383+
scope = generate_pilot_scope(pilot)
374384

375385
access_token, refresh_token = await exchange_token(
376386
auth_db=auth_db,
377-
scope=generate_pilot_scope(pilot),
387+
scope=scope,
378388
oidc_token_info=pilot_info,
379389
config=config,
380390
settings=settings,

diracx-routers/src/diracx/routers/auth/pilot.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
11
from __future__ import annotations
22

3-
from fastapi import HTTPException, status
3+
from typing import Annotated
44

5-
from diracx.core.exceptions import AuthorizationError, PilotNotFoundError
5+
from fastapi import (
6+
Depends,
7+
HTTPException,
8+
status,
9+
)
10+
11+
from diracx.core.exceptions import (
12+
AuthorizationError,
13+
InvalidCredentialsError,
14+
PilotNotFoundError,
15+
)
616
from diracx.logic.auth.pilot import try_login
717
from diracx.logic.auth.token import create_token, generate_pilot_tokens
18+
from diracx.routers.pilots.access_policies import RegisteredPilotAccessPolicyCallable
819

920
from ..dependencies import (
1021
AuthDB,
@@ -14,6 +25,7 @@
1425
PilotAgentsDB,
1526
)
1627
from ..fastapi_classes import DiracxRouter
28+
from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token
1729

1830
router = DiracxRouter(require_auth=False)
1931

@@ -63,3 +75,42 @@ async def pilot_login(
6375
"access_token": create_token(access_token, settings),
6476
"refresh_token": create_token(refresh_token, settings),
6577
}
78+
79+
80+
@router.post("/pilot-refresh-token")
81+
async def refresh_pilot_tokens(
82+
pilot_db: PilotAgentsDB,
83+
auth_db: AuthDB,
84+
config: Config,
85+
settings: AuthSettings,
86+
available_properties: AvailableSecurityProperties,
87+
check_permissions: RegisteredPilotAccessPolicyCallable,
88+
refresh_token: str,
89+
pilot_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)],
90+
):
91+
"""Endpoint where a pilot can exchange a refresh token against a token."""
92+
await check_permissions()
93+
94+
try:
95+
new_access_token, new_refresh_token = await generate_pilot_tokens(
96+
pilot_db=pilot_db,
97+
auth_db=auth_db,
98+
pilot_job_reference=pilot_info.preferred_username,
99+
config=config,
100+
settings=settings,
101+
available_properties=available_properties,
102+
refresh_token=refresh_token,
103+
)
104+
except InvalidCredentialsError as e:
105+
raise HTTPException(
106+
status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e)
107+
) from e
108+
except ValueError as e:
109+
raise HTTPException(
110+
status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)
111+
) from e
112+
113+
return {
114+
"access_token": create_token(new_access_token, settings),
115+
"refresh_token": create_token(new_refresh_token, settings),
116+
}

diracx-routers/src/diracx/routers/pilots/access_policies.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
from collections.abc import Callable
4-
from enum import StrEnum, auto
54
from typing import Annotated
65

76
from fastapi import Depends, HTTPException, status
@@ -13,18 +12,6 @@
1312
from diracx.routers.utils.users import AuthorizedUserInfo
1413

1514

16-
class ActionType(StrEnum):
17-
#: Create a job or a sandbox
18-
CREATE = auto()
19-
#: Check job status, download a sandbox
20-
READ = auto()
21-
#: delete, kill, remove, set status, etc of a job
22-
#: delete or assign a sandbox
23-
MANAGE = auto()
24-
#: Search
25-
QUERY = auto()
26-
27-
2815
class RegisteredPilotAccessPolicy(BaseAccessPolicy):
2916

3017
@staticmethod

diracx-routers/src/diracx/routers/pilots/debug.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
@router.get("/info")
1717
async def get_pilot_info(
1818
check_permissions: RegisteredPilotAccessPolicyCallable,
19-
user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)],
19+
pilot_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)],
2020
):
2121
await check_permissions()
2222

23-
return user_info
23+
return pilot_info

diracx-routers/tests/auth/test_pilot_auth.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,17 @@ async def test_create_pilot_and_verify_secret(test_client):
4343

4444
async with db as pilot_agents_db:
4545
# Register a pilot
46-
pilot_id = await pilot_agents_db.register_new_pilot(
47-
vo=pilot_vo, pilot_job_reference=pilot_reference, pilot_stamp="pilot-stamp"
46+
pilot_ids = await pilot_agents_db.add_pilot_references(
47+
vo=pilot_vo,
48+
pilot_ref=[pilot_reference],
49+
grid_type="grid-type",
4850
)
4951

52+
assert len(pilot_ids) == 1
53+
54+
# Only one element
55+
pilot_id = pilot_ids[0]
56+
5057
# Add credentials to this pilot
5158
await pilot_agents_db.add_pilot_credentials(
5259
pilot_id=pilot_id, pilot_hashed_secret=pilot_hashed_secret
@@ -116,3 +123,45 @@ async def test_create_pilot_and_verify_secret(test_client):
116123

117124
assert r.status_code == 401
118125
assert r.json()["detail"] == "bad pilot_id / pilot_secret"
126+
127+
# ----------------- Exchange for new tokens -----------------
128+
request_data = {"refresh_token": refresh_token}
129+
r = test_client.post(
130+
"/api/auth/pilot-refresh-token",
131+
params=request_data,
132+
headers={"Authorization": f"Bearer {access_token}"},
133+
)
134+
135+
assert r.status_code == 200
136+
137+
new_access_token = r.json()["access_token"]
138+
new_refresh_token = r.json()["refresh_token"]
139+
140+
# ----------------- Get info with new token -----------------
141+
r = test_client.get(
142+
"/api/pilots/info", headers={"Authorization": f"Bearer {new_access_token}"}
143+
)
144+
145+
assert r.status_code == 200
146+
147+
# ----------------- Exchange token with old token -----------------
148+
request_data = {"refresh_token": refresh_token}
149+
r = test_client.post(
150+
"/api/auth/pilot-refresh-token",
151+
params=request_data,
152+
headers={"Authorization": f"Bearer {access_token}"},
153+
)
154+
155+
assert r.status_code == 401, r.json()
156+
157+
# ----------------- Exchange token with new token -----------------
158+
request_data = {"refresh_token": new_refresh_token}
159+
r = test_client.post(
160+
"/api/auth/pilot-refresh-token",
161+
params=request_data,
162+
headers={"Authorization": f"Bearer {new_access_token}"},
163+
)
164+
165+
# RFC6749
166+
# https://datatracker.ietf.org/doc/html/rfc6749#section-10.4
167+
assert r.status_code == 401, r.json()

0 commit comments

Comments
 (0)