Skip to content

Commit d166bcb

Browse files
committed
⚡ Improve select query performance + handle flush/commit behaviours correctly
1 parent 0d28fdc commit d166bcb

File tree

4 files changed

+120
-76
lines changed

4 files changed

+120
-76
lines changed

sqlalchemy_memory/base/query.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@
3131
class MemoryQuery(Query):
3232
def __init__(self, entities, element):
3333
super().__init__(entities, element)
34-
assert len(entities) == 1, "Only single table queries are supported"
35-
3634
self._model = entities[0]
3735

3836
self._where_criteria = []
@@ -166,6 +164,10 @@ def _execute_query(self):
166164
for condition in self._where_criteria:
167165
collection = self._apply_condition(condition, collection)
168166

167+
if len(collection) == 0:
168+
# No need to go further
169+
return collection
170+
169171
# Apply order by
170172
for clause in reversed(self._order_by or []):
171173
reverse = False

sqlalchemy_memory/base/session.py

Lines changed: 101 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,62 @@
33
from sqlalchemy.sql.dml import Insert, Delete, Update
44
from sqlalchemy.engine import IteratorResult
55
from sqlalchemy.engine.cursor import SimpleResultMetaData
6+
from functools import lru_cache
7+
from collections import defaultdict
68

79
from .query import MemoryQuery
810
from ..logger import logger
911

12+
1013
class MemorySession(Session):
1114
def __init__(self, *args, **kwargs):
1215
super().__init__(*args, **kwargs)
1316
self._query_cls = MemoryQuery
17+
self._has_pending_merge = False
18+
self.store = self.get_bind().dialect._store
1419

15-
@property
16-
def raw_connection(self):
17-
return self.connection().connection.dbapi_connection
20+
# Non-committed inserts/deletes/updates
21+
self._to_add = defaultdict(list)
22+
self._to_delete = defaultdict(list)
23+
self._to_update = defaultdict(list)
1824

19-
@property
20-
def store(self):
21-
return self.raw_connection.store
25+
self._fetched = defaultdict(dict)
26+
27+
def add(self, obj, **kwargs):
28+
tablename = obj.__tablename__
29+
if not any(id(x) == id(obj) for x in self._to_add[tablename]):
30+
self._to_add[tablename].append(obj)
31+
32+
def delete(self, obj):
33+
tablename = obj.__tablename__
34+
self._to_delete[tablename].append(obj)
35+
36+
def update(self, tablename, pk_value, data):
37+
self._to_update[tablename].append((pk_value, data))
38+
39+
def _mark_as_fetched(self, instance):
40+
tablename = instance.__tablename__
41+
42+
pk_name = self.store._get_primary_key_name(instance)
43+
pk_value = getattr(instance, pk_name)
2244

23-
def add(self, instance, **kwargs):
24-
self.store.add(instance)
45+
if pk_value in self._fetched[tablename]:
46+
# Don't mark as fetched again
47+
return
48+
49+
original_values = {
50+
col.name: getattr(instance, col.name)
51+
for col in instance.__table__.columns
52+
}
53+
self._fetched[tablename][pk_value] = original_values
2554

2655
def get(self, entity, id, **kwargs):
2756
"""
2857
Return an instance based on the given primary key identifier, or ``None`` if not found.
2958
"""
3059
instance = self.store.get_by_primary_key(entity, id)
3160
if instance:
32-
self.store.mark_as_fetched(instance)
61+
self._mark_as_fetched(instance)
3362
return instance
3463

3564
def scalars(self, statement, **kwargs):
@@ -38,16 +67,25 @@ def scalars(self, statement, **kwargs):
3867
def scalar(self, statement, **kwargs):
3968
return self.execute(statement, **kwargs).scalar()
4069

41-
def _handle_select(self, statement: Select, **kwargs):
42-
# Detect single‑entity selects: select(MyModel)
43-
cd = statement.column_descriptions
44-
if len(cd) != 1 or cd[0]["entity"] is None:
45-
raise Exception("Model not found")
70+
@staticmethod
71+
@lru_cache(maxsize=256)
72+
def _get_metadata_for_annotated_table(annotated_table):
73+
"""
74+
Build minimal cursor metadata
75+
"""
76+
col_names = [col.name for col in annotated_table._columns]
77+
return SimpleResultMetaData([
78+
(col_name, None, None, None, None, None, None)
79+
for col_name in col_names
80+
])
4681

47-
model = cd[0]["entity"]
4882

49-
# Execute the query
83+
def _handle_select(self, statement: Select, **kwargs):
5084
entities = statement._raw_columns
85+
if len(entities) != 1:
86+
raise Exception("Only single‑entity SELECTs are supported")
87+
88+
# Execute the query
5189
q = MemoryQuery(entities, self)
5290

5391
# Apply WHERE
@@ -67,19 +105,16 @@ def _handle_select(self, statement: Select, **kwargs):
67105
objs = q.all()
68106

69107
for obj in objs:
70-
self.store.mark_as_fetched(obj)
71-
72-
# Build minimal cursor metadata
73-
metadata = SimpleResultMetaData([
74-
(col.name, None, None, None, None, None, None)
75-
for col in list(model.__table__.columns)
76-
])
108+
self._mark_as_fetched(obj)
77109

78110
# Wrap each object in a single‑element tuple, so .scalars() yields it
79111
wrapped = ((obj,) for obj in objs)
80112

113+
metadata = MemorySession._get_metadata_for_annotated_table(entities[0])
114+
81115
return IteratorResult(metadata, wrapped)
82116

117+
83118
def _handle_delete(self, statement: Delete, **kwargs):
84119
q = MemoryQuery([statement.table], self)
85120

@@ -89,7 +124,7 @@ def _handle_delete(self, statement: Delete, **kwargs):
89124
collection = q.all()
90125

91126
for obj in collection:
92-
self.store.delete(obj)
127+
self.delete(obj)
93128

94129
result = IteratorResult(SimpleResultMetaData([]), iter([]))
95130
result.rowcount = len(collection)
@@ -115,7 +150,7 @@ def _handle_insert(self, statement: Insert, params=None, **kwargs):
115150
instances = []
116151
for vals in vals_list:
117152
obj = model(**vals)
118-
self.store.add(obj)
153+
self.add(obj)
119154
instances.append(obj)
120155

121156
rowcount = len(instances)
@@ -157,15 +192,13 @@ def _handle_update(self, statement: Update, **kwargs):
157192
pk_col_name = self.store._get_primary_key_name(obj)
158193

159194
pk_value = getattr(obj, pk_col_name)
160-
self.store.update(tablename, pk_value, data)
195+
self.update(tablename, pk_value, data)
161196

162197
result = IteratorResult(SimpleResultMetaData([]), iter([]))
163198
result.rowcount = len(collection)
164199
return result
165200

166201
def execute(self, statement, params=None, **kwargs):
167-
#logger.debug(f"Executing query: {statement}")
168-
169202
if isinstance(statement, Select):
170203
return self._handle_select(statement, **kwargs)
171204

@@ -190,7 +223,8 @@ def merge(self, instance, **kwargs):
190223
existing = self.store.get_by_primary_key(instance, pk_value)
191224

192225
if existing:
193-
self.store.mark_as_fetched(existing)
226+
self._mark_as_fetched(existing)
227+
self._has_pending_merge = True
194228

195229
for column in instance.__table__.columns:
196230
field = column.name
@@ -205,16 +239,49 @@ def merge(self, instance, **kwargs):
205239
self.add(instance)
206240
return instance
207241

208-
def delete(self, instance):
209-
self.store.delete(instance)
242+
@property
243+
def dirty(self):
244+
return bool(self._to_add or self._to_delete or self._to_update) or self._has_pending_merge
245+
246+
def _is_clean(self):
247+
return not self.dirty
210248

211249
def flush(self, objects=None):
212-
pass
250+
if not self._transaction or not self._transaction._connections:
251+
self.connection() # Ensure a real connection is created
252+
253+
to_transfer = [
254+
"_to_add",
255+
"_to_update",
256+
"_to_delete",
257+
"_fetched",
258+
]
259+
for key in to_transfer:
260+
item = getattr(self, key)
261+
if not item:
262+
continue
263+
setattr(self.store, key, item.copy())
264+
item.clear()
213265

214266
def rollback(self, **kwargs):
215267
logger.debug("Rolling back ...")
268+
269+
self.store._fetched = self._fetched
216270
self.store.rollback()
217271

272+
self._has_pending_merge = False
273+
274+
self._to_add.clear()
275+
self._to_delete.clear()
276+
self._to_update.clear()
277+
self._fetched.clear()
278+
279+
218280
def commit(self):
219-
logger.debug("Committing ...")
220-
self.store.commit()
281+
if self.dirty:
282+
self.flush()
283+
284+
if self.store.dirty or self._has_pending_merge:
285+
logger.debug("Committing ...")
286+
self.store.commit()
287+
self._has_pending_merge = False

sqlalchemy_memory/base/store.py

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,43 +13,19 @@ def _reset(self):
1313
self.data = defaultdict(list)
1414
self.data_by_pk = defaultdict(dict)
1515

16-
# Uncommitted inserts/deletes/updates
17-
self._to_add = defaultdict(list)
18-
self._to_delete = defaultdict(list)
19-
self._to_update = defaultdict(list)
16+
# Non-committed inserts/deletes/updates
17+
self._to_add = {}
18+
self._to_delete = {}
19+
self._to_update = {}
2020

21-
self._fetched = defaultdict(dict)
21+
self._fetched = {}
2222

2323
# Auto increment counter per table
2424
self._pk_counter = defaultdict(int)
2525

26-
def add(self, obj):
27-
tablename = obj.__tablename__
28-
if not any(id(x) == id(obj) for x in self._to_add[tablename]):
29-
self._to_add[tablename].append(obj)
30-
31-
def delete(self, obj):
32-
tablename = obj.__tablename__
33-
self._to_delete[tablename].append(obj)
34-
35-
def update(self, tablename, pk_value, data):
36-
self._to_update[tablename].append((pk_value, data))
37-
38-
def mark_as_fetched(self, instance):
39-
tablename = instance.__tablename__
40-
41-
pk_name = self._get_primary_key_name(instance)
42-
pk_value = getattr(instance, pk_name)
43-
44-
if pk_value in self._fetched[tablename]:
45-
# Don't mark as fetched again
46-
return
47-
48-
original_values = {
49-
col.name: getattr(instance, col.name)
50-
for col in instance.__table__.columns
51-
}
52-
self._fetched[tablename][pk_value] = original_values
26+
@property
27+
def dirty(self):
28+
return bool(self._to_add or self._to_delete or self._to_update)
5329

5430
def commit(self):
5531
# apply deletes

tests/test_advanced.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@ class TestAdvanced:
1717
]
1818
)
1919
def test_like_patterns(self, SessionFactory, pattern, negate, expected_ids):
20-
with SessionFactory() as session:
21-
with session.begin():
22-
session.add_all([
23-
Item(id=1, name="foo"),
24-
Item(id=2, name="bar"),
25-
Item(id=3, name="foobar"),
26-
Item(id=4, name="barfoo"),
27-
])
20+
with SessionFactory.begin() as session:
21+
session.add_all([
22+
Item(id=1, name="foo"),
23+
Item(id=2, name="bar"),
24+
Item(id=3, name="foobar"),
25+
Item(id=4, name="barfoo"),
26+
])
2827

2928
with SessionFactory() as session:
3029
if negate:

0 commit comments

Comments
 (0)