Skip to content

Commit 674345a

Browse files
committed
Fixes and improvements after integrating in a real codebase
1 parent 556e9ea commit 674345a

16 files changed

Lines changed: 222 additions & 146 deletions

File tree

README.md

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,27 @@
55
[![PyPI - Downloads](https://img.shields.io/pypi/dm/sqlalchemy-memory)](https://pypistats.org/packages/sqlalchemy-memory)
66

77

8-
**In‑memory SQLAlchemy 2.0 dialect for blazing‑fast prototyping**
8+
**In‑memory SQLAlchemy 2.0 dialect for blazing‑fast prototyping**
99

10-
A pure‑Python, zero‑configuration SQLAlchemy 2.0 dialect that lives entirely in RAM.
11-
Ideal for rapid prototyping, backtesting, and demos-no external database required.
10+
A pure‑Python SQLAlchemy 2.0 dialect that runs entirely in RAM.
11+
It avoids typical database I/O and ORM overhead while maintaining full compatibility with the SQLAlchemy 2.0 Core and ORM APIs.
12+
Ideal for rapid prototyping, backtesting engines, simulations.
1213

1314
## Why ?
1415

15-
This project was inspired by the idea of creating a fast, introspectable, no-dependency backend for SQLAlchemy. Useful for prototyping, education, and testing ORM logic without spinning up a real database engine.
16+
This project was inspired by the idea of building a **fast, introspectable, no-dependency backend** for SQLAlchemy.
1617

17-
It's also perfect for apps that need a fast, in-memory store compatible with SQLAlchemy, such as backtesting engines, simulators, or tools where you don't want to maintain a separate memory layer alongside your database models.
18+
It is useful for:
19+
20+
- Prototyping new applications
21+
22+
- Educational purposes
23+
24+
- Testing ORM logic without spinning up a real database engine
25+
26+
Unlike traditional in-memory solutions like SQLite, `sqlalchemy-memory` fully avoids serialization, connection pooling, and driver overhead, leading to much faster in-memory performance while keeping the familiar SQLAlchemy API.
27+
28+
It is also perfect for **applications that need a lightweight, high-performance store** compatible with SQLAlchemy, such as backtesting engines, simulators, or other tools where you don't want to maintain a separate in-memory layer alongside your database models.
1829

1930
## Features
2031

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
from .store import InMemoryStore
2-
31
class MemoryDBAPIConnection:
4-
def __init__(self):
5-
self.store = InMemoryStore()
2+
store = None
63

74
def commit(self):
85
self.store.commit()
@@ -12,3 +9,9 @@ def rollback(self):
129

1310
def close(self):
1411
pass
12+
13+
@classmethod
14+
def connect(cls, store, *args, **kwargs):
15+
connection = MemoryDBAPIConnection()
16+
connection.store = store
17+
return connection

sqlalchemy_memory/base/dialect.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from sqlalchemy.engine import URL, default
22
import types
3-
import copy
43

54
from .connection import MemoryDBAPIConnection
6-
from ..util import _raw_dbapi_connection
5+
from .store import InMemoryStore
76

87
class MemoryDialect(default.DefaultDialect):
98
name = "memory"
@@ -12,14 +11,21 @@ class MemoryDialect(default.DefaultDialect):
1211
supports_native_boolean = True
1312
supports_statement_cache = False
1413

14+
_store = None
15+
16+
def __init__(self, *args, **kwargs):
17+
super().__init__(*args, **kwargs)
18+
self._store = InMemoryStore()
19+
1520
def initialize(self, connection):
1621
super().initialize(connection)
1722

1823
# Turn off pool reset to preserve in-memory data
1924
connection.engine.pool._reset_on_return = None
2025

2126
def create_connect_args(self, url: URL):
22-
return [], {}
27+
db_name = url.database or "_default"
28+
return [db_name], {}
2329

2430
@classmethod
2531
def import_dbapi(cls, **kwargs):
@@ -28,23 +34,21 @@ def import_dbapi(cls, **kwargs):
2834
paramstyle="named",
2935
apilevel="2.0",
3036
threadsafety=1,
31-
Error=Exception
37+
Error=Exception,
38+
connect=lambda *a, **k: None
3239
)
33-
module.connect = lambda *a, **k: MemoryDBAPIConnection()
3440
return module
3541

36-
def do_begin(self, dbapi_conn):
37-
connection = _raw_dbapi_connection(dbapi_conn)
38-
connection.store._snapshot = copy.deepcopy(connection.store.data)
42+
def connect(self, *args, **kwargs):
43+
connection = MemoryDBAPIConnection()
44+
connection.store = self._store
45+
return connection
3946

4047
def do_commit(self, dbapi_conn):
41-
connection = _raw_dbapi_connection(dbapi_conn)
42-
connection.store.commit()
48+
self._store.commit()
4349

4450
def do_rollback(self, dbapi_conn):
45-
connection = _raw_dbapi_connection(dbapi_conn)
46-
connection.store.data = connection.store._snapshot
47-
connection.store.rollback()
51+
self._store.rollback()
4852

4953
def has_table(self, *args, **kwargs):
5054
# Patch to make Base.metadata.create_all(engine) not throw any exception

sqlalchemy_memory/base/query.py

Lines changed: 37 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,28 @@
55
from sqlalchemy.sql.functions import FunctionElement
66
from sqlalchemy.sql import operators
77
from sqlalchemy.sql.annotation import AnnotatedTable
8+
from sqlalchemy.orm.query import Query
89
from functools import cached_property
910
import fnmatch
1011

11-
from sqlalchemy.orm.query import Query
12+
from ..logger import logger
13+
from .resolvers import DateResolver, JsonExtractResolver
14+
15+
OPERATOR_ADAPTERS = {
16+
operators.is_: lambda value: lambda x, _: x is value,
17+
operators.isnot: lambda value: lambda x, _: x is not value,
18+
operators.like_op: lambda value: lambda x, _: fnmatch.fnmatchcase(x or '', value.replace('%', '*').replace('_', '?')),
19+
operators.not_like_op: lambda value: lambda x, _: not fnmatch.fnmatchcase(x or '', value.replace('%', '*').replace('_', '?')),
20+
operators.between_op: lambda bounds: lambda x, _: bounds[0] <= x <= bounds[1],
21+
operators.not_between_op: lambda bounds: lambda x, _: not (bounds[0] <= x <= bounds[1]),
22+
operators.in_op: lambda values: lambda x, _: x in values,
23+
operators.not_in_op: lambda values: lambda x, _: x not in values,
24+
}
25+
26+
FUNCTION_RESOLVERS = {
27+
"date": DateResolver,
28+
"json_extract": JsonExtractResolver,
29+
}
1230

1331
class MemoryQuery(Query):
1432
def __init__(self, entities, element):
@@ -52,17 +70,6 @@ def order_by(self, clause):
5270
self._order_by.append(clause)
5371
return self
5472

55-
def _extract_json_value(self, data_dict, path):
56-
# Traverse nested keys for a JSON path like 'ref.abc.xyz'
57-
current = data_dict or {}
58-
for key in path.split('.'):
59-
if not isinstance(current, dict):
60-
return None
61-
current = current.get(key)
62-
if current is None:
63-
return None
64-
return current
65-
6673
def _apply_condition(self, cond, collection):
6774
if not isinstance(cond, BinaryExpression):
6875
raise NotImplementedError(f"Unsupported condition type: {type(cond)}")
@@ -85,34 +92,21 @@ def _apply_condition(self, cond, collection):
8592
else:
8693
raise NotImplementedError(f"Unsupported RHS: {type(rhs)}")
8794

88-
# Handle JSON extraction: func.json_extract(column, path)
89-
if isinstance(cond.left, FunctionElement) and cond.left.name.lower() == 'json_extract':
90-
args = list(cond.left.clauses)
91-
column_expr, path_expr = args[0], args[1]
92-
attr_name = column_expr.name
93-
# Determine raw path string
94-
raw = path_expr.value if hasattr(path_expr, 'value') else str(path_expr).strip('"')
95-
96-
# Strip leading '$.' or '$'
97-
if raw.startswith('$.'):
98-
raw_path = raw[2:]
99-
elif raw.startswith('$'):
100-
raw_path = raw[1:]
101-
else:
102-
raw_path = raw
103-
104-
# Compare nested value
105-
op = cond.operator
106-
return [
107-
item for item in collection
108-
if op(
109-
self._extract_json_value(getattr(item, attr_name), raw_path),
110-
value
111-
)
112-
]
95+
col = cond.left
96+
accessor = lambda obj, attr_name: getattr(obj, attr_name)
97+
98+
if isinstance(cond.left, FunctionElement):
99+
fn_name = cond.left.name.lower()
100+
if fn_name not in FUNCTION_RESOLVERS:
101+
raise NotImplementedError(f"Unsupported LHS function: {fn_name}")
102+
103+
clauses = list(cond.left.clauses)
104+
col = clauses[0]
105+
_class = FUNCTION_RESOLVERS[fn_name]
106+
_resolver = _class(clauses[1:])
107+
accessor = _resolver.accessor
113108

114109
# Extract column name (LHS) and operator
115-
col = cond.left
116110
if not hasattr(col, "name"):
117111
raise NotImplementedError(f"Unsupported LHS: {col}")
118112
attr_name = col.name
@@ -122,43 +116,18 @@ def _apply_condition(self, cond, collection):
122116

123117
op = cond.operator
124118

125-
# special‑case SQL "IS NULL" and "IS NOT NULL"
126-
if value is None:
127-
if op is operators.is_:
128-
op = lambda x, y: x is None
129-
elif op is operators.isnot:
130-
op = lambda x, y: x is not None
131-
132-
elif op is operators.like_op:
133-
fnmatch_pattern = value.replace('%', '*').replace('_', '?')
134-
op = lambda x, y: fnmatch.fnmatchcase(x or '', fnmatch_pattern)
135-
136-
elif op is operators.not_like_op:
137-
fnmatch_pattern = value.replace('%', '*').replace('_', '?')
138-
op = lambda x, y: not fnmatch.fnmatchcase(x or '', fnmatch_pattern)
139-
140-
elif op is operators.between_op:
141-
low, high = value
142-
op = lambda x, _: low <= x <= high
143-
144-
elif op is operators.not_between_op:
145-
low, high = value
146-
op = lambda x, _: not (low <= x <= high)
147-
148-
elif op is operators.in_op:
149-
op = lambda x, y: x in y
150-
151-
elif op is operators.not_in_op:
152-
op = lambda x, y: x not in y
119+
if op in OPERATOR_ADAPTERS:
120+
op = OPERATOR_ADAPTERS[op](value)
153121

154122
return [
155123
item for item in collection
156-
if op(getattr(item, attr_name), value)
124+
if op(accessor(item, attr_name), value)
157125
]
158126

159127
def _execute_query(self):
160128
collection = self.session.store.data.get(self.tablename, [])
161129
if not collection:
130+
logger.debug(f"Table '{self.tablename}' is empty")
162131
return collection
163132

164133
# Apply conditions
@@ -178,8 +147,7 @@ def _execute_query(self):
178147
else:
179148
col = clause
180149

181-
attr = col.name
182-
collection.sort(key=lambda x: getattr(x, attr), reverse=reverse)
150+
collection = sorted(collection, key=lambda x: getattr(x, col.name), reverse=reverse)
183151

184152
# Apply offset
185153
if self._offset is not None:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .date import DateResolver
2+
from .json_extract import JsonExtractResolver
3+
4+
__all__ = [
5+
"DateResolver",
6+
"JsonExtractResolver",
7+
]
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
class FunctionResolver:
2+
def __init__(self, clauses):
3+
self.clauses = clauses
4+
5+
def accessor(self, item, attr_name):
6+
raise NotImplementedError
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .abstract import FunctionResolver
2+
3+
4+
class DateResolver(FunctionResolver):
5+
def accessor(self, item, attr_name):
6+
value = getattr(item, attr_name)
7+
return value.date() if value else None
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from .abstract import FunctionResolver
2+
3+
4+
class JsonExtractResolver(FunctionResolver):
5+
def _extract_json_value(self, data_dict, path):
6+
# Traverse nested keys for a JSON path like 'ref.abc.xyz'
7+
current = data_dict or {}
8+
for key in path.split('.'):
9+
if not isinstance(current, dict):
10+
return None
11+
current = current.get(key)
12+
if current is None:
13+
return None
14+
return current
15+
16+
def accessor(self, item, attr_name):
17+
path_expr = self.clauses[0]
18+
19+
raw = path_expr.value if hasattr(path_expr, 'value') else str(path_expr).strip('"')
20+
21+
# Strip leading '$.' or '$'
22+
if raw.startswith('$.'):
23+
raw_path = raw[2:]
24+
elif raw.startswith('$'):
25+
raw_path = raw[1:]
26+
else:
27+
raw_path = raw
28+
29+
return self._extract_json_value(getattr(item, attr_name), raw_path)

sqlalchemy_memory/base/session.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sqlalchemy.engine.cursor import SimpleResultMetaData
66

77
from .query import MemoryQuery
8+
from ..logger import logger
89

910
class MemorySession(Session):
1011
def __init__(self, *args, **kwargs):
@@ -163,6 +164,8 @@ def _handle_update(self, statement: Update, **kwargs):
163164
return result
164165

165166
def execute(self, statement, params=None, **kwargs):
167+
#logger.debug(f"Executing query: {statement}")
168+
166169
if isinstance(statement, Select):
167170
return self._handle_select(statement, **kwargs)
168171

@@ -183,17 +186,19 @@ def merge(self, instance, **kwargs):
183186
"""
184187

185188
pk_name = self.store._get_primary_key_name(instance)
186-
id = getattr(instance, pk_name)
187-
existing = self.store.get_by_primary_key(instance, id)
189+
pk_value = getattr(instance, pk_name)
190+
existing = self.store.get_by_primary_key(instance, pk_value)
188191

189192
if existing:
190-
data = {
191-
col.name: getattr(instance, col.name)
192-
for col in instance.__table__.columns
193-
if col.name != pk_name
194-
}
193+
self.store.mark_as_fetched(existing)
194+
195+
for column in instance.__table__.columns:
196+
field = column.name
197+
if field == pk_name:
198+
continue
199+
value = getattr(instance, field)
200+
setattr(existing, field, value)
195201

196-
self.store.update(instance.__tablename__, id, data)
197202
return existing
198203

199204
else:
@@ -207,7 +212,9 @@ def flush(self, objects=None):
207212
pass
208213

209214
def rollback(self, **kwargs):
215+
logger.debug("Rolling back ...")
210216
self.store.rollback()
211217

212218
def commit(self):
219+
logger.debug("Committing ...")
213220
self.store.commit()

0 commit comments

Comments
 (0)