Skip to content

Commit b34f865

Browse files
fix: Fixed models and reoranizing db functions
1 parent 64aece6 commit b34f865

File tree

2 files changed

+39
-41
lines changed
  • diracx-core/src/diracx/core
  • diracx-db/src/diracx/db/sql/pilots

2 files changed

+39
-41
lines changed

diracx-core/src/diracx/core/models.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -274,17 +274,7 @@ class JobCommand(BaseModel):
274274
arguments: str | None = None
275275

276276

277-
class PilotInfo(BaseModel):
278-
sub: str
279-
pilot_stamp: str
280-
vo: str
281-
282-
283-
class PilotStampInfo(BaseModel):
284-
pilot_stamp: str
285-
286-
287-
class PilotFieldsMapping(BaseModel):
277+
class PilotFieldsMapping(BaseModel, extra="forbid"):
288278
"""All the fields that a user can modify on a Pilot (except PilotStamp)."""
289279

290280
PilotStamp: str

diracx-db/src/diracx/db/sql/pilots/db.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ class PilotAgentsDB(BaseSQLDB):
3232

3333
metadata = PilotAgentsDBBase.metadata
3434

35+
# ----------------------------- Insert Functions -----------------------------
36+
3537
async def add_pilots_bulk(
3638
self,
3739
pilot_stamps: list[str],
@@ -67,18 +69,6 @@ async def add_pilots_bulk(
6769

6870
await self.conn.execute(stmt)
6971

70-
async def delete_pilots_by_stamps_bulk(self, pilot_stamps: list[str]):
71-
"""Bulk delete pilots.
72-
73-
Raises PilotNotFound if one of the pilot was not found.
74-
"""
75-
stmt = delete(PilotAgents).where(PilotAgents.pilot_stamp.in_(pilot_stamps))
76-
77-
res = await self.conn.execute(stmt)
78-
79-
if res.rowcount != len(pilot_stamps):
80-
raise PilotNotFoundError(data={"pilot_stamps": str(pilot_stamps)})
81-
8272
async def associate_pilot_with_jobs(
8373
self, job_to_pilot_mapping: list[dict[str, Any]]
8474
):
@@ -124,6 +114,40 @@ async def associate_pilot_with_jobs(
124114
"Engine Specific error not caught" + str(e)
125115
) from e
126116

117+
# ----------------------------- Delete Functions -----------------------------
118+
119+
async def delete_pilots_by_stamps_bulk(self, pilot_stamps: list[str]):
120+
"""Bulk delete pilots.
121+
122+
Raises PilotNotFound if one of the pilot was not found.
123+
"""
124+
stmt = delete(PilotAgents).where(PilotAgents.pilot_stamp.in_(pilot_stamps))
125+
126+
res = await self.conn.execute(stmt)
127+
128+
if res.rowcount != len(pilot_stamps):
129+
raise PilotNotFoundError(data={"pilot_stamps": str(pilot_stamps)})
130+
131+
async def clear_pilots_bulk(
132+
self, cutoff_date: datetime, delete_only_aborted: bool
133+
) -> int:
134+
"""Bulk delete pilots that have SubmissionTime before the 'cutoff_date'.
135+
Returns the number of deletion.
136+
"""
137+
# TODO: Add test (Millisec?)
138+
stmt = delete(PilotAgents).where(PilotAgents.submission_time < cutoff_date)
139+
140+
# If delete_only_aborted is True, add the condition for 'Status' being 'Aborted'
141+
if delete_only_aborted:
142+
stmt = stmt.where(PilotAgents.status == "Aborted")
143+
144+
# Execute the statement
145+
res = await self.conn.execute(stmt)
146+
147+
return res.rowcount
148+
149+
# ----------------------------- Update Functions -----------------------------
150+
127151
async def update_pilot_fields_bulk(
128152
self, pilot_stamps_to_fields_mapping: list[PilotFieldsMapping]
129153
):
@@ -178,6 +202,8 @@ async def update_pilot_fields_bulk(
178202
data={"mapping": str(pilot_stamps_to_fields_mapping)}
179203
)
180204

205+
# ----------------------------- Search Functions -----------------------------
206+
181207
async def search_pilots(
182208
self,
183209
parameters: list[str] | None,
@@ -219,21 +245,3 @@ async def search_pilot_to_job_mapping(
219245
per_page=per_page,
220246
page=page,
221247
)
222-
223-
async def clear_pilots_bulk(
224-
self, cutoff_date: datetime, delete_only_aborted: bool
225-
) -> int:
226-
"""Bulk delete pilots that have SubmissionTime before the 'cutoff_date'.
227-
Returns the number of deletion.
228-
"""
229-
# TODO: Add test (Millisec?)
230-
stmt = delete(PilotAgents).where(PilotAgents.submission_time < cutoff_date)
231-
232-
# If delete_only_aborted is True, add the condition for 'Status' being 'Aborted'
233-
if delete_only_aborted:
234-
stmt = stmt.where(PilotAgents.status == "Aborted")
235-
236-
# Execute the statement
237-
res = await self.conn.execute(stmt)
238-
239-
return res.rowcount

0 commit comments

Comments
 (0)