Skip to content

Commit 3462a00

Browse files
fix the unnecessary relaod of data on pks, raise on non-returning backends instead of current buggy behaviour (#919)
Co-authored-by: collerek <collerek@gmail.com>
1 parent 077bcc5 commit 3462a00

3 files changed

Lines changed: 211 additions & 6 deletions

File tree

ormar/databases/query_executor.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,19 @@ async def execute(self, query: Executable) -> Any:
6161
Execute a query (INSERT, UPDATE, DELETE).
6262
6363
:param query: SQLAlchemy query expression
64-
:return: For INSERT, returns last row id; for UPDATE/DELETE, returns row count
64+
:return: For INSERT, the inserted primary key or ``None`` if the backend
65+
cannot return one (e.g. Oracle MySQL inserting into a
66+
non-AUTO_INCREMENT pk with a server default — no RETURNING support).
67+
For UPDATE/DELETE, the row count.
6568
"""
6669
result: CursorResult[Any] = await self._connection.execute(query)
6770

68-
# For INSERT queries, try to get the inserted primary key
69-
# PostgreSQL/MySQL use inserted_primary_key, SQLite uses lastrowid
70-
if result.context and result.context.isinsert: # pragma: no cover
71+
# For INSERT queries, try to get the inserted primary key via the
72+
# dialect's best-available mechanism (RETURNING on PostgreSQL / SQLite
73+
# 3.35+ / MariaDB 10.5+, LAST_INSERT_ID() on MySQL AUTO_INCREMENT).
74+
# Do NOT fall back to rowcount here: rowcount is not a pk, and
75+
# returning it would silently corrupt `Model.pk` in `save()`.
76+
if result.context and result.context.isinsert:
7177
if result.inserted_primary_key:
7278
pk_value = result.inserted_primary_key[0]
7379
if pk_value is not None:
@@ -76,6 +82,8 @@ async def execute(self, query: Executable) -> Any:
7682
if hasattr(result, "lastrowid") and result.lastrowid: # pragma: no cover
7783
return result.lastrowid
7884

85+
return None # pragma: no cover
86+
7987
return result.rowcount if result.rowcount is not None else 0
8088

8189
async def execute_many(

ormar/models/model.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,34 @@ async def save(self: T) -> T:
106106
expr = self.ormar_config.table.insert()
107107
expr = expr.values(**self_fields)
108108

109+
pkname = self.ormar_config.pkname
110+
pk_returned_from_insert = False
109111
pk = await self._execute_query(expr)
110112
if pk and isinstance(pk, self.pk_type()):
111-
setattr(self, self.ormar_config.pkname, pk)
113+
setattr(self, pkname, pk)
114+
pk_returned_from_insert = True
115+
116+
if self.pk is None:
117+
raise ModelPersistenceError( # pragma: no cover
118+
f"Could not recover the generated primary key for "
119+
f"{self.__class__.__name__} after INSERT. This happens on "
120+
"backends that lack RETURNING support for server-side "
121+
"defaults on non-AUTO_INCREMENT primary keys — most notably "
122+
"Oracle MySQL, which does not implement RETURNING in any "
123+
"version. Use autoincrement=True, provide the primary key "
124+
"client-side, or switch to a RETURNING-capable backend "
125+
"(PostgreSQL, SQLite 3.35+, MariaDB 10.5+)."
126+
)
112127

113128
self.set_save_status(True)
114-
# refresh server side defaults
129+
# refresh server-side defaults — but skip the pk if the insert already
130+
# returned it, so save() stays a single round-trip when the pk is the
131+
# only server_default field.
115132
if any(
116133
field.server_default is not None
117134
for name, field in self.ormar_config.model_fields.items()
118135
if name not in self_fields
136+
and not (pk_returned_from_insert and name == pkname)
119137
):
120138
await self.load()
121139

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
"""Tests for ``Model.save()`` behavior around ``server_default`` on the pk.
2+
3+
Covers two closely related concerns:
4+
5+
1. N+1 fix (PR #919) — when the pk is the only ``server_default`` field, the
6+
INSERT's RETURNING clause already provides the pk, so ``save()`` must not
7+
issue a second SELECT to reload the model.
8+
9+
2. Pk-recovery loud-fail — on backends that cannot return a server-generated
10+
pk (Oracle MySQL has no RETURNING), ``save()`` must raise
11+
``ModelPersistenceError`` rather than silently storing a bogus pk (the old
12+
behavior was to return ``rowcount`` from the executor, which ``save()``
13+
then mistook for the pk).
14+
"""
15+
16+
from typing import Any, List
17+
18+
import pytest
19+
from sqlalchemy import event, text
20+
21+
import ormar
22+
from ormar.exceptions import ModelPersistenceError
23+
from tests.lifespan import init_tests
24+
from tests.settings import create_config
25+
26+
base_ormar_config = create_config()
27+
28+
_IS_MYSQL = "mysql" in base_ormar_config.database.url
29+
30+
31+
class ServerDefaultPk(ormar.Model):
32+
ormar_config = base_ormar_config.copy(tablename="server_default_pk")
33+
34+
id: int = ormar.Integer(
35+
primary_key=True, autoincrement=False, server_default=text("100")
36+
)
37+
name: str = ormar.String(max_length=100)
38+
39+
40+
class ServerDefaultNonPk(ormar.Model):
41+
ormar_config = base_ormar_config.copy(tablename="server_default_nonpk")
42+
43+
id: int = ormar.Integer(primary_key=True)
44+
name: str = ormar.String(max_length=100)
45+
company: str = ormar.String(max_length=100, server_default="Acme")
46+
47+
48+
class ServerDefaultPkAndNonPk(ormar.Model):
49+
ormar_config = base_ormar_config.copy(tablename="server_default_pk_and_nonpk")
50+
51+
id: int = ormar.Integer(
52+
primary_key=True, autoincrement=False, server_default=text("200")
53+
)
54+
name: str = ormar.String(max_length=100)
55+
company: str = ormar.String(max_length=100, server_default="Acme")
56+
57+
58+
create_test_database = init_tests(base_ormar_config)
59+
60+
61+
class _StatementCounter:
62+
"""Records every statement executed on the sqlalchemy engine."""
63+
64+
def __init__(self) -> None:
65+
self.statements: List[str] = []
66+
67+
def __enter__(self) -> "_StatementCounter":
68+
sync_engine = base_ormar_config.database.engine.sync_engine
69+
70+
def before_cursor_execute(
71+
conn: Any,
72+
cursor: Any,
73+
statement: str,
74+
parameters: Any,
75+
context: Any,
76+
executemany: bool,
77+
) -> None:
78+
self.statements.append(statement)
79+
80+
self._listener = before_cursor_execute
81+
self._sync_engine = sync_engine
82+
event.listen(sync_engine, "before_cursor_execute", self._listener)
83+
return self
84+
85+
def __exit__(self, *exc: Any) -> None:
86+
event.remove(self._sync_engine, "before_cursor_execute", self._listener)
87+
88+
89+
def _table_selects(statements: List[str], tablename: str) -> List[str]:
90+
return [
91+
s
92+
for s in statements
93+
if s.lstrip().upper().startswith("SELECT") and tablename in s
94+
]
95+
96+
97+
@pytest.mark.asyncio
98+
@pytest.mark.skipif(
99+
_IS_MYSQL,
100+
reason=(
101+
"Oracle MySQL has no RETURNING clause, so a server_default on a "
102+
"non-AUTO_INCREMENT pk cannot be recovered — covered by "
103+
"test_save_raises_when_server_default_pk_cannot_be_recovered instead."
104+
),
105+
)
106+
async def test_save_does_not_reload_when_only_pk_has_server_default(): # noqa: E501 # pragma: no cover
107+
"""INSERT returns the server-generated pk, so no SELECT should follow."""
108+
async with base_ormar_config.database:
109+
async with base_ormar_config.database.transaction(force_rollback=True):
110+
with _StatementCounter() as counter:
111+
instance = ServerDefaultPk(name="first")
112+
await instance.save()
113+
114+
selects = _table_selects(
115+
counter.statements, ServerDefaultPk.ormar_config.tablename
116+
)
117+
assert instance.pk is not None
118+
assert selects == [], counter.statements
119+
120+
121+
@pytest.mark.asyncio
122+
async def test_save_still_reloads_when_non_pk_has_server_default():
123+
"""Regression guard: non-pk server defaults must still trigger a reload."""
124+
async with base_ormar_config.database:
125+
async with base_ormar_config.database.transaction(force_rollback=True):
126+
with _StatementCounter() as counter:
127+
instance = ServerDefaultNonPk(id=1, name="first")
128+
await instance.save()
129+
130+
selects = _table_selects(
131+
counter.statements, ServerDefaultNonPk.ormar_config.tablename
132+
)
133+
assert instance.company == "Acme"
134+
assert len(selects) == 1, counter.statements
135+
136+
137+
@pytest.mark.asyncio
138+
@pytest.mark.skipif(
139+
_IS_MYSQL,
140+
reason=(
141+
"Oracle MySQL has no RETURNING clause, so a server_default on a "
142+
"non-AUTO_INCREMENT pk cannot be recovered — covered by "
143+
"test_save_raises_when_server_default_pk_cannot_be_recovered instead."
144+
),
145+
)
146+
async def test_save_reloads_once_when_both_pk_and_non_pk_have_server_default(): # noqa: E501 # pragma: no cover
147+
"""Mixed case: still need exactly one reload for the non-pk column."""
148+
async with base_ormar_config.database:
149+
async with base_ormar_config.database.transaction(force_rollback=True):
150+
with _StatementCounter() as counter:
151+
instance = ServerDefaultPkAndNonPk(name="first")
152+
await instance.save()
153+
154+
selects = _table_selects(
155+
counter.statements, ServerDefaultPkAndNonPk.ormar_config.tablename
156+
)
157+
assert instance.pk is not None
158+
assert instance.company == "Acme"
159+
assert len(selects) == 1, counter.statements
160+
161+
162+
@pytest.mark.asyncio
163+
@pytest.mark.skipif(
164+
not _IS_MYSQL,
165+
reason=(
166+
"This loud-fail path only fires on backends that cannot return a "
167+
"server-generated pk. RETURNING-capable backends (PostgreSQL, "
168+
"SQLite 3.35+, MariaDB 10.5+) succeed here."
169+
),
170+
)
171+
async def test_save_raises_when_server_default_pk_cannot_be_recovered(): # noqa: E501 # pragma: no cover
172+
"""Oracle MySQL: no RETURNING → save() must raise, not silently store a
173+
bogus pk (the old bug was to coerce rowcount into ``Model.pk``)."""
174+
async with base_ormar_config.database:
175+
async with base_ormar_config.database.transaction(force_rollback=True):
176+
instance = ServerDefaultPk(name="first")
177+
with pytest.raises(ModelPersistenceError, match="primary key"):
178+
await instance.save()
179+
assert instance.pk is None

0 commit comments

Comments
 (0)