11from __future__ import annotations
22
33from datetime import datetime , timezone
4+ from typing import Sequence
45
5- from sqlalchemy import DateTime , insert , select , update
6+ from sqlalchemy import insert , select , update
67from sqlalchemy .exc import IntegrityError , NoResultFound
78
89from diracx .core .exceptions import (
@@ -26,7 +27,7 @@ async def add_pilot_references(
2627 vo : str ,
2728 grid_type : str = "DIRAC" ,
2829 pilot_stamps : dict | None = None ,
29- ) -> None :
30+ ) -> Sequence : # Return a list of primary keys
3031
3132 if pilot_stamps is None :
3233 pilot_stamps = {}
@@ -47,10 +48,18 @@ async def add_pilot_references(
4748 for ref in pilot_ref
4849 ]
4950
50- # Insert multiple rows in a single execute call
51- stmt = insert (PilotAgents ).values (values )
52- await self .conn .execute (stmt )
53- return
51+ # Insert multiple rows in a single execute call and use 'returning' to get primary keys
52+ stmt = (
53+ insert (PilotAgents ).values (values ).returning (PilotAgents .pilot_id )
54+ ) # Assuming 'id' is the primary key
55+ result = await self .conn .execute (stmt )
56+
57+ # Use .scalars() and .all() to get the primary keys directly in a list
58+ primary_keys = (
59+ result .scalars ().all ()
60+ ) # This returns a flat list of primary keys
61+
62+ return primary_keys
5463
5564 async def increment_pilot_secret_use (
5665 self ,
@@ -100,32 +109,6 @@ async def verify_pilot_secret(
100109 # Increment the count
101110 await self .increment_pilot_secret_use (pilot_id = pilot_id )
102111
103- async def register_new_pilot (
104- self ,
105- vo : str ,
106- pilot_job_reference : str ,
107- pilot_stamp : str ,
108- grid_type : str = "DIRAC" ,
109- submission_time : DateTime | None = None , # ?
110- last_update_time : DateTime | None = None , # = now?
111- ) -> int | None :
112- stmt = insert (PilotAgents ).values (
113- vo = vo ,
114- submission_time = submission_time ,
115- last_update_time = last_update_time ,
116- pilot_job_reference = pilot_job_reference ,
117- grid_type = grid_type ,
118- pilot_stamp = pilot_stamp ,
119- )
120-
121- # Execute the request
122- res = await self .conn .execute (stmt )
123-
124- new_pilot_id = res .inserted_primary_key
125-
126- # Returns the new pilot ID
127- return int (new_pilot_id [0 ]) if new_pilot_id else None
128-
129112 async def add_pilot_credentials (self , pilot_id : int , pilot_hashed_secret : str ):
130113
131114 stmt = insert (PilotRegistrations ).values (
0 commit comments