Skip to content

Commit bbe9ad7

Browse files
HeloiseJoffeStellatsuu
authored andcommitted
refactor: Change the DeclarativeBase sqlalchemy style (DIRACGrid#748)
* Refactor: Change the DeclarativeBase sqlalchemy style * Refactor: apply SQLAlchemy 2.0 migration for model mappings
1 parent 57a0e97 commit bbe9ad7

19 files changed

Lines changed: 419 additions & 247 deletions

File tree

diracx-db/src/diracx/db/sql/auth/db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ async def insert_device_flow(
163163
for _ in range(MAX_RETRY):
164164
user_code = "".join(
165165
secrets.choice(USER_CODE_ALPHABET)
166-
for _ in range(DeviceFlows.user_code.type.length) # type: ignore
166+
for _ in range(DeviceFlows.user_code.type.length)
167167
)
168168
device_code = secrets.token_urlsafe()
169169

diracx-db/src/diracx/db/sql/auth/schema.py

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,35 @@
11
from __future__ import annotations
22

33
from enum import Enum, auto
4+
from typing import Any, Optional
5+
from uuid import UUID
46

57
from sqlalchemy import (
68
JSON,
79
Index,
810
String,
911
Uuid,
1012
)
11-
from sqlalchemy.orm import declarative_base
13+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
1214

1315
from diracx.db.sql.utils import (
14-
Column,
15-
DateNowColumn,
16-
EnumColumn,
17-
NullColumn,
16+
datetime_now,
17+
enum_column,
18+
str128,
19+
str255,
20+
str1024,
1821
)
1922

2023
USER_CODE_LENGTH = 8
2124

22-
Base = declarative_base()
25+
26+
class Base(DeclarativeBase):
27+
type_annotation_map = {
28+
str128: String(128),
29+
str255: String(255),
30+
str1024: String(1024),
31+
dict[str, Any]: JSON,
32+
}
2333

2434

2535
class FlowStatus(Enum):
@@ -47,27 +57,35 @@ class FlowStatus(Enum):
4757

4858
class DeviceFlows(Base):
4959
__tablename__ = "DeviceFlows"
50-
user_code = Column("UserCode", String(USER_CODE_LENGTH), primary_key=True)
51-
status = EnumColumn("Status", FlowStatus, server_default=FlowStatus.PENDING.name)
52-
creation_time = DateNowColumn("CreationTime")
53-
client_id = Column("ClientID", String(255))
54-
scope = Column("Scope", String(1024))
55-
device_code = Column("DeviceCode", String(128), unique=True) # Should be a hash
56-
id_token = NullColumn("IDToken", JSON())
60+
user_code: Mapped[str] = mapped_column(
61+
"UserCode", String(USER_CODE_LENGTH), primary_key=True
62+
)
63+
status: Mapped[FlowStatus] = enum_column(
64+
"Status", FlowStatus, server_default=FlowStatus.PENDING.name
65+
)
66+
creation_time: Mapped[datetime_now] = mapped_column("CreationTime")
67+
client_id: Mapped[str255] = mapped_column("ClientID")
68+
scope: Mapped[str1024] = mapped_column("Scope")
69+
device_code: Mapped[str128] = mapped_column(
70+
"DeviceCode", unique=True
71+
) # Should be a hash
72+
id_token: Mapped[Optional[dict[str, Any]]] = mapped_column("IDToken")
5773

5874

5975
class AuthorizationFlows(Base):
6076
__tablename__ = "AuthorizationFlows"
61-
uuid = Column("UUID", Uuid(as_uuid=False), primary_key=True)
62-
status = EnumColumn("Status", FlowStatus, server_default=FlowStatus.PENDING.name)
63-
client_id = Column("ClientID", String(255))
64-
creation_time = DateNowColumn("CreationTime")
65-
scope = Column("Scope", String(1024))
66-
code_challenge = Column("CodeChallenge", String(255))
67-
code_challenge_method = Column("CodeChallengeMethod", String(8))
68-
redirect_uri = Column("RedirectURI", String(255))
69-
code = NullColumn("Code", String(255)) # Should be a hash
70-
id_token = NullColumn("IDToken", JSON())
77+
uuid: Mapped[UUID] = mapped_column("UUID", Uuid(as_uuid=False), primary_key=True)
78+
status: Mapped[FlowStatus] = enum_column(
79+
"Status", FlowStatus, server_default=FlowStatus.PENDING.name
80+
)
81+
client_id: Mapped[str255] = mapped_column("ClientID")
82+
creation_time: Mapped[datetime_now] = mapped_column("CreationTime")
83+
scope: Mapped[str1024] = mapped_column("Scope")
84+
code_challenge: Mapped[str255] = mapped_column("CodeChallenge")
85+
code_challenge_method: Mapped[str] = mapped_column("CodeChallengeMethod", String(8))
86+
redirect_uri: Mapped[str255] = mapped_column("RedirectURI")
87+
code: Mapped[Optional[str255]] = mapped_column("Code") # Should be a hash
88+
id_token: Mapped[Optional[dict[str, Any]]] = mapped_column("IDToken")
7189

7290

7391
class RefreshTokenStatus(Enum):
@@ -93,13 +111,13 @@ class RefreshTokens(Base):
93111

94112
__tablename__ = "RefreshTokens"
95113
# Refresh token attributes
96-
jti = Column("JTI", Uuid(as_uuid=False), primary_key=True)
97-
status = EnumColumn(
114+
jti: Mapped[UUID] = mapped_column("JTI", Uuid(as_uuid=False), primary_key=True)
115+
status: Mapped[RefreshTokenStatus] = enum_column(
98116
"Status", RefreshTokenStatus, server_default=RefreshTokenStatus.CREATED.name
99117
)
100-
scope = Column("Scope", String(1024))
118+
scope: Mapped[str1024] = mapped_column("Scope")
101119

102120
# User attributes bound to the refresh token
103-
sub = Column("Sub", String(256), index=True)
121+
sub: Mapped[str] = mapped_column("Sub", String(256), index=True)
104122

105123
__table_args__ = (Index("index_status_sub", status, sub),)

diracx-db/src/diracx/db/sql/dummy/schema.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,31 @@
22
# in place of the SQLAlchemy one. Have a look at them
33
from __future__ import annotations
44

5-
from sqlalchemy import ForeignKey, Integer, String, Uuid
6-
from sqlalchemy.orm import declarative_base
5+
from uuid import UUID
76

8-
from diracx.db.sql.utils import Column, DateNowColumn
7+
from sqlalchemy import ForeignKey, String
8+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
99

10-
Base = declarative_base()
10+
from diracx.db.sql.utils import datetime_now, str255
11+
12+
13+
class Base(DeclarativeBase):
14+
type_annotation_map = {
15+
str255: String(255),
16+
}
1117

1218

1319
class Owners(Base):
1420
__tablename__ = "Owners"
15-
owner_id = Column("OwnerID", Integer, primary_key=True, autoincrement=True)
16-
creation_time = DateNowColumn("CreationTime")
17-
name = Column("Name", String(255))
21+
owner_id: Mapped[int] = mapped_column(
22+
"OwnerID", primary_key=True, autoincrement=True
23+
)
24+
creation_time: Mapped[datetime_now] = mapped_column("CreationTime")
25+
name: Mapped[str255] = mapped_column("Name")
1826

1927

2028
class Cars(Base):
2129
__tablename__ = "Cars"
22-
license_plate = Column("LicensePlate", Uuid(), primary_key=True)
23-
model = Column("Model", String(255))
24-
owner_id = Column("OwnerID", Integer, ForeignKey(Owners.owner_id))
30+
license_plate: Mapped[UUID] = mapped_column("LicensePlate", primary_key=True)
31+
model: Mapped[str255] = mapped_column("Model")
32+
owner_id: Mapped[int] = mapped_column("OwnerID", ForeignKey(Owners.owner_id))

diracx-db/src/diracx/db/sql/job/db.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from datetime import datetime, timezone
66
from typing import TYPE_CHECKING, Any, Iterable
77

8-
from sqlalchemy import bindparam, case, delete, literal, select, update
8+
from sqlalchemy import bindparam, case, delete, insert, literal, select, update
99

1010
if TYPE_CHECKING:
1111
from sqlalchemy.sql.elements import BindParameter
@@ -75,7 +75,7 @@ async def search(
7575
async def create_job(self, compressed_original_jdl: str):
7676
"""Insert a new job with original JDL. Returns inserted job id."""
7777
result = await self.conn.execute(
78-
JobJDLs.__table__.insert().values(
78+
insert(JobJDLs).values(
7979
JDL="",
8080
JobRequirements="",
8181
OriginalJDL=compressed_original_jdl,
@@ -91,7 +91,7 @@ async def delete_jobs(self, job_ids: list[int]):
9191
async def insert_input_data(self, lfns: dict[int, list[str]]):
9292
"""Insert input data for jobs."""
9393
await self.conn.execute(
94-
InputData.__table__.insert(),
94+
insert(InputData),
9595
[
9696
{
9797
"JobID": job_id,
@@ -105,7 +105,7 @@ async def insert_input_data(self, lfns: dict[int, list[str]]):
105105
async def insert_job_attributes(self, jobs_to_update: dict[int, dict]):
106106
"""Insert the job attributes."""
107107
await self.conn.execute(
108-
Jobs.__table__.insert(),
108+
insert(Jobs),
109109
[
110110
{
111111
"JobID": job_id,
@@ -118,9 +118,7 @@ async def insert_job_attributes(self, jobs_to_update: dict[int, dict]):
118118
async def update_job_jdls(self, jdls_to_update: dict[int, str]):
119119
"""Update the JDL, typically just after inserting the original JDL, or rescheduling, for example."""
120120
await self.conn.execute(
121-
JobJDLs.__table__.update().where(
122-
JobJDLs.__table__.c.JobID == bindparam("b_JobID")
123-
),
121+
update(JobJDLs).where(JobJDLs.__table__.c.JobID == bindparam("b_JobID")),
124122
[
125123
{
126124
"b_JobID": job_id,
@@ -186,7 +184,7 @@ async def get_job_jdls(self, job_ids, original: bool = False) -> dict[int, str]:
186184
async def set_job_commands(self, commands: list[tuple[int, str, str]]) -> None:
187185
"""Store a command to be passed to the job together with the next heart beat."""
188186
await self.conn.execute(
189-
JobCommands.__table__.insert(),
187+
insert(JobCommands),
190188
[
191189
{
192190
"JobID": job_id,
@@ -261,7 +259,7 @@ async def add_heartbeat_data(
261259
}
262260
for key, value in dynamic_data.items()
263261
]
264-
await self.conn.execute(HeartBeatLoggingInfo.__table__.insert().values(values))
262+
await self.conn.execute(insert(HeartBeatLoggingInfo).values(values))
265263

266264
async def get_job_commands(self, job_ids: Iterable[int]) -> list[JobCommand]:
267265
"""Get a command to be passed to the job together with the next heartbeat.

0 commit comments

Comments
 (0)