Skip to content

Commit 6d4ac27

Browse files
authored
Support aggregation functions (min/max/sum/count/avg) (#3)
* Use generators to filter datasets * Finish aggregation support * . * Bump version: 0.3.1 → 0.4.0
1 parent 849ba18 commit 6d4ac27

22 files changed

Lines changed: 900 additions & 396 deletions

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.3.1
2+
current_version = 0.4.0
33
commit = True
44
tag = True
55

README.md

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,15 @@ Data is kept purely in RAM and is **volatile**: it is **not persisted across app
3636
- **Zero I/O overhead**: pure in‑RAM storage (`dict`/`list` under the hood)
3737
- **Commit/rollback support**
3838
- **Index support**: indexes are recognized and used for faster lookups
39-
- **Merge and `get()` support**: like real SQLAlchemy behavior
39+
- **Lazy query evaluation**: supports generator pipelines and short-circuiting
40+
- `first()`-style queries avoid scanning the full dataset
41+
- Optimized for read-heavy workloads and streaming filters
42+
43+
## Benchmark
44+
45+
Curious how `sqlalchemy-memory` stacks up?
46+
47+
[View Benchmark Results](https://sqlalchemy-memory.readthedocs.io/en/latest/benchmarks.html) comparing `sqlalchemy-memory` to `in-memory SQLite`
4048

4149
## Installation
4250

@@ -48,25 +56,6 @@ pip install sqlalchemy-memory
4856

4957
[See the official documentation for usage examples](https://sqlalchemy-memory.readthedocs.io/en/latest/)
5058

51-
52-
## Status
53-
54-
Currently supports basic functionality equivalent to:
55-
56-
- SQLite in-memory behavior for ORM + Core queries
57-
58-
- `declarative_base()` model support
59-
60-
Coming soon:
61-
62-
- `func.count()` / aggregations
63-
64-
- Joins and relationships (limited)
65-
66-
- Compound indexes
67-
68-
- Better expression support in `update(...).values()` (e.g., +=)
69-
7059
## Testing
7160

7261
Simply run `make tests`

benchmark.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1-
from sqlalchemy import create_engine, Column, Integer, String, Boolean, select, Index, update, delete
1+
from sqlalchemy import create_engine, Column, Integer, String, Boolean, select, Float, update, delete, bindparam, literal
22
from sqlalchemy.orm import declarative_base, sessionmaker
3+
from sqlalchemy.sql import operators
4+
from sqlalchemy.sql.elements import BinaryExpression
35
from sqlalchemy_memory import MemorySession
46
import argparse
57
import time
68
import random
79
from faker import Faker
810

9-
try:
10-
from sqlalchemy_memory import create_memory_engine
11-
except ImportError:
12-
create_memory_engine = None
1311

12+
random.seed(42)
1413
Base = declarative_base()
1514
fake = Faker()
1615
CATEGORIES = list("ABCDEFGHIJK")
@@ -22,22 +21,46 @@ class Item(Base):
2221
name = Column(String)
2322
active = Column(Boolean, index=True)
2423
category = Column(String, index=True)
24+
price = Column(Float, index=True)
25+
cost = Column(Float)
2526

2627
def generate_items(n):
2728
for _ in range(n):
2829
yield Item(
2930
name=fake.name(),
3031
active=random.choice([True, False]),
31-
category=random.choice(CATEGORIES)
32+
category=random.choice(CATEGORIES),
33+
price=round(random.uniform(5, 500), 2),
34+
cost=round(random.uniform(1, 300), 2),
3235
)
3336

3437
def generate_random_select_query():
3538
clauses = []
39+
3640
if random.random() < 0.5:
37-
clauses.append(Item.active == random.choice([True, False]))
38-
if random.random() < 0.5 or not clauses:
41+
val = random.choice([True, False])
42+
op = random.choice([operators.eq, operators.ne])
43+
clauses.append(BinaryExpression(Item.active, literal(val), op))
44+
45+
if random.random() < 0.7:
3946
subset = random.sample(CATEGORIES, random.randint(1, 4))
40-
clauses.append(Item.category.in_(subset))
47+
op = random.choice([operators.in_op, operators.notin_op])
48+
param = bindparam("category_list", subset, expanding=True)
49+
clauses.append(BinaryExpression(Item.category, param, op))
50+
51+
if random.random() < 0.6:
52+
price_val = round(random.uniform(10, 400), 2)
53+
op = random.choice([operators.gt, operators.lt, operators.le, operators.gt])
54+
clauses.append(BinaryExpression(Item.price, literal(price_val), op))
55+
56+
if random.random() < 0.3:
57+
cost_val = round(random.uniform(10, 200), 2)
58+
op = random.choice([operators.gt, operators.lt, operators.le, operators.gt])
59+
clauses.append(BinaryExpression(Item.cost, literal(cost_val), op))
60+
61+
if not clauses:
62+
clauses.append(Item.active == True)
63+
4164
return select(Item).where(*clauses)
4265

4366
def inserts(Session, count):
@@ -49,15 +72,24 @@ def inserts(Session, count):
4972
print(f"Inserted {count} items in {insert_duration:.2f} seconds.")
5073
return insert_duration
5174

52-
def selects(Session, count):
75+
def selects(Session, count, fetch_type):
5376
queries = [generate_random_select_query() for _ in range(count)]
5477

5578
query_start = time.time()
5679
with Session() as session:
5780
for stmt in queries:
58-
list(session.execute(stmt).scalars())
81+
if fetch_type == "limit":
82+
stmt = stmt.limit(5)
83+
84+
result = session.execute(stmt)
85+
86+
if fetch_type == "first":
87+
result.first()
88+
else:
89+
list(result.scalars())
90+
5991
query_duration = time.time() - query_start
60-
print(f"Executed {count} select queries in {query_duration:.2f} seconds.")
92+
print(f"Executed {count} select queries ({fetch_type}) in {query_duration:.2f} seconds.")
6193
return query_duration
6294

6395
def updates(Session, random_ids):
@@ -105,7 +137,8 @@ def run_benchmark(db_type="sqlite", count=100_000):
105137
Base.metadata.create_all(engine)
106138

107139
elapsed = inserts(Session, count)
108-
elapsed += selects(Session, 500)
140+
elapsed += selects(Session, 500, fetch_type="all")
141+
elapsed += selects(Session, 500, fetch_type="limit")
109142

110143
random_ids = random.sample(range(1, count + 1), 500)
111144
elapsed += updates(Session, random_ids)

docs/benchmarks.rst

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ This benchmark compares `sqlalchemy-memory` to `in-memory SQLite` using 20,000 i
55

66
As the results show, `sqlalchemy-memory` **excels in read-heavy workloads**, delivering significantly faster query performance. While SQLite performs slightly better on update and delete operations, the overall runtime of `sqlalchemy-memory` remains substantially lower, making it a strong choice for prototyping and simulation.
77

8+
`Check the benchmark script on GitHub <https://github.com/rundef/sqlalchemy-memory/blob/main/benchmark.py>`_
9+
810
.. list-table::
911
:header-rows: 1
1012
:widths: 25 25 25
@@ -13,17 +15,20 @@ As the results show, `sqlalchemy-memory` **excels in read-heavy workloads**, del
1315
- SQLite (in-memory)
1416
- sqlalchemy-memory
1517
* - Insert
16-
- 3.17 sec
17-
- 2.70 sec
18-
* - 500 Select Queries
19-
- 26.37 sec
20-
- 2.94 sec
18+
- 3.30 sec
19+
- **3.10 sec**
20+
* - 500 Select Queries (all())
21+
- 30.07 sec
22+
- **4.14 sec**
23+
* - 500 Select Queries (limit(5))
24+
- **0.24** sec
25+
- 0.30 sec
2126
* - 500 Updates
22-
- 0.26 sec
23-
- 1.12 sec
27+
- 0.25 sec
28+
- **0.19** sec
2429
* - 500 Deletes
25-
- 0.09 sec
26-
- 0.90 sec
27-
* - **Total Runtime**
28-
- **29.89 sec**
29-
- **7.66 sec**
30+
- **0.09** sec
31+
- **0.09** sec
32+
* - *Total Runtime*
33+
- 33.95 sec
34+
- **7.81 sec**

docs/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ Welcome to sqlalchemy-memory's documentation!
33

44
`sqlalchemy-memory` is a pure in-memory backend for SQLAlchemy 2.0 that supports both sync and async modes, with full compatibility for SQLAlchemy Core and ORM.
55

6+
📦 GitHub: https://github.com/rundef/sqlalchemy-memory
7+
68
Quickstart: sync example
79
------------------------
810

docs/query.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Supported Functions
1515

1616
- `DATE(column)`
1717
- `func.json_extract(col, '$.expr')`
18+
- Aggregation functions: - Aggregations: `func.count()` / `func.sum()` / `func.min()` / `func.max()` / `func.avg()`
1819

1920
Indexes
2021
-------

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "sqlalchemy-memory"
7-
version = "0.3.1"
7+
version = "0.4.0"
88
dependencies = [
99
"sqlalchemy>=2.0,<3.0",
1010
"sortedcontainers>=2.4.0"

sqlalchemy_memory/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
"AsyncMemorySession",
77
]
88

9-
__version__ = '0.3.1'
9+
__version__ = '0.4.0'

sqlalchemy_memory/base/indexes.py

Lines changed: 51 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections import defaultdict
22
from sortedcontainers import SortedDict
3-
from typing import Any, List
3+
from typing import Any, List, Generator
4+
from itertools import chain
45
from sqlalchemy.sql import operators
56

67
from ..helpers.ordered_set import OrderedSet
@@ -108,62 +109,84 @@ def on_update(self, obj, updates):
108109
self.hash_index.add(tablename, indexname, new_value, obj)
109110
self.range_index.add(tablename, indexname, new_value, obj)
110111

111-
def query(self, collection, tablename, colname, operator, value):
112+
def query(self, collection, tablename, colname, operator, value, collection_is_full_table=False):
112113
indexname = self._column_to_index(tablename, colname)
113114
if not indexname:
114115
return None
115116

116-
# Use hash index for = / != / IN / NOT IN operators
117117
if operator == operators.eq:
118118
result = self.hash_index.query(tablename, indexname, value)
119-
return list(set(result) & set(collection))
119+
if collection_is_full_table:
120+
return result
121+
return (item for item in collection if item in result)
120122

121123
elif operator == operators.ne:
122-
# All values except the given one
123124
excluded = self.hash_index.query(tablename, indexname, value)
124-
return list(set(collection) - set(excluded))
125+
return (item for item in collection if item not in excluded)
125126

126127
elif operator == operators.in_op:
127-
result = []
128-
for v in value:
129-
result.extend(self.hash_index.query(tablename, indexname, v))
130-
return list(set(result) & set(collection))
128+
result = chain.from_iterable(
129+
self.hash_index.query(tablename, indexname, v) for v in value
130+
)
131+
if collection_is_full_table:
132+
return result
133+
result = set(result)
134+
return (item for item in collection if item in result)
131135

132136
elif operator == operators.notin_op:
133-
excluded = []
134-
for v in value:
135-
excluded.extend(self.hash_index.query(tablename, indexname, v))
136-
return list(set(collection) - set(excluded))
137+
excluded = set(chain.from_iterable(
138+
self.hash_index.query(tablename, indexname, v) for v in value
139+
))
140+
return (item for item in collection if item not in excluded)
137141

138-
# Use range index
139-
if operator == operators.gt:
142+
elif operator == operators.gt:
140143
result = self.range_index.query(tablename, indexname, gt=value)
141-
return list(set(result) & set(collection))
144+
if collection_is_full_table:
145+
return result
146+
result = set(result)
147+
return (item for item in collection if item in result)
142148

143149
elif operator == operators.ge:
144150
result = self.range_index.query(tablename, indexname, gte=value)
145-
return list(set(result) & set(collection))
151+
if collection_is_full_table:
152+
return result
153+
result = set(result)
154+
return (item for item in collection if item in result)
146155

147156
elif operator == operators.lt:
148157
result = self.range_index.query(tablename, indexname, lt=value)
149-
return list(set(result) & set(collection))
158+
if collection_is_full_table:
159+
return result
160+
result = set(result)
161+
return (item for item in collection if item in result)
150162

151163
elif operator == operators.le:
152164
result = self.range_index.query(tablename, indexname, lte=value)
153-
return list(set(result) & set(collection))
165+
if collection_is_full_table:
166+
return result
167+
result = set(result)
168+
return (item for item in collection if item in result)
154169

155170
elif operator == operators.between_op and isinstance(value, (tuple, list)) and len(value) == 2:
156171
result = self.range_index.query(tablename, indexname, gte=value[0], lte=value[1])
157-
return list(set(result) & set(collection))
172+
if collection_is_full_table:
173+
return result
174+
result = set(result)
175+
return (item for item in collection if item in result)
158176

159177
elif operator == operators.not_between_op and isinstance(value, (tuple, list)) and len(value) == 2:
160-
in_range = self.range_index.query(tablename, indexname, gte=value[0], lte=value[1])
161-
return list(set(collection) - set(in_range))
178+
in_range = set(self.range_index.query(tablename, indexname, gte=value[0], lte=value[1]))
179+
return (item for item in collection if item not in in_range)
162180

163181

164182
def get_selectivity(self, tablename, colname, operator, value, total_count):
165183
"""
166-
Estimate selectivity: higher means worst filtering.
184+
Estimate the selectivity of a single WHERE condition.
185+
186+
This method is used to rank or sort WHERE conditions by their estimated
187+
filtering power. A lower selectivity value indicates that the condition
188+
is expected to filter out more rows (i.e., fewer rows remain after applying it),
189+
making it more selective.
167190
"""
168191

169192
indexname = self._column_to_index(tablename, colname)
@@ -220,7 +243,7 @@ def remove(self, tablename: str, indexname: str, value: Any, obj: Any):
220243
del self.index[tablename][indexname][value]
221244

222245
def query(self, tablename: str, indexname: str, value: Any) -> List[Any]:
223-
return list(self.index[tablename][indexname].get(value, []))
246+
return self.index[tablename][indexname].get(value, [])
224247

225248

226249
class RangeIndex:
@@ -255,7 +278,7 @@ def remove(self, tablename: str, indexname: str, value: Any, obj: Any):
255278
except ValueError:
256279
pass
257280

258-
def query(self, tablename: str, indexname: str, gt=None, gte=None, lt=None, lte=None) -> List[Any]:
281+
def query(self, tablename: str, indexname: str, gt=None, gte=None, lt=None, lte=None) -> Generator:
259282
sd = self.index[tablename][indexname]
260283

261284
# Define range bounds
@@ -264,14 +287,10 @@ def query(self, tablename: str, indexname: str, gt=None, gte=None, lt=None, lte=
264287
inclusive_min = gte is not None
265288
inclusive_max = lte is not None
266289

267-
irange = sd.irange(
290+
keys = sd.irange(
268291
minimum=min_key,
269292
maximum=max_key,
270293
inclusive=(inclusive_min, inclusive_max)
271294
)
272295

273-
result = []
274-
for key in irange:
275-
result.extend(sd[key])
276-
277-
return result
296+
return chain.from_iterable(sd[key] for key in keys)

0 commit comments

Comments
 (0)