Skip to content

Commit 8217650

Browse files
author
Robin Van de Merghel
committed
feat: Adding pilot registrations
1 parent 4d91238 commit 8217650

14 files changed

Lines changed: 641 additions & 30 deletions

File tree

diracx-core/src/diracx/core/exceptions.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,41 @@ def __init__(self, job_id, detail: str | None = None):
101101
super().__init__(
102102
f"Error concerning job {job_id}" + (": {detail} " if detail else "")
103103
)
104+
105+
106+
class PilotNotFoundError(Exception):
107+
def __init__(
108+
self,
109+
pilot_ref: str | None = None,
110+
pilot_id: int | None = None,
111+
detail: str | None = None,
112+
):
113+
self.pilot_ref = pilot_ref
114+
self.pilot_id = pilot_id
115+
self.detail = detail
116+
super().__init__(
117+
"Pilot "
118+
+ (f"(Ref: {pilot_ref})" if pilot_ref else "")
119+
+ (f" (ID: {str(pilot_id)})" if pilot_id is not None else "")
120+
+ " not found"
121+
+ (f": {detail}" if detail else "")
122+
)
123+
124+
125+
class PilotAlreadyExistsError(Exception):
126+
def __init__(
127+
self,
128+
pilot_ref: str | None = None, # Changed to str based on the format
129+
pilot_id: int | None = None,
130+
detail: str | None = None,
131+
):
132+
self.pilot_ref = pilot_ref
133+
self.pilot_id = pilot_id
134+
self.detail = detail
135+
super().__init__(
136+
"Pilot "
137+
+ (f"(Ref: {pilot_ref})" if pilot_ref else "")
138+
+ (f" (ID: {str(pilot_id)})" if pilot_id is not None else "")
139+
+ " already exists"
140+
+ (f": {detail}" if detail else "")
141+
)

diracx-db/src/diracx/db/sql/pilot_agents/db.py

Lines changed: 118 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,17 @@
22

33
from datetime import datetime, timezone
44

5-
from sqlalchemy import insert
5+
from sqlalchemy import DateTime, insert, select, update
6+
from sqlalchemy.exc import IntegrityError, NoResultFound
7+
8+
from diracx.core.exceptions import (
9+
AuthorizationError,
10+
PilotAlreadyExistsError,
11+
PilotNotFoundError,
12+
)
613

714
from ..utils import BaseSQLDB
8-
from .schema import PilotAgents, PilotAgentsDBBase
15+
from .schema import PilotAgents, PilotAgentsDBBase, PilotRegistrations
916

1017

1118
class PilotAgentsDB(BaseSQLDB):
@@ -44,3 +51,112 @@ async def add_pilot_references(
4451
stmt = insert(PilotAgents).values(values)
4552
await self.conn.execute(stmt)
4653
return
54+
55+
async def increment_pilot_secret_use(
56+
self,
57+
pilot_id: int,
58+
) -> None:
59+
60+
# Prepare the update statement
61+
stmt = (
62+
update(PilotRegistrations)
63+
.values(
64+
pilot_secret_use_count=PilotRegistrations.pilot_secret_use_count + 1
65+
)
66+
.where(PilotRegistrations.pilot_id == pilot_id)
67+
)
68+
69+
# Execute the update using the connection
70+
res = await self.conn.execute(stmt)
71+
72+
if res.rowcount == 0:
73+
raise PilotNotFoundError(pilot_id=pilot_id)
74+
75+
async def verify_pilot_secret(
76+
self, pilot_job_reference: str, pilot_hashed_secret: str
77+
) -> None:
78+
79+
try:
80+
pilot = await self.get_pilot_by_reference(pilot_job_reference)
81+
except NoResultFound as e:
82+
raise PilotNotFoundError(pilot_ref=pilot_job_reference) from e
83+
84+
pilot_id = pilot["PilotID"]
85+
86+
stmt = (
87+
select(PilotRegistrations)
88+
.where(PilotRegistrations.pilot_hashed_secret == pilot_hashed_secret)
89+
.where(PilotRegistrations.pilot_id == pilot_id)
90+
)
91+
92+
# Execute the request
93+
res = await self.conn.execute(stmt)
94+
95+
result = res.fetchone()
96+
97+
if result is None:
98+
raise AuthorizationError(detail="bad pilot_id / pilot_secret")
99+
100+
# Increment the count
101+
await self.increment_pilot_secret_use(pilot_id=pilot_id)
102+
103+
async def register_new_pilot(
104+
self,
105+
vo: str,
106+
pilot_job_reference: str,
107+
pilot_stamp: str,
108+
grid_type: str = "DIRAC",
109+
submission_time: DateTime | None = None, # ?
110+
last_update_time: DateTime | None = None, # = now?
111+
) -> int | None:
112+
stmt = insert(PilotAgents).values(
113+
vo=vo,
114+
submission_time=submission_time,
115+
last_update_time=last_update_time,
116+
pilot_job_reference=pilot_job_reference,
117+
grid_type=grid_type,
118+
pilot_stamp=pilot_stamp,
119+
)
120+
121+
# Execute the request
122+
res = await self.conn.execute(stmt)
123+
124+
new_pilot_id = res.inserted_primary_key
125+
126+
# Returns the new pilot ID
127+
return int(new_pilot_id[0]) if new_pilot_id else None
128+
129+
async def add_pilot_credentials(self, pilot_id: int, pilot_hashed_secret: str):
130+
131+
stmt = insert(PilotRegistrations).values(
132+
pilot_id=pilot_id, pilot_hashed_secret=pilot_hashed_secret
133+
)
134+
135+
try:
136+
await self.conn.execute(stmt)
137+
except IntegrityError as e:
138+
if "foreign key" in str(e.orig).lower():
139+
raise PilotNotFoundError(pilot_id=pilot_id) from e
140+
if "duplicate entry" in str(e.orig).lower():
141+
raise PilotAlreadyExistsError(
142+
pilot_id=pilot_id, detail="this pilot has already credentials"
143+
) from e
144+
145+
async def fetch_all_pilots(self):
146+
stmt = select(PilotRegistrations).with_for_update()
147+
result = await self.conn.execute(stmt)
148+
149+
# Convert results into a dictionary
150+
pilots = [dict(row._mapping) for row in result]
151+
152+
return pilots
153+
154+
async def get_pilot_by_reference(self, pilot_ref: str):
155+
stmt = (
156+
select(PilotAgents)
157+
.with_for_update()
158+
.where(PilotAgents.pilot_job_reference == pilot_ref)
159+
)
160+
161+
# We assume it is unique...
162+
return dict((await self.conn.execute(stmt)).one()._mapping)

diracx-db/src/diracx/db/sql/pilot_agents/schema.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
11
from __future__ import annotations
22

3-
from sqlalchemy import (
4-
DateTime,
5-
Double,
6-
Index,
7-
Integer,
8-
String,
9-
Text,
10-
)
3+
from sqlalchemy import DateTime, Double, ForeignKey, Index, Integer, String, Text
114
from sqlalchemy.orm import declarative_base
125

136
from ..utils import Column, EnumBackedBool, NullColumn
@@ -58,3 +51,16 @@ class PilotOutput(PilotAgentsDBBase):
5851
pilot_id = Column("PilotID", Integer, primary_key=True)
5952
std_output = Column("StdOutput", Text)
6053
std_error = Column("StdError", Text)
54+
55+
56+
class PilotRegistrations(PilotAgentsDBBase):
57+
__tablename__ = "PilotRegistrations"
58+
59+
pilot_id = Column(
60+
"PilotID",
61+
Integer,
62+
ForeignKey("PilotAgents.PilotID", ondelete="CASCADE"),
63+
primary_key=True,
64+
)
65+
pilot_hashed_secret = Column("PilotHashedSecret", String(64))
66+
pilot_secret_use_count = Column("PilotSecretUseCount", Integer, default=0)

diracx-db/tests/pilot_agents/test_pilot_agents_db.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from __future__ import annotations
22

33
import pytest
4+
from sqlalchemy.exc import NoResultFound
45

6+
from diracx.core.exceptions import AuthorizationError
57
from diracx.db.sql.pilot_agents.db import PilotAgentsDB
8+
from diracx.db.sql.utils.functions import hash
69

710

811
@pytest.fixture
9-
async def pilot_agents_db(tmp_path) -> PilotAgentsDB:
12+
async def pilot_agents_db(tmp_path):
1013
agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:")
1114
async with agents_db.engine_context():
1215
async with agents_db.engine.begin() as conn:
@@ -29,3 +32,66 @@ async def test_insert_and_select(pilot_agents_db: PilotAgentsDB):
2932
await pilot_agents_db.add_pilot_references(
3033
refs, "test_vo", grid_type="DIRAC", pilot_stamps=None
3134
)
35+
36+
37+
async def test_insert_and_select_single(pilot_agents_db: PilotAgentsDB):
38+
39+
async with pilot_agents_db as pilot_agents_db:
40+
pilot_reference = "pilot-reference-test"
41+
await pilot_agents_db.register_new_pilot(
42+
vo="pilot-vo",
43+
pilot_job_reference=pilot_reference,
44+
pilot_stamp="pilot-stamp",
45+
grid_type="grid-type",
46+
)
47+
48+
res = await pilot_agents_db.get_pilot_by_reference(pilot_ref=pilot_reference)
49+
50+
with pytest.raises(NoResultFound):
51+
await pilot_agents_db.get_pilot_by_reference("I am a fake ref")
52+
53+
# Set values
54+
assert res["VO"] == "pilot-vo"
55+
assert res["PilotJobReference"] == pilot_reference
56+
assert res["PilotStamp"] == "pilot-stamp"
57+
assert res["GridType"] == "grid-type"
58+
59+
# Default values
60+
assert res["BenchMark"] == 0.0
61+
assert res["Status"] == "Unknown"
62+
63+
64+
async def test_create_pilot_and_verify_secret(pilot_agents_db: PilotAgentsDB):
65+
66+
async with pilot_agents_db as pilot_agents_db:
67+
pilot_reference = "pilot-reference-test"
68+
pilot_id = await pilot_agents_db.register_new_pilot(
69+
vo="pilot-vo",
70+
pilot_job_reference=pilot_reference,
71+
pilot_stamp="pilot-stamp",
72+
)
73+
74+
secret = "AW0nd3rfulS3cr3t"
75+
pilot_hashed_secret = hash(secret)
76+
77+
# Add creds
78+
await pilot_agents_db.add_pilot_credentials(
79+
pilot_id=pilot_id, pilot_hashed_secret=pilot_hashed_secret
80+
)
81+
82+
assert secret is not None
83+
84+
await pilot_agents_db.verify_pilot_secret(
85+
pilot_job_reference=pilot_reference, pilot_hashed_secret=pilot_hashed_secret
86+
)
87+
88+
with pytest.raises(AuthorizationError):
89+
await pilot_agents_db.verify_pilot_secret(
90+
pilot_job_reference=pilot_reference,
91+
pilot_hashed_secret=hash("I love stawberries :)"),
92+
)
93+
94+
await pilot_agents_db.verify_pilot_secret(
95+
pilot_job_reference="I am a spider",
96+
pilot_hashed_secret=pilot_hashed_secret,
97+
)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from __future__ import annotations
2+
3+
from secrets import token_hex
4+
5+
from diracx.db.sql import PilotAgentsDB
6+
7+
# TODO: Move this hash function in diracx-logic, and rename it
8+
from diracx.db.sql.utils.functions import hash
9+
10+
11+
def generate_pilot_secret() -> str:
12+
# Can change with time
13+
return token_hex(32)
14+
15+
16+
async def add_pilot_credentials(pilot_id: int, pilot_db: PilotAgentsDB) -> str:
17+
18+
# Get a random string
19+
# Can be customized
20+
random_secret = generate_pilot_secret()
21+
22+
hashed_secret = hash(random_secret)
23+
24+
await pilot_db.add_pilot_credentials(
25+
pilot_id=pilot_id, pilot_hashed_secret=hashed_secret
26+
)
27+
28+
return random_secret
29+
30+
31+
def generate_pilot_scope(pilot: dict) -> str:
32+
return f"vo:{pilot['VO']}"
33+
34+
35+
async def try_login(
36+
pilot_reference: str, pilot_db: PilotAgentsDB, pilot_secret: str
37+
) -> None:
38+
39+
hashed_secret = hash(pilot_secret)
40+
41+
await pilot_db.verify_pilot_secret(
42+
pilot_hashed_secret=hashed_secret, pilot_job_reference=pilot_reference
43+
)

0 commit comments

Comments
 (0)