Skip to content

Commit e61802d

Browse files
committed
Handle column default values
1 parent 02469f2 commit e61802d

File tree

4 files changed

+135
-2
lines changed

4 files changed

+135
-2
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ Unlike traditional in-memory solutions like SQLite, `sqlalchemy-memory` fully av
2727

2828
It is also perfect for **applications that need a lightweight, high-performance store** compatible with SQLAlchemy, such as backtesting engines, simulators, or other tools where you don't want to maintain a separate in-memory layer alongside your database models.
2929

30+
Data is kept purely in RAM and is **volatile**: it is **not persisted across application restarts** and is **cleared when the engine is disposed**.
31+
3032
## Features
3133

3234
- **SQLAlchemy 2.0 support**: ORM & Core expressions, sync & async modes

sqlalchemy_memory/base/store.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from collections import defaultdict
2+
from sqlalchemy import func
3+
from sqlalchemy.sql.elements import TextClause
4+
from datetime import datetime
25

36
from ..logger import logger
47

@@ -80,6 +83,8 @@ def commit(self):
8083
if pk_value in self.data_by_pk[tablename].keys():
8184
raise Exception(f"Cannot have duplicate PK value {pk_value} for table '{tablename}'")
8285

86+
self._apply_column_defaults(obj)
87+
8388
logger.debug(f"Adding {obj} to table '{tablename}'")
8489

8590
self.data[tablename].append(obj)
@@ -155,3 +160,37 @@ def _assign_primary_key_if_needed(self, obj):
155160
self._pk_counter[table] = max(self._pk_counter[table], current_id)
156161

157162
return current_id
163+
164+
def _apply_column_defaults(self, obj):
165+
"""
166+
Apply default and server_default values to an ORM object.
167+
"""
168+
169+
for column in obj.__table__.columns:
170+
attr_name = column.name
171+
current_value = getattr(obj, attr_name, None)
172+
173+
if current_value is not None:
174+
continue
175+
176+
elif column.default is not None:
177+
if callable(column.default.arg):
178+
try:
179+
value = column.default.arg()
180+
except TypeError:
181+
value = column.default.arg(ctx=None)
182+
else:
183+
value = column.default.arg
184+
185+
setattr(obj, attr_name, value)
186+
187+
elif column.server_default is not None:
188+
if isinstance(column.server_default.arg, TextClause):
189+
text_value = column.server_default.arg.text
190+
setattr(obj, attr_name, text_value)
191+
192+
elif isinstance(column.server_default.arg, func.now().__class__):
193+
setattr(obj, attr_name, datetime.utcnow())
194+
195+
else:
196+
raise Exception(f"Unhandled server_default type: {type(column.server_default)}")

tests/models.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from sqlalchemy.orm import declarative_base, mapped_column, Mapped
2-
from sqlalchemy import JSON
2+
from sqlalchemy import JSON, func, text
3+
from datetime import datetime
34

45
Base = declarative_base()
56

@@ -17,8 +18,13 @@ class Product(Base):
1718
id: Mapped[int] = mapped_column(primary_key=True)
1819
active: Mapped[bool] = mapped_column(default=True, index=True)
1920
name: Mapped[str] = mapped_column(nullable=False)
20-
category: Mapped[str] = mapped_column(nullable=False, index=True)
21+
category: Mapped[str] = mapped_column(index=True, server_default=text("unknown"))
2122
data: Mapped[dict] = mapped_column(JSON)
2223

24+
created_at: Mapped[datetime] = mapped_column(
25+
server_default=func.now(),
26+
nullable=False
27+
)
28+
2329
def __repr__(self):
2430
return f"Product(id={self.id} name={self.name})"

tests/test_advanced.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from sqlalchemy import select, func
2+
from datetime import datetime, date
23
import pytest
34

45
from models import Item, Product
@@ -117,6 +118,91 @@ def test_json_extract_filter(self, SessionFactory, pattern, value, expected_ids)
117118
results = session.execute(stmt).scalars().all()
118119
assert {item.id for item in results} == expected_ids
119120

121+
def test_default_values(self, SessionFactory):
122+
dt = datetime(2025, 1, 1, 2, 3, 4)
123+
124+
with SessionFactory() as session:
125+
session.add_all([
126+
Product(id=5, name="foo", category="A"),
127+
Product(name="bar", active=False, created_at=dt),
128+
])
129+
session.commit()
130+
131+
products = session.execute(select(Product)).scalars().all()
132+
133+
assert products[0].active
134+
assert products[0].created_at is not None
135+
assert products[0].category == "A"
136+
137+
assert products[1].id == 6
138+
assert not products[1].active
139+
assert products[1].created_at == dt
140+
assert products[1].category == "unknown"
141+
142+
@pytest.mark.parametrize(
143+
"operator, value, expected_ids",
144+
[
145+
("is", True, {1, 3}),
146+
("is_not", True, {2}),
147+
("is", False, {2}),
148+
("is_not", False, {1, 3}),
149+
]
150+
)
151+
def test_is_filter(self, SessionFactory, operator, value, expected_ids):
152+
with SessionFactory() as session:
153+
session.add_all([
154+
Product(id=1, name="foo", active=True),
155+
Product(id=2, name="bar", active=False),
156+
Product(id=3, name="foobar", active=True),
157+
])
158+
session.commit()
159+
160+
stmt = select(Product)
161+
if operator == "is":
162+
stmt = stmt.where(Product.active.is_(value))
163+
else:
164+
stmt = stmt.where(Product.active.is_not(value))
165+
166+
results = session.execute(stmt).scalars().all()
167+
assert {item.id for item in results} == expected_ids
168+
169+
@pytest.mark.parametrize(
170+
"operator, value, expected_ids",
171+
[
172+
("==", date(2025, 1, 1), {1}),
173+
("!=", date(2025, 1, 2), {1, 3, 4}),
174+
(">", date(2025, 1, 2), {3, 4}),
175+
(">", date(2025, 1, 10), set()),
176+
]
177+
)
178+
def test_date_filter(self, SessionFactory, operator, value, expected_ids):
179+
"""
180+
stmt = select(NetLiquidationValue).where(
181+
func.DATE(NetLiquidationValue.created_at) == _date
182+
)
183+
184+
"""
185+
186+
with SessionFactory() as session:
187+
session.add_all([
188+
Product(id=1, name="foo", created_at=datetime(2025, 1, 1, 1, 1, 1)),
189+
Product(id=2, name="bar", created_at=datetime(2025, 1, 2, 2, 2, 2)),
190+
Product(id=3, name="foobar", created_at=datetime(2025, 1, 3, 3, 3, 3)),
191+
Product(id=4, name="barfoo", created_at=datetime(2025, 1, 4, 4, 4, 4)),
192+
])
193+
session.commit()
194+
195+
stmt = select(Product)
196+
if operator == "==":
197+
stmt = stmt.where(func.DATE(Product.created_at) == value)
198+
elif operator == "!=":
199+
stmt = stmt.where(func.DATE(Product.created_at) != value)
200+
elif operator == ">":
201+
stmt = stmt.where(func.DATE(Product.created_at) > value)
202+
203+
results = session.execute(stmt).scalars().all()
204+
assert {item.id for item in results} == expected_ids
205+
120206
def test_session_inception(self, SessionFactory):
121207
with SessionFactory() as session1:
122208
session1.add(Item(id=1, name="foo"))

0 commit comments

Comments
 (0)