|
2 | 2 |
|
3 | 3 | from datetime import datetime, timezone |
4 | 4 |
|
5 | | -from sqlalchemy import insert |
| 5 | +from sqlalchemy import DateTime, insert, select, update |
| 6 | +from sqlalchemy.exc import IntegrityError, NoResultFound |
| 7 | + |
| 8 | +from diracx.core.exceptions import ( |
| 9 | + AuthorizationError, |
| 10 | + PilotAlreadyExistsError, |
| 11 | + PilotNotFoundError, |
| 12 | +) |
6 | 13 |
|
7 | 14 | from ..utils import BaseSQLDB |
8 | | -from .schema import PilotAgents, PilotAgentsDBBase |
| 15 | +from .schema import PilotAgents, PilotAgentsDBBase, PilotRegistrations |
9 | 16 |
|
10 | 17 |
|
11 | 18 | class PilotAgentsDB(BaseSQLDB): |
@@ -44,3 +51,112 @@ async def add_pilot_references( |
44 | 51 | stmt = insert(PilotAgents).values(values) |
45 | 52 | await self.conn.execute(stmt) |
46 | 53 | return |
| 54 | + |
| 55 | + async def increment_pilot_secret_use( |
| 56 | + self, |
| 57 | + pilot_id: int, |
| 58 | + ) -> None: |
| 59 | + |
| 60 | + # Prepare the update statement |
| 61 | + stmt = ( |
| 62 | + update(PilotRegistrations) |
| 63 | + .values( |
| 64 | + pilot_secret_use_count=PilotRegistrations.pilot_secret_use_count + 1 |
| 65 | + ) |
| 66 | + .where(PilotRegistrations.pilot_id == pilot_id) |
| 67 | + ) |
| 68 | + |
| 69 | + # Execute the update using the connection |
| 70 | + res = await self.conn.execute(stmt) |
| 71 | + |
| 72 | + if res.rowcount == 0: |
| 73 | + raise PilotNotFoundError(pilot_id=pilot_id) |
| 74 | + |
| 75 | + async def verify_pilot_secret( |
| 76 | + self, pilot_job_reference: str, pilot_hashed_secret: str |
| 77 | + ) -> None: |
| 78 | + |
| 79 | + try: |
| 80 | + pilot = await self.get_pilot_by_reference(pilot_job_reference) |
| 81 | + except NoResultFound as e: |
| 82 | + raise PilotNotFoundError(pilot_ref=pilot_job_reference) from e |
| 83 | + |
| 84 | + pilot_id = pilot["PilotID"] |
| 85 | + |
| 86 | + stmt = ( |
| 87 | + select(PilotRegistrations) |
| 88 | + .where(PilotRegistrations.pilot_hashed_secret == pilot_hashed_secret) |
| 89 | + .where(PilotRegistrations.pilot_id == pilot_id) |
| 90 | + ) |
| 91 | + |
| 92 | + # Execute the request |
| 93 | + res = await self.conn.execute(stmt) |
| 94 | + |
| 95 | + result = res.fetchone() |
| 96 | + |
| 97 | + if result is None: |
| 98 | + raise AuthorizationError(detail="bad pilot_id / pilot_secret") |
| 99 | + |
| 100 | + # Increment the count |
| 101 | + await self.increment_pilot_secret_use(pilot_id=pilot_id) |
| 102 | + |
| 103 | + async def register_new_pilot( |
| 104 | + self, |
| 105 | + vo: str, |
| 106 | + pilot_job_reference: str, |
| 107 | + pilot_stamp: str, |
| 108 | + grid_type: str = "DIRAC", |
| 109 | + submission_time: DateTime | None = None, # ? |
| 110 | + last_update_time: DateTime | None = None, # = now? |
| 111 | + ) -> int | None: |
| 112 | + stmt = insert(PilotAgents).values( |
| 113 | + vo=vo, |
| 114 | + submission_time=submission_time, |
| 115 | + last_update_time=last_update_time, |
| 116 | + pilot_job_reference=pilot_job_reference, |
| 117 | + grid_type=grid_type, |
| 118 | + pilot_stamp=pilot_stamp, |
| 119 | + ) |
| 120 | + |
| 121 | + # Execute the request |
| 122 | + res = await self.conn.execute(stmt) |
| 123 | + |
| 124 | + new_pilot_id = res.inserted_primary_key |
| 125 | + |
| 126 | + # Returns the new pilot ID |
| 127 | + return int(new_pilot_id[0]) if new_pilot_id else None |
| 128 | + |
| 129 | + async def add_pilot_credentials(self, pilot_id: int, pilot_hashed_secret: str): |
| 130 | + |
| 131 | + stmt = insert(PilotRegistrations).values( |
| 132 | + pilot_id=pilot_id, pilot_hashed_secret=pilot_hashed_secret |
| 133 | + ) |
| 134 | + |
| 135 | + try: |
| 136 | + await self.conn.execute(stmt) |
| 137 | + except IntegrityError as e: |
| 138 | + if "foreign key" in str(e.orig).lower(): |
| 139 | + raise PilotNotFoundError(pilot_id=pilot_id) from e |
| 140 | + if "duplicate entry" in str(e.orig).lower(): |
| 141 | + raise PilotAlreadyExistsError( |
| 142 | + pilot_id=pilot_id, detail="this pilot has already credentials" |
| 143 | + ) from e |
| 144 | + |
| 145 | + async def fetch_all_pilots(self): |
| 146 | + stmt = select(PilotRegistrations).with_for_update() |
| 147 | + result = await self.conn.execute(stmt) |
| 148 | + |
| 149 | + # Convert results into a dictionary |
| 150 | + pilots = [dict(row._mapping) for row in result] |
| 151 | + |
| 152 | + return pilots |
| 153 | + |
| 154 | + async def get_pilot_by_reference(self, pilot_ref: str): |
| 155 | + stmt = ( |
| 156 | + select(PilotAgents) |
| 157 | + .with_for_update() |
| 158 | + .where(PilotAgents.pilot_job_reference == pilot_ref) |
| 159 | + ) |
| 160 | + |
| 161 | + # We assume it is unique... |
| 162 | + return dict((await self.conn.execute(stmt)).one()._mapping) |
0 commit comments