33from sqlalchemy .sql .dml import Insert , Delete , Update
44from sqlalchemy .engine import IteratorResult
55from sqlalchemy .engine .cursor import SimpleResultMetaData
6+ from functools import lru_cache
7+ from collections import defaultdict
68
79from .query import MemoryQuery
810from ..logger import logger
911
12+
1013class MemorySession (Session ):
1114 def __init__ (self , * args , ** kwargs ):
1215 super ().__init__ (* args , ** kwargs )
1316 self ._query_cls = MemoryQuery
17+ self ._has_pending_merge = False
18+ self .store = self .get_bind ().dialect ._store
1419
15- @property
16- def raw_connection (self ):
17- return self .connection ().connection .dbapi_connection
20+ # Non-committed inserts/deletes/updates
21+ self ._to_add = defaultdict (list )
22+ self ._to_delete = defaultdict (list )
23+ self ._to_update = defaultdict (list )
1824
19- @property
20- def store (self ):
21- return self .raw_connection .store
25+ self ._fetched = defaultdict (dict )
26+
27+ def add (self , obj , ** kwargs ):
28+ tablename = obj .__tablename__
29+ if not any (id (x ) == id (obj ) for x in self ._to_add [tablename ]):
30+ self ._to_add [tablename ].append (obj )
31+
32+ def delete (self , obj ):
33+ tablename = obj .__tablename__
34+ self ._to_delete [tablename ].append (obj )
35+
36+ def update (self , tablename , pk_value , data ):
37+ self ._to_update [tablename ].append ((pk_value , data ))
38+
39+ def _mark_as_fetched (self , instance ):
40+ tablename = instance .__tablename__
41+
42+ pk_name = self .store ._get_primary_key_name (instance )
43+ pk_value = getattr (instance , pk_name )
2244
23- def add (self , instance , ** kwargs ):
24- self .store .add (instance )
45+ if pk_value in self ._fetched [tablename ]:
46+ # Don't mark as fetched again
47+ return
48+
49+ original_values = {
50+ col .name : getattr (instance , col .name )
51+ for col in instance .__table__ .columns
52+ }
53+ self ._fetched [tablename ][pk_value ] = original_values
2554
2655 def get (self , entity , id , ** kwargs ):
2756 """
2857 Return an instance based on the given primary key identifier, or ``None`` if not found.
2958 """
3059 instance = self .store .get_by_primary_key (entity , id )
3160 if instance :
32- self .store . mark_as_fetched (instance )
61+ self ._mark_as_fetched (instance )
3362 return instance
3463
3564 def scalars (self , statement , ** kwargs ):
@@ -38,16 +67,25 @@ def scalars(self, statement, **kwargs):
3867 def scalar (self , statement , ** kwargs ):
3968 return self .execute (statement , ** kwargs ).scalar ()
4069
41- def _handle_select (self , statement : Select , ** kwargs ):
42- # Detect single‑entity selects: select(MyModel)
43- cd = statement .column_descriptions
44- if len (cd ) != 1 or cd [0 ]["entity" ] is None :
45- raise Exception ("Model not found" )
70+ @staticmethod
71+ @lru_cache (maxsize = 256 )
72+ def _get_metadata_for_annotated_table (annotated_table ):
73+ """
74+ Build minimal cursor metadata
75+ """
76+ col_names = [col .name for col in annotated_table ._columns ]
77+ return SimpleResultMetaData ([
78+ (col_name , None , None , None , None , None , None )
79+ for col_name in col_names
80+ ])
4681
47- model = cd [0 ]["entity" ]
4882
49- # Execute the query
83+ def _handle_select ( self , statement : Select , ** kwargs ):
5084 entities = statement ._raw_columns
85+ if len (entities ) != 1 :
86+ raise Exception ("Only single‑entity SELECTs are supported" )
87+
88+ # Execute the query
5189 q = MemoryQuery (entities , self )
5290
5391 # Apply WHERE
@@ -67,19 +105,16 @@ def _handle_select(self, statement: Select, **kwargs):
67105 objs = q .all ()
68106
69107 for obj in objs :
70- self .store .mark_as_fetched (obj )
71-
72- # Build minimal cursor metadata
73- metadata = SimpleResultMetaData ([
74- (col .name , None , None , None , None , None , None )
75- for col in list (model .__table__ .columns )
76- ])
108+ self ._mark_as_fetched (obj )
77109
78110 # Wrap each object in a single‑element tuple, so .scalars() yields it
79111 wrapped = ((obj ,) for obj in objs )
80112
113+ metadata = MemorySession ._get_metadata_for_annotated_table (entities [0 ])
114+
81115 return IteratorResult (metadata , wrapped )
82116
117+
83118 def _handle_delete (self , statement : Delete , ** kwargs ):
84119 q = MemoryQuery ([statement .table ], self )
85120
@@ -89,7 +124,7 @@ def _handle_delete(self, statement: Delete, **kwargs):
89124 collection = q .all ()
90125
91126 for obj in collection :
92- self .store . delete (obj )
127+ self .delete (obj )
93128
94129 result = IteratorResult (SimpleResultMetaData ([]), iter ([]))
95130 result .rowcount = len (collection )
@@ -115,7 +150,7 @@ def _handle_insert(self, statement: Insert, params=None, **kwargs):
115150 instances = []
116151 for vals in vals_list :
117152 obj = model (** vals )
118- self .store . add (obj )
153+ self .add (obj )
119154 instances .append (obj )
120155
121156 rowcount = len (instances )
@@ -157,15 +192,13 @@ def _handle_update(self, statement: Update, **kwargs):
157192 pk_col_name = self .store ._get_primary_key_name (obj )
158193
159194 pk_value = getattr (obj , pk_col_name )
160- self .store . update (tablename , pk_value , data )
195+ self .update (tablename , pk_value , data )
161196
162197 result = IteratorResult (SimpleResultMetaData ([]), iter ([]))
163198 result .rowcount = len (collection )
164199 return result
165200
166201 def execute (self , statement , params = None , ** kwargs ):
167- #logger.debug(f"Executing query: {statement}")
168-
169202 if isinstance (statement , Select ):
170203 return self ._handle_select (statement , ** kwargs )
171204
@@ -190,7 +223,8 @@ def merge(self, instance, **kwargs):
190223 existing = self .store .get_by_primary_key (instance , pk_value )
191224
192225 if existing :
193- self .store .mark_as_fetched (existing )
226+ self ._mark_as_fetched (existing )
227+ self ._has_pending_merge = True
194228
195229 for column in instance .__table__ .columns :
196230 field = column .name
@@ -205,16 +239,49 @@ def merge(self, instance, **kwargs):
205239 self .add (instance )
206240 return instance
207241
208- def delete (self , instance ):
209- self .store .delete (instance )
242+ @property
243+ def dirty (self ):
244+ return bool (self ._to_add or self ._to_delete or self ._to_update ) or self ._has_pending_merge
245+
246+ def _is_clean (self ):
247+ return not self .dirty
210248
211249 def flush (self , objects = None ):
212- pass
250+ if not self ._transaction or not self ._transaction ._connections :
251+ self .connection () # Ensure a real connection is created
252+
253+ to_transfer = [
254+ "_to_add" ,
255+ "_to_update" ,
256+ "_to_delete" ,
257+ "_fetched" ,
258+ ]
259+ for key in to_transfer :
260+ item = getattr (self , key )
261+ if not item :
262+ continue
263+ setattr (self .store , key , item .copy ())
264+ item .clear ()
213265
214266 def rollback (self , ** kwargs ):
215267 logger .debug ("Rolling back ..." )
268+
269+ self .store ._fetched = self ._fetched
216270 self .store .rollback ()
217271
272+ self ._has_pending_merge = False
273+
274+ self ._to_add .clear ()
275+ self ._to_delete .clear ()
276+ self ._to_update .clear ()
277+ self ._fetched .clear ()
278+
279+
218280 def commit (self ):
219- logger .debug ("Committing ..." )
220- self .store .commit ()
281+ if self .dirty :
282+ self .flush ()
283+
284+ if self .store .dirty or self ._has_pending_merge :
285+ logger .debug ("Committing ..." )
286+ self .store .commit ()
287+ self ._has_pending_merge = False
0 commit comments